Xebia Background Header Wave

Generic Refinement Types in Scala 3

.
A major goal of strong typing in any programming language is to "make illegal states unrepresentable", as remarkably put in Alexis King’s article "Parse, don’t validate". To this end, refinement types have become an increasingly popular way of constraining existing types, for example by restricting an integer type to only natural numbers.

What can refinement types do for us in Scala? As it turns out, a lot! Just looking at my own personal projects, I’ve been experimenting with:

  • A small hierarchy of generic numeric types (NonZero, NonNegative and Positive) and a set of arithmetic operations that use them, such as "safe" division, square root and calculation of vector lengths.
  • A custom Seq that uses NonNegative[Int] as its index and count types, with a NonEmptySeq subclass that uses a Positive[Int] count type.
  • A Real refinement of Double that excludes values that are not real numbers (i.e. positive/negative infinity and NaN), along with a set of "safe" trigonometry functions with no corner cases for the user to worry about.
  • Distinct types for position and motion vectors in a game world, with operations that enforce how they can combine (for example we can add two motions, or shift a position by a motion, but not add two positions by mistake).

In this article, we’re going to implement the first of these examples, using Scala 3’s opaque type aliases, as well as inlining and other compile-time operations, to build our numeric types and type-safe operations. We’ll make these opaque types generic and higher-kinded so that they can constrain a variety of base types.

Our focus will be on maximizing the type safety of non-constant values as they move through our various operations. But refinement types are also great for handling compile-time constants; a great tutorial on this topic and more is "Crafting Concise Constructors with Opaque Types in Scala 3" by my Xebia colleague David A. Gil Méndez.

Using the Scastie online editor

In the following code samples, I’ll be using Scala 3.5.2 and the Scastie online editor in Worksheet mode. This mode (the default) allows for top-level declarations as if we were coding inside a main method, and shows inlay type hints for all declarations, similarly to how the Scala REPL prints out types after each line. I’ll link to the completed worksheet after the last code sample, just before the conclusion section.

Note that since package declarations aren’t allowed in Scastie worksheet, when working with opaque type aliases I’ll use a Scala object to scope their transparency. In a normal project we would use a package instead, or even just top-level declarations in a separate source file.

Laying the groundwork

Let’s get started! Since we’ll be creating several generic refinement types, let’s define a convenience base trait for their companion objects:

trait Refinement[R[_]]:
  protected def refined[T](x: T): R[T]
  def isValid[T](x: T): Boolean
  def apply[T](x: T): R[T] =
    require(isValid(x), s"invalid $x")
    refined(x)
  def option[T](x: T): Option[R[T]] =
    if isValid(x) then Some(refined(x)) else None

In this trait we declared a higher-kinded R[_] type, an isValid method to check that a value of T conforms to our refinement, and a protected refined method to turn a validated value of T into a refined R[T]. We then wrapped these into public apply and option methods to ensure that all values of R[T] are valid.

Our first refinement type

Let’s define our first refinement type, for non-zero values. Since our types are generic, we must first define a typeclass to hold the "zero" element of T to validate against:

final case class Zero[T](e: T)
def zero[T: Zero]: T = summon[Zero[T]].e

Let’s declare a generic given instance of Zero for Scala’s Numeric types:

given [T: Numeric]: Zero[T] = Zero(Numeric[T].zero)

This will give us instances for common types such as Int, Double etc.

Defining our type

With that done, here is our first draft of the NonZero type:

object refinements:
  opaque type NonZero[T] <: T = T
  object NonZero extends Refinement[NonZero]:
    override protected def refined[T](x: T): NonZero[T] = x
    override def isValid[T](x: T): Boolean = x != zero

Here we used a (parameterized) opaque type alias that is internally represented by a value of T at runtime, without wrapping or boxing. This identity is transparent inside the refinements object, which allows us to simply return x itself in our refined implementation. Outside the refinements scope, NonZero[T] is viewed as a distinct type; but since it has T as a type bound, we’ll be able to use its values everywhere a T is expected.

Inlining our methods

You may have noticed a problem in our isValid implementation: it needs an implicit parameter for the NonZero instance, and there’s no such parameter in the inherited signature. The code won’t compile if it’s missing (no Zero instance found) and won’t compile if it’s added (as the parent’s abstract method would remain unimplemented). We don’t want to add it to that parent signature either, since that would tie all our refinement types to Zero.

So how can we resolve this conundrum? Turns out we can mark the override inline and use scala.compiletime.summonInline so that an instance of Zero is required at the call site even if it’s not in the method signature. (It uses the same implicit resolution rules as a regular summon, in this case resolving to our Numeric-derived instance.) So after propagating the inlining where required by the compiler, our code becomes:

import scala.compiletime.*

trait Refinement[R[_]]:
  protected def refined[T](x: T): R[T]
  inline def isValid[T](x: T): Boolean
  inline def apply[T](x: T): R[T] =
    require(isValid(x), s"invalid $x")
    refined(x)
  inline def option[T](x: T): Option[R[T]] =
    if isValid(x) then Some(refined(x)) else None
end Refinement

final case class Zero[T](e: T)
inline def zero[T]: T = summonInline[Zero[T]].e

object refinements:
  opaque type NonZero[T] <: T = T
  object NonZero extends Refinement[NonZero]:
    override protected def refined[T](x: T): NonZero[T] = x
    override inline def isValid[T](x: T): Boolean = x != zero
end refinements

given [T: Numeric]: Zero[T] = Zero(Numeric[T].zero)

Testing our type

Now let’s test our NonZero type by using as the divisor parameter of a "safe division" function:

import refinements.*

def quotient(a: Int, b: NonZero[Int]): Int = a / b

quotient(4, 0)

Since we are outside the refinements object, the call quotient(4, 0) won’t compile because 0 is an Int and we require a NonZero[Int].

Now let’s see if we can trick the method into accepting 0 as a NonZero value:

quotient(4, NonZero(0))

Here the apply call fails with an ExceptionInInitializerError caused by "IllegalArgumentException: invalid 0". (In Scastie the initialization error might nor be obvious at first; after hitting Run look for a red squiggle at the very bottom of the worksheet). Note that in a real application the developper won’t normally call this exception-throwing apply() method, instead using option() for parsing/validation at the edges of the system and then having the type properly carried and transformed through the operations we’ll define.

Now let’s call our method with a valid NonZero value:

quotient(4, NonZero(2))

This returns the expected Int value of 2.

Adding more types

Now let’s create another refinement type in our refinements object, for non-negative (zero or greater) values:

  opaque type NonNegative[T] <: T = T
  object NonNegative extends Refinement[NonNegative]:
    override protected def refined[T](x: T): NonNegative[T] = x
    override inline def isValid[T](x: T): Boolean =
      summonInline[Ordering[T]].gteq(x, zero)

Here our isValid method requires typeclass instances of both our Zero and Scala’s own Ordering, so we use the summonInline trick once again and check that our value is greater or equal to zero.

We can use this new refinement to give a more precise type to our zero function itself, if we move it inside the refinements object:

object refinements:
  inline def zero[T]: NonNegative[T] = summonInline[Zero[T]].e

Finally, let’s create a refinement for positive (strictly greater than zero) values. This one is interesting: it’s the combination of non-zero and non-negative… could we represent this combination in a more expressive way than using an opaque type? Turns out Scala 3 does offer another way: intersection types!

  type Positive[T] = NonZero[T] & NonNegative[T]
  object Positive extends Refinement[Positive]:
    override protected def refined[T](x: T): Positive[T] = x
    override inline def isValid[T](x: T): Boolean =
      NonZero.isValid(x) && NonNegative.isValid(x)

Here our Positive type isn’t opaque, it is simply an intersection type of NonZero and NonNegative. We’re still able to return the T value as-is in our refined implementation because we’re inside the refinements object, where both NonZero[T] and NonNegative[T] are aliases of T and thus their intersection reduces to T. And just like with our bounded opaque types, we’ll be able to pass a Positive[T] everywhere a T, NonZero[T] or NonNegative[T] is required.

Defining operations

Using opaque types that encapsulate and enforce validation goes a long way towards "making illegal states unrepresentable" as we quoted in the introduction. However, since our refinements are numeric-focused we’d like to do arithmetic with them without having to constantly re-wrap the results. Let’s see if we can define some operations that can achieve this.

Our first "refined" operation

Let’s start with a simple binary operation: addition.

What should be the return type of this function? Let’s see:

  • Adding a positive and a non-negative (or vice-versa, commutatively) should yield a positive;
  • Adding two non-negatives (that aren’t known to be positive) should yield a non-negative.
  • For other input types such as non-zero numbers, we can’t guarantee anything about the result; for example adding 2 and -2 yields 0 which is outside of our refinements.

So here’s our first draft:

trait Addition[T <: Matchable]:
  protected def sumOf(x: T, y: T): T
  extension(x: T)
    def ++(y: T): T = (x, y) match
      case (_:    Positive[T], _: NonNegative[T]) => Positive(sumOf(x, y))
      case (_: NonNegative[T], _:    Positive[T]) => Positive(sumOf(x, y))
      case (_: NonNegative[T], _: NonNegative[T]) => NonNegative(sumOf(x, y))
      case _                                      => sumOf(x, y)

As with our Refinement trait, we defined a protected "unsafe" method for the user to implement in terms of plain T, and wrapped it inside a public (extension) infix operator method named ++ (more on the choice of name below). That method implements our refinement logic using a pattern match, taking care to handle specific cases before falling back to the more general ones. Following modern Scala 3 best practices, we require the T-typed argument to be Matchable, which most value and reference types will satisfy.

This gives us the correct runtime type for our return value, but it doesn’t change its compile-time type. For example, let’s declare an Addition[Int] instance and then try to feed its result to our quotient function:

given Addition[Int] with
  override protected def sumOf(x: Int, y: Int): Int = x + y

val x = Positive(1)
val y = Positive(2)
val q = quotient(9, x ++ y)

We find that this fails because x ++ y is of plain type Int instead of the NonZero[Int] expected by quotient.

So what do we do? More inlining! This time we’ll use the transparent inline modifier along with an inline match, so that the inlined code of ++ has the specific return type of the selected branch.

trait Addition[T <: Matchable]:
  protected def sumOf(x: T, y: T): T
  extension(x: T)
    transparent inline def ++(y: T): T = inline (x, y) match
      case (_:    Positive[T], _: NonNegative[T]) => Positive(sumOf(x, y))
      case (_: NonNegative[T], _:    Positive[T]) => Positive(sumOf(x, y))
      case (_: NonNegative[T], _: NonNegative[T]) => NonNegative(sumOf(x, y))
      case _                                      => sumOf(x, y)

With this syntax, we can retry the test code above and see that it works.

Choosing our operator symbols

Note that we’re defining the addition operator using two plus symbols (++). Why not override the existing + operator? We’d love to do so, unfortunately Scala built-in operators take precedence for primitive types, even when those are aliased as (locally non-transparent) opaque types. So a refinement of Int such as Positive[Int] will use the built-in + operator instead of our override, and return a plain Int in all cases. The override would be used in contexts where the refinement wraps a generic T type not known at compile time, but the distinction between these two scenarios would be easily lost on the user, causing confusion.

A follow-up question could be: why choose ++ instead of something else like the |+| used for semigroups in Cats? We do so in order to benefit from Scala’s existing operator precedence rules, which are based on the first character of the operator.

Adding more operations

Let’s follow this up with multiplication, which follows the same pattern as addition but with its own refinement rules:

trait Multiplication[T <: Matchable]:
  protected def productOf(x: T, y: T): T
  extension(x: T)
    transparent inline def **(y: T): T = inline (x, y) match
      case (_:    Positive[T], _:    Positive[T]) => Positive(productOf(x, y))
      case (_: NonNegative[T], _: NonNegative[T]) => NonNegative(productOf(x, y))
      case (_:     NonZero[T], _:     NonZero[T]) => NonZero(productOf(x, y))
      case _                                      => productOf(x, y)

In this typeclass, let’s also add a squared operation that leverages the assumption that all squares are non-negative (i.e. we’re not dealing with complex numbers here):

    transparent inline def squared: NonNegative[T] = inline x match
      case _: Positive[T] => Positive(productOf(x, x))
      case _              => NonNegative(productOf(x, x))

And speaking of squares, let’s throw in a square root operation:

trait SquareRoot[T <: Matchable]:
  protected def squareRootOf(x: T): T
  extension(x: NonNegative[T])
    transparent inline def squareRoot: NonNegative[T] = inline x match
      case _: Positive[T] => Positive(squareRootOf(x))
      case _              => NonNegative(squareRootOf(x))

Note that this operation returns the principal (non-negative) square root and thus uses the NonNegative refinement in both its input and output.

Providing given instances

As we did with our Zero typeclass, we can leverage Scala’s existing Numeric type to define generic typeclass instances for the common primitive types. We’ll use the single abstract method (SAM) shorthand notation:

given [T <: Matchable: Numeric]: Addition[T] = Numeric[T].plus(_, _)

given [T <: Matchable: Numeric]: Multiplication[T] = Numeric[T].times(_, _)

given SquareRoot[Double] = math.sqrt(_)

As we can see our SquareRoot given is Double-specific, since there’s no such operation for integral types like Int. (Scala does provide a Fractional typeclass that defines a division operation, but not a square root.)

Putting it all together

Let’s see how we can use our types and operations to define and use an operation that returns the "length" of a value in the sense of its "distance from zero". This means the absolute value for scalars, the Euclidean length for vectors etc. To avoid confusion with other notions of length such as the number of elements in a collection, we’ll call it Norm after the mathematical term.

trait Norm[T <: Matchable, N]:
  protected def normOf(x: T): N
  extension(x: T)
    transparent inline def norm: NonNegative[N] = inline x match
      case _: NonZero[T] => Positive(normOf(x))
      case _             => NonNegative(normOf(x))

Note that this is our first operation with two different types, T for the value and N for the norm. But the basic principle is the same, and our refinements here simply express that a value known (at compile time) to be non-zero will have a positive "length", otherwise if it’s only possibly non-zero, we’ll return a non-negative instead.

Here’s our absolute-value instance for scalar Numeric values, where the value and the norm have the same type:

given [T <: Matchable: Numeric]: Norm[T, T] = Numeric[T].abs(_)

For the vector instances, we’ll use plain tuples of T to represent the vectors, of 2 and 3 dimensions in this example. First let’s implement their Zero instances, by simply deriving them from their components’:

given [T: Zero]: Zero[(T, T)] = Zero(zero, zero)
given [T: Zero]: Zero[(T, T, T)] = Zero(zero, zero, zero)

Now let’s implement their Norm instances by leveraging our existing refinement types and operations:

given [T <: Matchable: Addition: Multiplication: SquareRoot: Zero: Ordering]: Norm[(T, T), T] =
  v => (v._1.squared ++ v._2.squared).squareRoot

given [T <: Matchable: Addition: Multiplication: SquareRoot: Zero: Ordering]: Norm[(T, T, T), T] =
  v => (v._1.squared ++ v._2.squared ++ v._3.squared).squareRoot

Our trait’s composite context bound, though it may look a bit daunting at first, exemplifies the "principle of least power" by requiring exactly the operations it will need to perform its task, no more and no less. (It’s a bit like requiring the "minimum" of a Functor, Applicative, Monad etc. in a library like Cats.) The Zero and Ordering bounds are a bit less obvious, but remember that our operations use them internally to validate their results when wrapping them in refinements.

As for the implementations, we’re doing the standard "square root of the sum of squares" to find the Euclidean norm. Note that we didn’t need to pass the the sum-of-squares through the NonNegative wrapper so that squareRoot will accept it: that’s because squared always returns a NonNegative, and our transparent-inline ++ operator "knows" that adding two non-negatives will also yield a non-negative.

Remember from the Norm trait that we’re also refining our given’s result to Positive if our value is known to be NonZero. Let’s test this:

val v1 = (1.0, 2.0)
val v1Norm: NonNegative[Double] = v1.norm

val v2 = NonZero((1.0, 2.0))
val v2Norm: Positive[Double] = v2.norm

And some counter-examples:

val v3 = NonZero((1, 2))
val v3Norm: Positive[Int] = v3.norm // fails to compile: no squareRoot instance

val v4 = NonZero((0.0, 0.0)) // fails at runtime in the NonZero call
val v4Norm: Positive[Double] = v4.norm // never reached

This concludes our initial foray into generic refinements! You can find the completed Scastie worksheet here.

Conclusion

In this article we’ve only scratched the surface of the possibilities of generic refined types. Libraries exist that provide a full-fledged refinement framework, such as the classic refined library and the newer alternative iron, both of which I encourage you to take a look at. But even a simple homegrown setup as we’ve defined here can provide a basis for interesting developments. Hopefully this can help you imagine how you could leverage opaque refinement types (generic or otherwise) in your real-world projects. And I’d love to hear about it – don’t hesitate to contact me with questions or comments!

Questions?

Get in touch with us to learn more about the subject and related solutions

Explore related posts