Blog

Notes on Writing a Scala 3 Compiler Plugin

26 Jun, 2023
Xebia Background Header Wave

Earlier this year, I worked on the compiler plugin for scourt, a project to implement coroutines in Scala alongside Unwrapped, which was proposed in PRE-SIP: Suspended functions and continuations. As a first time working with the compiler, I took notes of things I thought were useful. And I’d like to share them with you to help make your life easier as well.

Intro

To begin, it would be helpful to explain a few terms that we will see later.

  • Trees: or Abstract Syntax Tree refers to the data structure the compiler builds to represent the code. This structure is being transformed as it passes through Phases. For example, for a variable and a method like
private val hello: String = "Hello"

def exampleMethod: Int = {    
  1 
}

the corresponding Trees are

// using ctx: Context
val valdef = tpd.ValDef(
  Symbols.newSymbol(ctx.owner, termName("hello"), Flags.Private, defn.StringType),
  tpd.Literal(Constant("Hello"))
)

val defdef =
  tpd.DefDef(
    Symbols.newSymbol(ctx.owner, termName("exampleMethod"), Flags.Method, defn.IntType),
    tpd.Block(List.empty, tpd.Literal(Constant(1)))
  )

You can also use the AST Explorer to see the corresponding Trees of some Scala code.

  • Phases: are used to transform the Tree. Looking into the Scala 3 Compiler Phases we see for example Inlining which supports the inline functionality. When we create a plugin, we add a new Phase to the compiler in order to transform the Tree, like in our case with ContinuationsPhase.
  • Symbols: are used to uniquely identify an entity like a class, a method, or a type and include among others the name of the entity. A ClassSymbol can be used to describe a class, a trait, or an object. In our previous example, you can see that we create the symbols for our Trees with the Symbols.newSymbol method giving them a name, flags, and a type.
  • Denotations: Symbols themselves don’t have a structure. That is why they are associated with Denotations that represent the meaning of a symbol on a specific compiler Phase. In our code, we used denotations in order to compare Trees
    tree1.symbol.denot.matches(tree2.symbol)
  • Names: are wrapped strings with only one copy stored. As part of the Symbol, they also
    correspond to specific terms (variables, methods, or class parameters) or types (class, trait, type parameters).
  • Flags: they define semantics for a symbol (e.g., if it is mutable, a method, an implicit, etc.).
    In the example above, we used Flags to say that a symbol should have the private access modifier or will be used in a method.

Helper methods and suggestions

Now we can start with some suggestions for where to find methods that will help you write Trees for the compiler.

  • To start, always make sure to look around in the compiler codebase for helper methods. A few places to look for useful extension methods on Symbols, Types, Names, or generally are SymUtils, TypeUtils, NameOps and Decorators. These methods have helped me especially with debugging and identifying patterns.

  • Another place to keep in mind is StdNames and the StdNames.nme. Here you can find common method names like nme.asInstanceOf_, nme.CONSTRUCTOR, nme.OR, nme.apply, or even nme.x_0. These are already terms, so there is no need to wrap them in termName either. And if you want to find some ready definitions of symbols or types that you can use, look at Definitions or use Symbols.defn, for example defn.IntType, defn.AnyType.

  • If you are creating or handling a Tree, make sure to use the helper methods from TreeOps. For example: tree.select, tree.appliedTo, tree.appliedToType vs Trees.Select, Trees.Apply, Trees.TypeApply, etc., where tree is the name of a variable val tree: tpd.ValDef = ??? and Trees is the structure. In tpd, you can find all kinds of useful Trees like the underscore or helper methods like isInstance.

  • Do you want to see the signature of the Tree you are working with? Then .symbol.showDcl is what you need. You will notice there is also a show method, but these two could show different signatures with different parameters. You may think that the one from show is what you are working with and try to call the method with those parameters, but the compiler will try to call the method as it is shown from the symbol resulting in an error. Something to point out here is that, with these two methods, you can’t see the Tree’s Flags, so be sure to check the symbol.flags as well (for something that is implicit, lazy, private, etc.).

  • To reference existing code like classes, traits, or objects, you can useSymbols.requiredClass, Symbols.requiredModule, Symbols.requiredPackage, etc.
    For example

    Symbols.requiredClassRef("scala.util.Either").appliedTo(defn.ThrowableType, defn.IntType)

    to create

    Either[Throwable, Int]
  • If you have a Tree and you want to reference it later in your code, then you can use ref(tree.symbol). For example, for

    val x = 1
    val y = x

    you can do

    val x = tpd.Literal(Constant(1))
    val y = ref(x.symbol)
  • Do you want to see if your Tree is a method? Then all you need is tree.symbol.is(Flags.Method)). By the way, a var is just a ValDef with the Mutable flag.

  • A really useful class is the TreeTypeMap as it helps from transforming a Tree to changing just a symbol inside the Tree.

  • But how do you even define a constant value? What you need is tpd.Literal(Constant("Hello World")).

  • Generally, try to stick to using tpd or Trees for importing things like ValDef or DefDef. Otherwise, it can become a bit hard to read the code. Although admittedly, sometimes there is no other option than to mix them. In the same context, if you want to pattern match a Tree to see its type but not use any params, then you can still use just tpd,

    tree match { case t: tpd.DefDef => println(t.show) }

    instead of

    tree match { case t @ Trees.DefDef(_, _, _, _) => println(t.show) }

Things to look out for

Be careful, some things are not so obvious!

  • Always remember to enter the symbols you create with symbol.entered or one of the similar methods. If their owner is a class, then the symbol will be in scope. Otherwise, it won’t. That goes hand in hand with making sure to use TreeTypeMap to change owners if needed. A symbol being attached to an old owner was one of the most common issues we had to deal with, and, unfortunately, figuring out the problem and which symbol needs updating was not an easy task (usually, you will see something like java.lang.IllegalArgumentException: Could not find proxy:...).

  • When traversing a Tree, take into account that the compiler creates anonymous functions for cases like lambdas def func(f: Int => Boolean), context functions, etc. In our case, we were extracting pieces of code to use somewhere else, so we had to make sure that we could identify these synthetic functions, take the inner tree, and then change its owners as well.

  • Flags play an important role, so make sure to understand which ones you need as they can be used to transform a tree. For example, if you create the Tree for a class with a parameter, but forget to add the Flags.LocalParamAccessor to it (like for the params in our synthetic class), then in the Constructors phase, it will ignore this parameter. After that, the compiler may add its own synthetic parameter (e.g., in the case that our new class is nested inside another class) and it will then assume it has only one parameter (the one the compiler added) and not two. Later, it will pick this presumably one parameter, but actually, it will be our parameter, not the one added by the compiler. Unfortunately, from what I have seen, creating a class and adding a constructor with a parameter won’t add the flag or enter the symbol in the scope either. We have to do these steps explicitly later on once the constructor has created the symbol for the parameter.

  • If you add parameters to a def and then want to use those parameters inside its body, then you need to take them from the def itself and not use the initial variable that has not been assigned to the def. For example, for

    val x = ???
    val defdef = DefDef(newSymbol(???), paramss = List(List(x.symbol)), ???)

    we can’t use the x val inside the defdef body, but we need to get the method params like val paramsToUse = defdef.paramss or val paramsToUse = defdef.termParamss and find thex one before using it. Hopefully, looking at some code will make more sense:

val param: Symbol = newSymbol(ctx.owner, termName("param"), Flags.LocalParam, defn.IntType)

val defdef: tpd.DefDef =
  tpd.DefDef(
    sym = newSymbol(
      ctx.owner,
      Names.termName("exampleMethod"),
      Flags.Method,
      MethodType.fromSymbols(List(param), defn.IntType)
    )
  )

val defdefParam: Symbol =
  defdef.termParamss.flatten.find(_.symbol.denot.matches(param)).get.symbol

val defdefWithBody =
  cpy.DefDef(defdef)(rhs = ref(defdefParam).select(defn.Int_+).appliedTo(tpd.Literal(Constant(1))))

tpd.Block(List(defdefWithBody), ref(defdefWithBody.symbol).appliedTo(tpd.Literal(Constant(1))))

this Tree when compiled produces

{
  def exampleMethod(param: Int): Int = param.+(1)
  exampleMethod(1)
}

However, if we replace the defdefParam inside the method body with param

val defdefWithBody =
  cpy.DefDef(defdef)(rhs = ref(param).select(defn.Int_+).appliedTo(tpd.Literal(Constant(1))))

compilation will fail with
Exception in thread "main" java.util.NoSuchElementException: val param.
You can see this scenario being used in our plugin here. In this example, we also used cpy in order to copy an existing Tree and change only one of its options, the body or rhs.

  • If you want to transform or traverse Trees, you can use phases like transformDefDef or methods like filterSubTrees, shallowFold or deepFold. Bear in mind that these are depth-first traversals following the order of the tree; the first branch will get fully traversed before continuing to the next one.
    For example, if we have
val block = tpd.Block(
  List(
    tpd.Block(List(tpd.Literal(Constant("A"))), tpd.Literal(Constant(1))),
    tpd.Block(List(tpd.Literal(Constant("B"))), tpd.Literal(Constant(2)))
  ),
  tpd.Literal(Constant(3))
)

TreeTypeMap(treeMap = tree => {
  println(tree.show)
  tree
})(block)

the outcome will be

{ { "A" 1 } { "B" 2 }  3 } -> the whole tree
{ "A" 1 } 
"A"
1
{ "B" 2 }
"B"
2
3

Make sure to watch Jack’s talk to see how you can easily test this yourself.

However, there are some special cases that initially surprised me. If you have a Trees.Inlined, then using shallowFold will apply the transformation to the expanded Tree, the one you get after the inline has been applied and not the initial Tree that is supposed to be replaced. So if you want to find an Inlined and change that existing Tree, you have to do it explicitly, do a pattern match, and take the call from Inlined(call,...) like seen here. The same applies to class constructor parameters and anonymous functions, as they are not part of the traverse. Similarly, we can explicitly pattern match against tpd.Apply or use the transformParams, transformParamss and transformDefDef methods.

Tricks

Finally, let’s look at some examples mainly from the compiler codebase but which are not always quite obvious.

  • Let’s say that you want to create the Tree that throws an exception with a specific message. It sounds like a simple task, but actually, you have to follow this pattern here or from our plugin code with how to also throw the exception here
val IllegalArgumentExceptionClass = requiredClass("java.lang.IllegalArgumentException") // or defn.IllegalArgumentExceptionClass

val IllegalArgumentExceptionClass_stringConstructor: TermSymbol =
  IllegalArgumentExceptionClass
    .info
    .member(nme.CONSTRUCTOR)
    .suchThat(_.info.firstParamTypes match {
      case List(pt) => pt.stripNull.isRef(defn.StringClass)
      case _ => false
    })
    .symbol
    .asTerm

val throwException = tpd.Throw(
  tpd.New(
    IllegalArgumentExceptionClass.typeRef, // or defn.IllegalArgumentExceptionType
    IllegalArgumentExceptionClass_stringConstructor,
    List(tpd.Literal(Constant("wrong argument")))
  )
)

println(throwException.show) // throw new IllegalArgumentException("wrong argument")

However, we later realized that you can also disambiguate overloaded member using select with Symbol => Boolean as the disambiguation handler

val IllegalArgumentExceptionClass_stringConstructor: TermSymbol =
  ref(IllegalArgumentExceptionClass).select(
    nme.CONSTRUCTOR,
    _.info.firstParamTypes match {
      case List(pt) => pt.stripNull.isRef(defn.StringClass)
      case _ => false
    }).symbol.asTerm
  • Now, if you want to create the Tree for an addition like 1 + 1, there are a few ways, including trying to require the Int type and then searching for the plus method. But this has already been done for you. So you can try

    tpd.Literal(Constant(1)).select(defn.Int_+).appliedTo(tpd.Literal(Constant(1)))

    But what if you want to do an OR like 9 | 10? This is not available, but the pattern is the same. We need

    val Int_| = (defn.IntClass.requiredMethod(nme.OR, List(defn.IntType)))
  • Next is how we can get something like the minimum value of an integer Int.MinValue, and for that, we can do

    val Int_Min =ref(requiredModuleRef("scala.Int").select(Names.termName("MinValue")).symbol)
  • Another interesting bit we had to do was to call the constructor of a class, something like new AClass[Int](2). We have seen that when we were throwing an exception, but just for clarity, here is the code:

    tpd.New(ref(AClassSymbol)).select(nme.CONSTRUCTOR).appliedToType(defn.IntType).appliedTo(tpd.Literal(Constant(2)))
  • In this scenario, we want to do a pattern match against the type like value match { case x: Int => ??? } and the question is how can we define the x: Int part. Looking in the Trees.scala, we can see there is Bind or Typed. But looking in the compiler code, we see that the suggestion is to use the BindTyped. Overall, the Trees for value match { case x$0: Int => ??? } will look like

    val caseParam = newSymbol(owner, nme.x_0, Flags.Case | Flags.CaseAccessor, defn.IntType)
    tpd.Match(ref(value.symbol), List(tpd.CaseDef(tpd.BindTyped(caseParam, caseParam.info), ???, ???)))

    and for value { match case x$0 => ??? } we will have

    val caseParam = newSymbol(owner, nme.x_0, Flags.Case | Flags.CaseAccessor, defn.IntType)
    tpd.Match(ref(value.symbol), List(tpd.CaseDef(tpd.Bind(caseParam, tpd.EmptyTree), ???, ???)))

    In our plugin transformations, we also used a combination of Bind and Typed

    val caseParam = newSymbol(owner, nme.x_0, Flags.Case | Flags.CaseAccessor, defn.IntType)
    tpd.Match(ref(value.symbol), List(tpd.CaseDef(tpd.Bind(caseParam, tpd.Typed(ref(param), ref(defn.IntType))), ???, ???)))

    which can also be seen here.

Conclusion

Working with the compiler is not easy, or maybe it is just different from developing HTTP APIs and streaming platforms. Code is being extracted, lifted, moved out or inside, or even deleted in various Phases. Things matter even if it doesn’t look like it; you may not have registered a new symbol or used an extra Flag, and the compiler treats it in a way you didn’t expect.

But overall, it was an interesting experience. I feel like I learned a lot, including how to navigate around the Scala 3 compiler codebase, and in the end, I was really happy to see our plugin working.

If you are interested in writing a plugin yourself, I would say it is worth it to try and have a look around the compiler codebase, and also look at some existing plugins. When the time comes, you will be able to recognize the patterns you need.

For more ideas that will help you write a testable compiler plugin, make sure to watch Jack Viers’ talk from Scala Days Testable Compiler Plugin Development in Scala 3. (NOTE: We will update this post with the direct link to the video once the Scala Days 2023 presentations have been publically published. You can also subscribe to the Scala Days YouTube channel to be notified when videos are published.)

Useful links

Dotty Docs
Compiler Plugin Development in Scala 3 | Let’s talk about Scala 3
Scala 3 Compiler Academy YouTube Channel
Scala Center Sprees
Compilers are Databases

Questions?

Get in touch with us to learn more about the subject and related solutions

Explore related posts