Automatically Deriving Typeclass Instances in Scala 3

13 May, 2021
Xebia Background Header Wave

We are going to carry on with the examination of typeclasses in Scala 3. Last time we covered the nuts and bolts of what typeclasses are, why they’re good, and how they work in Scala 3.

Revisiting enum

Let’s look at our BinaryTree class from the enum article:

enum BinaryTree[+A]:
  case Node(value: A, left: BinaryTree[A], right: BinaryTree[A])
  case Leaf
scala> BinaryTree.Node(5,
     |                 BinaryTree.Leaf,
     |                 BinaryTree.Node(10,
     |                                 BinaryTree.Leaf,
     |                                 BinaryTree.Leaf
     |                                )
     |                )
val res0: BinaryTree[Int] = Node(5,Leaf,Node(10,Leaf,Leaf))

We can create for equality:

trait Eq[A]:
  def eq(x: A, y: A): Boolean

This is the same concept as in the previous article, so hopefully implementing this shouldn’t be too much of a surprise:

given[A](using eqA: Eq[A]): Eq[BinaryTree[A]] with
  def eq(x: BinaryTree[A], y: BinaryTree[A]): Boolean = {
    (x, y) match {
      case (BinaryTree.Leaf, BinaryTree.Leaf) => true
      case (BinaryTree.Node(xv, xl, xr), BinaryTree.Node(yv, yl, yr)) => eqA.eq(xv, yv) && eq(xl, yl) && eq(xr, yr)
      case _ => false

There are a few things worth pointing out here:

  • We have the new using keyword summoning the Eq instance for A: If we can’t measure the equality of A, then we can’t measure the equality of BinaryTree[A].
  • We are pattern matching on the enum: This works as it did in Scala 2.
  • We are recursively checking that the sub-trees are equal.

We’ve got Eq implemented for our binary tree: going forward, anyone can use our BinaryTree in a
function requiring Eq, as long as the underlying type has instances for those typeclasses in scope.

Let’s turn our attention to implementing Show for BinaryTree[A]:

trait Show[A]:
  def show(a: A): String

Show will return a String representation of whatever the value A is.

This should seem vary similar to the Eq approach:

  • We will need the compiler to provide us with a Show[A] instance.
  • Base case: work out what to do for Leaf: print the empty string.
  • Cons case: invoke Show[A] for the branches in the tree, appending whatever show returns for the subtree.

Conceptually, this is the same as our Eq implementation. Considering we’re doing the same thing,
wouldn’t it be great if this could be automatically written for us?

Automatic Typeclass Derivation

Thankfully, we can tell Scala 3 to do it. We break a type down into its constituent parts and Show these instead.

The Mirror typeclass

The Mirror typeclass gives us information about the type itself, the names of the fields, the types of the fields, and how it all fits together. The compiler is able to provide this information for:

  • Case classes
  • Case objects
  • Sealed hierarchies containing case object and case classes
  • Enums
  • Enum cases

For now, the most important part of the Mirror is the breakdown between Sum and Product.
Our BinaryTree is an Algebraic Data Type (ADT) built from a Sum of two variants: Node and Leaf.
Node is a Product, built from a combination of an A and two more BinaryTree instances.
It may be helpful to think of a Product as a tuple containing the values that type would hold,
plus some metadata about the field names and class name.

We will be using these imports for this article:

import scala.deriving.*
import scala.compiletime.{erasedValue, summonInline}

In order for Scala to build Show typeclass instances for us, we need to tell the compiler how to build Show instances for Sums and Products:

Showing products

We can assume that there is a Show instance available for all of the parts of the product.
A possible definition for Showing a product could look like this:

def showProduct[T](shows: => List[Show[_]]): Show[T] =                    // (1)
  new Show[T]:                                                            // (2)
    def show(t: T): String = {
      (t.asInstanceOf[Product].productIterator).zip(shows.iterator).map { // (3)
        case (p, s) => s.asInstanceOf[Show[Any]].show(p)                  // (4)

There’s quite a lot going on here!

  1. For this function, we are assuming (hoping?) the caller gives us a list of Show instances in the same order they are defined in the product definition. So for case class Prod[A](a: A, s: String, i: Int), this list would contain, in this order, List(Show[A], Show[String], Show[Int]). It should be clear why this is defined as Show[_]: we will have different values for the type parameter throughout the list. The Show instances can be available through derivation (using the list of options described for Mirror above), or if an explicit given is provided.
  2. This function creates a new Show instance every time it is called. While this is extremely useful, if you are auto-deriving large trees, and doing this many times, then compilation times will noticeably increase. You may find it useful to make a reference to an auto-derived instance in your code, and then refer to that instead; this way, the auto-derivation code will not be re-run every time for the same underlying type.
  3. We pair the value of type X as a Product with its Show[X].
  4. We show that value, and concatenate everything into a single string. It is important to note here that this is going to be the format of your string for every auto derived type. If you want to do anything fancy for a particular type (e.g., printing a binary tree into a nice tree-like structure), you’ll be better off doing this in its own Show instance. This code is effectively going to be the default show format for all types that use it. We also need to cast our Show to a type. We are fine to cast this to Any; we don’t have any more type information at this point, and the compiler has made sure the instance is the correct type before we even get to this point.

Showing sums

Some of the sum types are going to be made up from products. Again, we assume Show instances are available for all of the different sum variants, including auto-deriving products if necessary.

Our Show for sums could look like this:

def showSum[T](s: Mirror.SumOf[T], shows: => List[Show[_]]): Show[T] = // (1)
  new Show[T]:
    def show(t: T): String = {
      val index = s.ordinal(t)                                         // (2)
      shows(index).asInstanceOf[Show[Any]].show(t)                     // (3)
  1. This function makes use of the SumOf type within the Mirror typeclass.
    This is going to give us information about the structure of the differing Sum variants.
    Note, for our product, we didn’t need a reference to any Mirror. There is a Mirror.ProductOf[T] that we didn’t need, but we could use it for more complicated typeclasses. Mirror.ProductOf[T] has the ability to create an instance of T from a Product, so we could, for instance, change the values in a product and construct a new T at this point.
  2. Assuming again the List[Show[_]] are in the order that matches the order of the sum variant definitions, we grab the Show instance at the right position in the list.
  3. As with the product, we cast, we call show, and we’re done.

Putting it all together

Now that we know how to derive sums and products, we need to put it together and tell the compiler how to use this automatically.

If we include a given named derived, then the compiler will use this function to try to derive a typeclass for us. It has a very specific signature:

inline given derived[T](using m: Mirror.Of[T]): Show[T]

First, we need a way to grab either all of the sum variant or product field typeclass instances we want to derive for:

import scala.compiletime.{erasedValue, summonInline}

inline def summonAll[T <: Tuple]: List[Show[_]] =
  inline erasedValue[T] match
    case _: EmptyTuple => Nil
    case _: (t *: ts) => summonInline[Show[t]] :: summonAll[ts]

summonAll summons the instances in a list, respecting the order of the fields as we have assumed for our showSum and showProduct functions.

We can implement the derived function: we have all of our needed Show instances, and we know how to derive sums and products. We will put the derived function, along with any "base" or default typeclass instances, in the companion object:

object Show:
  inline given derived[T](using m: Mirror.Of[T]): Show[T] =
    lazy val shows = summonAll[m.MirroredElemTypes]
    inline m match
      case s: Mirror.SumOf[T] => showSum(s, shows)
      case _: Mirror.ProductOf[T] => showProduct(shows)

The lazy val means that any recursive definitions are not computed until needed, as this would produce a StackOverflowError when trying to derive all of the cases of the ADT.
This is why the List[Show[_]] parameters to showProduct and showSum have a call-by-name marker on the definition.

That’s it! Now we can automatically derive any Show typeclass instance. We need to make a small adjustment to our type to say that it can derive Show:

enum BinaryTree[+A] derives Show:
  case Node(value: A, left: BinaryTree[A], right: BinaryTree[A])
  case Leaf

Now, when we drop into a console, as long as we can derive a Show[A], we can derive a Show[BinaryTree[A]]:

scala> given Show[String] with
     |   def show(s: String): String = s"$s!"
// defined object given_Show_String

scala> summon[Show[BinaryTree[String]]]
val res0: Show[BinaryTree[String]] = Show$$anon$2@5043da15

val res1: String = ""

scala>"Hello", BinaryTree.Leaf, BinaryTree.Leaf))
val res2: String = Hello!

scala>"Hello", BinaryTree.Leaf, BinaryTree.Node("World", BinaryTree.Leaf, BinaryTree.Leaf)))                    
val res3: String = Hello!World!

And, of course, this works for types that have nothing to do with BinaryTree.
Using the same Show[String] instance from above:

scala> case class Record(s: String) derives Show
// defined case class Record

scala> summon[Show[Record]]
val res4: Show[Record] = Show$$anon$1@4ec6634c

scala>"New record"))
val res5: String = New record!

If there are other typeclass derivations we wish to include, the derives keyword takes a comma-separated list:

enum BinaryTree[+A] derives Show, Eq:
  case Node(value: A, left: BinaryTree[A], right: BinaryTree[A])
  case Leaf

Think about how Eq would be implemented in terms of sum and product. How does it differ from Show? What does it mean for a Leaf to be compared to a Node? What about for different product types?

scala> summon[Eq[BinaryTree[Int]]]                                                                                                                    
val res0: Eq[BinaryTree[Int]] = Eq$$anon$1@7c4384c7

scala> res0.eq(BinaryTree.Node(1, BinaryTree.Leaf, BinaryTree.Leaf), BinaryTree.Node(2, BinaryTree.Leaf, BinaryTree.Leaf))                 
val res1: Boolean = false

We can also apply the typeclass derivation after the class is defined. This is useful if you do not have control over the class:

given[T: Show]: Show[BinaryTree[T]] = Show.derived
given[T: Eq] : Eq[BinaryTree[T]] = Eq.derived

This is, in essence, no different from defining any typeclass instance for a given type, just using the derived function required for automatic typeclass derivation that makes part of the companion object.

Other typeclasses

Last time, we created a Bool typeclass. Can we use this mechanism to automatically derive Bools? It would be great to use and and or on a couple of binary trees and have all the plumbing for free.

Sadly, this isn’t possible. What does it mean to Leaf booleanAnd Node? There could be an argument made to discard the Leaf, but don’t forget this is a specific case; we would need to be correct for the general case. What does Some(x) booleanAnd Some(y) booleanAnd None mean, all while keeping the rules in place for binary trees too? Or a more tricky one: Left(List(false, true)) booleanOr Right(100)? Again, we could probably make rules that respect laws for this, but this is on a case-by-case basis.
This is the reason we don’t see automatic typeclass derivation for some of the more intricate typeclasses such as Monad or Applicative.

But there are cases where this makes sense. Equality, to-string, and order are all good candidates. There are also libraries already in place for JSON parsing with Circe, arbitrary value generation with Scalacheck, and plenty of others.


Get in touch with us to learn more about the subject and related solutions

Explore related posts