We are going to carry on with the examination of typeclasses in Scala 3. Last time we covered the nuts and bolts of what typeclasses are, why they’re good, and how they work in Scala 3.
Revisiting enum
Let’s look at our BinaryTree
class from the enum
article:
enum BinaryTree[+A]:
case Node(value: A, left: BinaryTree[A], right: BinaryTree[A])
case Leaf
scala> BinaryTree.Node(5,
| BinaryTree.Leaf,
| BinaryTree.Node(10,
| BinaryTree.Leaf,
| BinaryTree.Leaf
| )
| )
val res0: BinaryTree[Int] = Node(5,Leaf,Node(10,Leaf,Leaf))
We can create for equality:
trait Eq[A]:
def eq(x: A, y: A): Boolean
This is the same concept as in the previous article, so hopefully implementing this shouldn’t be too much of a surprise:
given[A](using eqA: Eq[A]): Eq[BinaryTree[A]] with
def eq(x: BinaryTree[A], y: BinaryTree[A]): Boolean = {
(x, y) match {
case (BinaryTree.Leaf, BinaryTree.Leaf) => true
case (BinaryTree.Node(xv, xl, xr), BinaryTree.Node(yv, yl, yr)) => eqA.eq(xv, yv) && eq(xl, yl) && eq(xr, yr)
case _ => false
}
}
There are a few things worth pointing out here:
- We have the new
using
keyword summoning theEq
instance forA
: If we can’t measure the equality ofA
, then we can’t measure the equality ofBinaryTree[A]
. - We are pattern matching on the enum: This works as it did in Scala 2.
- We are recursively checking that the sub-trees are equal.
We’ve got Eq
implemented for our binary tree: going forward, anyone can use our BinaryTree
in a
function requiring Eq
, as long as the underlying type has instances for those typeclasses in scope.
Let’s turn our attention to implementing Show
for BinaryTree[A]
:
trait Show[A]:
def show(a: A): String
Show will return a String
representation of whatever the value A
is.
This should seem vary similar to the Eq
approach:
- We will need the compiler to provide us with a
Show[A]
instance. - Base case: work out what to do for
Leaf
: print the empty string. - Cons case: invoke
Show[A]
for the branches in the tree, appending whatevershow
returns for the subtree.
Conceptually, this is the same as our Eq
implementation. Considering we’re doing the same thing,
wouldn’t it be great if this could be automatically written for us?
Automatic Typeclass Derivation
Thankfully, we can tell Scala 3 to do it. We break a type down into its constituent parts and Show
these instead.
The Mirror typeclass
The Mirror
typeclass gives us information about the type itself, the names of the fields, the types of the fields, and how it all fits together. The compiler is able to provide this information for:
- Case classes
- Case objects
- Sealed hierarchies containing case object and case classes
- Enums
- Enum cases
For now, the most important part of the Mirror
is the breakdown between Sum
and Product
.
Our BinaryTree
is an Algebraic Data Type (ADT) built from a Sum
of two variants: Node
and Leaf
.
Node
is a Product
, built from a combination of an A
and two more BinaryTree
instances.
It may be helpful to think of a Product
as a tuple containing the values that type would hold,
plus some metadata about the field names and class name.
We will be using these imports for this article:
import scala.deriving.*
import scala.compiletime.{erasedValue, summonInline}
In order for Scala to build Show
typeclass instances for us, we need to tell the compiler how to build Show
instances for Sum
s and Product
s:
Showing products
We can assume that there is a Show
instance available for all of the parts of the product.
A possible definition for Show
ing a product could look like this:
def showProduct[T](shows: => List[Show[_]]): Show[T] = // (1)
new Show[T]: // (2)
def show(t: T): String = {
(t.asInstanceOf[Product].productIterator).zip(shows.iterator).map { // (3)
case (p, s) => s.asInstanceOf[Show[Any]].show(p) // (4)
}.mkString
}
There’s quite a lot going on here!
- For this function, we are assuming (hoping?) the caller gives us a list of
Show
instances in the same order they are defined in the product definition. So forcase class Prod[A](a: A, s: String, i: Int)
, this list would contain, in this order,List(Show[A], Show[String], Show[Int])
. It should be clear why this is defined asShow[_]
: we will have different values for the type parameter throughout the list. TheShow
instances can be available through derivation (using the list of options described forMirror
above), or if an explicitgiven
is provided. - This function creates a new
Show
instance every time it is called. While this is extremely useful, if you are auto-deriving large trees, and doing this many times, then compilation times will noticeably increase. You may find it useful to make a reference to an auto-derived instance in your code, and then refer to that instead; this way, the auto-derivation code will not be re-run every time for the same underlying type. - We pair the value of type
X
as aProduct
with itsShow[X]
. - We
show
that value, and concatenate everything into a single string. It is important to note here that this is going to be the format of your string for every auto derived type. If you want to do anything fancy for a particular type (e.g., printing a binary tree into a nice tree-like structure), you’ll be better off doing this in its ownShow
instance. This code is effectively going to be the defaultshow
format for all types that use it. We also need to cast ourShow
to a type. We are fine to cast this toAny
; we don’t have any more type information at this point, and the compiler has made sure the instance is the correct type before we even get to this point.
Showing sums
Some of the sum types are going to be made up from products. Again, we assume Show
instances are available for all of the different sum variants, including auto-deriving products if necessary.
Our Show
for sums could look like this:
def showSum[T](s: Mirror.SumOf[T], shows: => List[Show[_]]): Show[T] = // (1)
new Show[T]:
def show(t: T): String = {
val index = s.ordinal(t) // (2)
shows(index).asInstanceOf[Show[Any]].show(t) // (3)
}
- This function makes use of the
SumOf
type within theMirror
typeclass.
This is going to give us information about the structure of the differingSum
variants.
Note, for our product, we didn’t need a reference to anyMirror
. There is aMirror.ProductOf[T]
that we didn’t need, but we could use it for more complicated typeclasses.Mirror.ProductOf[T]
has the ability to create an instance ofT
from aProduct
, so we could, for instance, change the values in a product and construct a newT
at this point. - Assuming again the
List[Show[_]]
are in the order that matches the order of the sum variant definitions, we grab theShow
instance at the right position in the list. - As with the product, we cast, we call
show
, and we’re done.
Putting it all together
Now that we know how to derive sums and products, we need to put it together and tell the compiler how to use this automatically.
If we include a given
named derived
, then the compiler will use this function to try to derive a typeclass for us. It has a very specific signature:
inline given derived[T](using m: Mirror.Of[T]): Show[T]
First, we need a way to grab either all of the sum variant or product field typeclass instances we want to derive for:
import scala.compiletime.{erasedValue, summonInline}
inline def summonAll[T <: Tuple]: List[Show[_]] =
inline erasedValue[T] match
case _: EmptyTuple => Nil
case _: (t *: ts) => summonInline[Show[t]] :: summonAll[ts]
summonAll
summons the instances in a list, respecting the order of the fields as we have assumed for our showSum
and showProduct
functions.
We can implement the derived
function: we have all of our needed Show
instances, and we know how to derive sums and products. We will put the derived
function, along with any "base" or default typeclass instances, in the companion object:
object Show:
inline given derived[T](using m: Mirror.Of[T]): Show[T] =
lazy val shows = summonAll[m.MirroredElemTypes]
inline m match
case s: Mirror.SumOf[T] => showSum(s, shows)
case _: Mirror.ProductOf[T] => showProduct(shows)
The lazy val
means that any recursive definitions are not computed until needed, as this would produce a StackOverflowError
when trying to derive all of the cases of the ADT.
This is why the List[Show[_]]
parameters to showProduct
and showSum
have a call-by-name marker on the definition.
That’s it! Now we can automatically derive any Show
typeclass instance. We need to make a small adjustment to our type to say that it can derive Show
:
enum BinaryTree[+A] derives Show:
case Node(value: A, left: BinaryTree[A], right: BinaryTree[A])
case Leaf
Now, when we drop into a console, as long as we can derive a Show[A]
, we can derive a Show[BinaryTree[A]]
:
scala> given Show[String] with
| def show(s: String): String = s"$s!"
|
// defined object given_Show_String
scala> summon[Show[BinaryTree[String]]]
val res0: Show[BinaryTree[String]] = Show$$anon$2@5043da15
scala> res0.show(BinaryTree.Leaf)
val res1: String = ""
scala> res0.show(BinaryTree.Node("Hello", BinaryTree.Leaf, BinaryTree.Leaf))
val res2: String = Hello!
scala> res0.show(BinaryTree.Node("Hello", BinaryTree.Leaf, BinaryTree.Node("World", BinaryTree.Leaf, BinaryTree.Leaf)))
val res3: String = Hello!World!
And, of course, this works for types that have nothing to do with BinaryTree
.
Using the same Show[String]
instance from above:
scala> case class Record(s: String) derives Show
// defined case class Record
scala> summon[Show[Record]]
val res4: Show[Record] = Show$$anon$1@4ec6634c
scala> res4.show(Record("New record"))
val res5: String = New record!
If there are other typeclass derivations we wish to include, the derives
keyword takes a comma-separated list:
enum BinaryTree[+A] derives Show, Eq:
case Node(value: A, left: BinaryTree[A], right: BinaryTree[A])
case Leaf
Think about how Eq
would be implemented in terms of sum and product. How does it differ from Show
? What does it mean for a Leaf
to be compared to a Node
? What about for different product types?
scala> summon[Eq[BinaryTree[Int]]]
val res0: Eq[BinaryTree[Int]] = Eq$$anon$1@7c4384c7
scala> res0.eq(BinaryTree.Node(1, BinaryTree.Leaf, BinaryTree.Leaf), BinaryTree.Node(2, BinaryTree.Leaf, BinaryTree.Leaf))
val res1: Boolean = false
We can also apply the typeclass derivation after the class is defined. This is useful if you do not have control over the class:
given[T: Show]: Show[BinaryTree[T]] = Show.derived
given[T: Eq] : Eq[BinaryTree[T]] = Eq.derived
This is, in essence, no different from defining any typeclass instance for a given type, just using the derived
function required for automatic typeclass derivation that makes part of the companion object.
Other typeclasses
Last time, we created a Bool
typeclass. Can we use this mechanism to automatically derive Bool
s? It would be great to use and
and or
on a couple of binary trees and have all the plumbing for free.
Sadly, this isn’t possible. What does it mean to Leaf booleanAnd Node
? There could be an argument made to discard the Leaf
, but don’t forget this is a specific case; we would need to be correct for the general case. What does Some(x) booleanAnd Some(y) booleanAnd None
mean, all while keeping the rules in place for binary trees too? Or a more tricky one: Left(List(false, true)) booleanOr Right(100)
? Again, we could probably make rules that respect laws for this, but this is on a case-by-case basis.
This is the reason we don’t see automatic typeclass derivation for some of the more intricate typeclasses such as Monad
or Applicative
.
But there are cases where this makes sense. Equality, to-string, and order are all good candidates. There are also libraries already in place for JSON parsing with Circe, arbitrary value generation with Scalacheck, and plenty of others.