Generic Refinement Types in Scala 3
.
A major goal of strong typing in any programming language is to "make illegal states unrepresentable", as remarkably put in Alexis King’s article "Parse, don’t validate". To this end, refinement types have become an increasingly popular way of constraining existing types, for example by restricting an integer type to only natural numbers.
What can refinement types do for us in Scala? As it turns out, a lot! Just looking at my own personal projects, I’ve been experimenting with:
- A small hierarchy of generic numeric types (
NonZero
,NonNegative
andPositive
) and a set of arithmetic operations that use them, such as "safe" division, square root and calculation of vector lengths. - A custom
Seq
that usesNonNegative[Int]
as its index and count types, with aNonEmptySeq
subclass that uses aPositive[Int]
count type. - A
Real
refinement ofDouble
that excludes values that are not real numbers (i.e. positive/negative infinity andNaN
), along with a set of "safe" trigonometry functions with no corner cases for the user to worry about. - Distinct types for position and motion vectors in a game world, with operations that enforce how they can combine (for example we can add two motions, or shift a position by a motion, but not add two positions by mistake).
In this article, we’re going to implement the first of these examples, using Scala 3’s opaque type aliases, as well as inlining and other compile-time operations, to build our numeric types and type-safe operations. We’ll make these opaque types generic and higher-kinded so that they can constrain a variety of base types.
Our focus will be on maximizing the type safety of non-constant values as they move through our various operations. But refinement types are also great for handling compile-time constants; a great tutorial on this topic and more is "Crafting Concise Constructors with Opaque Types in Scala 3" by my Xebia colleague David A. Gil Méndez.
Using the Scastie online editor
In the following code samples, I’ll be using Scala 3.5.2 and the Scastie online editor in Worksheet mode. This mode (the default) allows for top-level declarations as if we were coding inside a main method, and shows inlay type hints for all declarations, similarly to how the Scala REPL prints out types after each line. I’ll link to the completed worksheet after the last code sample, just before the conclusion section.
Note that since package declarations aren’t allowed in Scastie worksheet, when working with opaque type aliases I’ll use a Scala object to scope their transparency. In a normal project we would use a package instead, or even just top-level declarations in a separate source file.
Laying the groundwork
Let’s get started! Since we’ll be creating several generic refinement types, let’s define a convenience base trait for their companion objects:
trait Refinement[R[_]]:
protected def refined[T](x: T): R[T]
def isValid[T](x: T): Boolean
def apply[T](x: T): R[T] =
require(isValid(x), s"invalid $x")
refined(x)
def option[T](x: T): Option[R[T]] =
if isValid(x) then Some(refined(x)) else None
In this trait we declared a higher-kinded R[_]
type, an isValid
method to check that a value of T
conforms to our refinement, and a protected refined
method to turn a validated value of T
into a refined R[T]
. We then wrapped these into public apply
and option
methods to ensure that all values of R[T]
are valid.
Our first refinement type
Let’s define our first refinement type, for non-zero values. Since our types are generic, we must first define a typeclass to hold the "zero" element of T
to validate against:
final case class Zero[T](e: T)
def zero[T: Zero]: T = summon[Zero[T]].e
Let’s declare a generic given instance of Zero
for Scala’s Numeric
types:
given [T: Numeric]: Zero[T] = Zero(Numeric[T].zero)
This will give us instances for common types such as Int
, Double
etc.
Defining our type
With that done, here is our first draft of the NonZero
type:
object refinements:
opaque type NonZero[T] <: T = T
object NonZero extends Refinement[NonZero]:
override protected def refined[T](x: T): NonZero[T] = x
override def isValid[T](x: T): Boolean = x != zero
Here we used a (parameterized) opaque type alias that is internally represented by a value of T
at runtime, without wrapping or boxing. This identity is transparent inside the refinements
object, which allows us to simply return x
itself in our refined
implementation. Outside the refinements
scope, NonZero[T]
is viewed as a distinct type; but since it has T
as a type bound, we’ll be able to use its values everywhere a T
is expected.
Inlining our methods
You may have noticed a problem in our isValid
implementation: it needs an implicit parameter for the NonZero
instance, and there’s no such parameter in the inherited signature. The code won’t compile if it’s missing (no Zero
instance found) and won’t compile if it’s added (as the parent’s abstract method would remain unimplemented). We don’t want to add it to that parent signature either, since that would tie all our refinement types to Zero
.
So how can we resolve this conundrum? Turns out we can mark the override inline
and use scala.compiletime.summonInline
so that an instance of Zero
is required at the call site even if it’s not in the method signature. (It uses the same implicit resolution rules as a regular summon
, in this case resolving to our Numeric
-derived instance.) So after propagating the inlining where required by the compiler, our code becomes:
import scala.compiletime.*
trait Refinement[R[_]]:
protected def refined[T](x: T): R[T]
inline def isValid[T](x: T): Boolean
inline def apply[T](x: T): R[T] =
require(isValid(x), s"invalid $x")
refined(x)
inline def option[T](x: T): Option[R[T]] =
if isValid(x) then Some(refined(x)) else None
end Refinement
final case class Zero[T](e: T)
inline def zero[T]: T = summonInline[Zero[T]].e
object refinements:
opaque type NonZero[T] <: T = T
object NonZero extends Refinement[NonZero]:
override protected def refined[T](x: T): NonZero[T] = x
override inline def isValid[T](x: T): Boolean = x != zero
end refinements
given [T: Numeric]: Zero[T] = Zero(Numeric[T].zero)
Testing our type
Now let’s test our NonZero
type by using as the divisor parameter of a "safe division" function:
import refinements.*
def quotient(a: Int, b: NonZero[Int]): Int = a / b
quotient(4, 0)
Since we are outside the refinements
object, the call quotient(4, 0)
won’t compile because 0 is an Int
and we require a NonZero[Int]
.
Now let’s see if we can trick the method into accepting 0 as a NonZero
value:
quotient(4, NonZero(0))
Here the apply
call fails with an ExceptionInInitializerError
caused by "IllegalArgumentException: invalid 0"
. (In Scastie the initialization error might nor be obvious at first; after hitting Run look for a red squiggle at the very bottom of the worksheet). Note that in a real application the developper won’t normally call this exception-throwing apply()
method, instead using option()
for parsing/validation at the edges of the system and then having the type properly carried and transformed through the operations we’ll define.
Now let’s call our method with a valid NonZero
value:
quotient(4, NonZero(2))
This returns the expected Int
value of 2.
Adding more types
Now let’s create another refinement type in our refinements
object, for non-negative (zero or greater) values:
opaque type NonNegative[T] <: T = T
object NonNegative extends Refinement[NonNegative]:
override protected def refined[T](x: T): NonNegative[T] = x
override inline def isValid[T](x: T): Boolean =
summonInline[Ordering[T]].gteq(x, zero)
Here our isValid
method requires typeclass instances of both our Zero
and Scala’s own Ordering
, so we use the summonInline
trick once again and check that our value is greater or equal to zero
.
We can use this new refinement to give a more precise type to our zero
function itself, if we move it inside the refinements
object:
object refinements:
inline def zero[T]: NonNegative[T] = summonInline[Zero[T]].e
Finally, let’s create a refinement for positive (strictly greater than zero) values. This one is interesting: it’s the combination of non-zero and non-negative… could we represent this combination in a more expressive way than using an opaque type? Turns out Scala 3 does offer another way: intersection types!
type Positive[T] = NonZero[T] & NonNegative[T]
object Positive extends Refinement[Positive]:
override protected def refined[T](x: T): Positive[T] = x
override inline def isValid[T](x: T): Boolean =
NonZero.isValid(x) && NonNegative.isValid(x)
Here our Positive
type isn’t opaque, it is simply an intersection type of NonZero
and NonNegative
. We’re still able to return the T
value as-is in our refined
implementation because we’re inside the refinements
object, where both NonZero[T]
and NonNegative[T]
are aliases of T
and thus their intersection reduces to T
. And just like with our bounded opaque types, we’ll be able to pass a Positive[T]
everywhere a T
, NonZero[T]
or NonNegative[T]
is required.
Defining operations
Using opaque types that encapsulate and enforce validation goes a long way towards "making illegal states unrepresentable" as we quoted in the introduction. However, since our refinements are numeric-focused we’d like to do arithmetic with them without having to constantly re-wrap the results. Let’s see if we can define some operations that can achieve this.
Our first "refined" operation
Let’s start with a simple binary operation: addition.
What should be the return type of this function? Let’s see:
- Adding a positive and a non-negative (or vice-versa, commutatively) should yield a positive;
- Adding two non-negatives (that aren’t known to be positive) should yield a non-negative.
- For other input types such as non-zero numbers, we can’t guarantee anything about the result; for example adding 2 and -2 yields 0 which is outside of our refinements.
So here’s our first draft:
trait Addition[T <: Matchable]:
protected def sumOf(x: T, y: T): T
extension(x: T)
def ++(y: T): T = (x, y) match
case (_: Positive[T], _: NonNegative[T]) => Positive(sumOf(x, y))
case (_: NonNegative[T], _: Positive[T]) => Positive(sumOf(x, y))
case (_: NonNegative[T], _: NonNegative[T]) => NonNegative(sumOf(x, y))
case _ => sumOf(x, y)
As with our Refinement
trait, we defined a protected "unsafe" method for the user to implement in terms of plain T
, and wrapped it inside a public (extension) infix operator method named ++
(more on the choice of name below). That method implements our refinement logic using a pattern match, taking care to handle specific cases before falling back to the more general ones. Following modern Scala 3 best practices, we require the T
-typed argument to be Matchable
, which most value and reference types will satisfy.
This gives us the correct runtime type for our return value, but it doesn’t change its compile-time type. For example, let’s declare an Addition[Int]
instance and then try to feed its result to our quotient
function:
given Addition[Int] with
override protected def sumOf(x: Int, y: Int): Int = x + y
val x = Positive(1)
val y = Positive(2)
val q = quotient(9, x ++ y)
We find that this fails because x ++ y
is of plain type Int
instead of the NonZero[Int]
expected by quotient
.
So what do we do? More inlining! This time we’ll use the transparent inline
modifier along with an inline match, so that the inlined code of ++
has the specific return type of the selected branch.
trait Addition[T <: Matchable]:
protected def sumOf(x: T, y: T): T
extension(x: T)
transparent inline def ++(y: T): T = inline (x, y) match
case (_: Positive[T], _: NonNegative[T]) => Positive(sumOf(x, y))
case (_: NonNegative[T], _: Positive[T]) => Positive(sumOf(x, y))
case (_: NonNegative[T], _: NonNegative[T]) => NonNegative(sumOf(x, y))
case _ => sumOf(x, y)
With this syntax, we can retry the test code above and see that it works.
Choosing our operator symbols
Note that we’re defining the addition operator using two plus symbols (++
). Why not override the existing +
operator? We’d love to do so, unfortunately Scala built-in operators take precedence for primitive types, even when those are aliased as (locally non-transparent) opaque types. So a refinement of Int
such as Positive[Int]
will use the built-in +
operator instead of our override, and return a plain Int
in all cases. The override would be used in contexts where the refinement wraps a generic T
type not known at compile time, but the distinction between these two scenarios would be easily lost on the user, causing confusion.
A follow-up question could be: why choose ++
instead of something else like the |+|
used for semigroups in Cats? We do so in order to benefit from Scala’s existing operator precedence rules, which are based on the first character of the operator.
Adding more operations
Let’s follow this up with multiplication, which follows the same pattern as addition but with its own refinement rules:
trait Multiplication[T <: Matchable]:
protected def productOf(x: T, y: T): T
extension(x: T)
transparent inline def **(y: T): T = inline (x, y) match
case (_: Positive[T], _: Positive[T]) => Positive(productOf(x, y))
case (_: NonNegative[T], _: NonNegative[T]) => NonNegative(productOf(x, y))
case (_: NonZero[T], _: NonZero[T]) => NonZero(productOf(x, y))
case _ => productOf(x, y)
In this typeclass, let’s also add a squared
operation that leverages the assumption that all squares are non-negative (i.e. we’re not dealing with complex numbers here):
transparent inline def squared: NonNegative[T] = inline x match
case _: Positive[T] => Positive(productOf(x, x))
case _ => NonNegative(productOf(x, x))
And speaking of squares, let’s throw in a square root operation:
trait SquareRoot[T <: Matchable]:
protected def squareRootOf(x: T): T
extension(x: NonNegative[T])
transparent inline def squareRoot: NonNegative[T] = inline x match
case _: Positive[T] => Positive(squareRootOf(x))
case _ => NonNegative(squareRootOf(x))
Note that this operation returns the principal (non-negative) square root and thus uses the NonNegative
refinement in both its input and output.
Providing given instances
As we did with our Zero
typeclass, we can leverage Scala’s existing Numeric
type to define generic typeclass instances for the common primitive types. We’ll use the single abstract method (SAM) shorthand notation:
given [T <: Matchable: Numeric]: Addition[T] = Numeric[T].plus(_, _)
given [T <: Matchable: Numeric]: Multiplication[T] = Numeric[T].times(_, _)
given SquareRoot[Double] = math.sqrt(_)
As we can see our SquareRoot
given is Double
-specific, since there’s no such operation for integral types like Int
. (Scala does provide a Fractional
typeclass that defines a division operation, but not a square root.)
Putting it all together
Let’s see how we can use our types and operations to define and use an operation that returns the "length" of a value in the sense of its "distance from zero". This means the absolute value for scalars, the Euclidean length for vectors etc. To avoid confusion with other notions of length such as the number of elements in a collection, we’ll call it Norm
after the mathematical term.
trait Norm[T <: Matchable, N]:
protected def normOf(x: T): N
extension(x: T)
transparent inline def norm: NonNegative[N] = inline x match
case _: NonZero[T] => Positive(normOf(x))
case _ => NonNegative(normOf(x))
Note that this is our first operation with two different types, T
for the value and N
for the norm. But the basic principle is the same, and our refinements here simply express that a value known (at compile time) to be non-zero will have a positive "length", otherwise if it’s only possibly non-zero, we’ll return a non-negative instead.
Here’s our absolute-value instance for scalar Numeric
values, where the value and the norm have the same type:
given [T <: Matchable: Numeric]: Norm[T, T] = Numeric[T].abs(_)
For the vector instances, we’ll use plain tuples of T
to represent the vectors, of 2 and 3 dimensions in this example. First let’s implement their Zero
instances, by simply deriving them from their components’:
given [T: Zero]: Zero[(T, T)] = Zero(zero, zero)
given [T: Zero]: Zero[(T, T, T)] = Zero(zero, zero, zero)
Now let’s implement their Norm
instances by leveraging our existing refinement types and operations:
given [T <: Matchable: Addition: Multiplication: SquareRoot: Zero: Ordering]: Norm[(T, T), T] =
v => (v._1.squared ++ v._2.squared).squareRoot
given [T <: Matchable: Addition: Multiplication: SquareRoot: Zero: Ordering]: Norm[(T, T, T), T] =
v => (v._1.squared ++ v._2.squared ++ v._3.squared).squareRoot
Our trait’s composite context bound, though it may look a bit daunting at first, exemplifies the "principle of least power" by requiring exactly the operations it will need to perform its task, no more and no less. (It’s a bit like requiring the "minimum" of a Functor, Applicative, Monad etc. in a library like Cats.) The Zero
and Ordering
bounds are a bit less obvious, but remember that our operations use them internally to validate their results when wrapping them in refinements.
As for the implementations, we’re doing the standard "square root of the sum of squares" to find the Euclidean norm. Note that we didn’t need to pass the the sum-of-squares through the NonNegative
wrapper so that squareRoot
will accept it: that’s because squared
always returns a NonNegative
, and our transparent-inline ++
operator "knows" that adding two non-negatives will also yield a non-negative.
Remember from the Norm
trait that we’re also refining our given’s result to Positive
if our value is known to be NonZero
. Let’s test this:
val v1 = (1.0, 2.0)
val v1Norm: NonNegative[Double] = v1.norm
val v2 = NonZero((1.0, 2.0))
val v2Norm: Positive[Double] = v2.norm
And some counter-examples:
val v3 = NonZero((1, 2))
val v3Norm: Positive[Int] = v3.norm // fails to compile: no squareRoot instance
val v4 = NonZero((0.0, 0.0)) // fails at runtime in the NonZero call
val v4Norm: Positive[Double] = v4.norm // never reached
This concludes our initial foray into generic refinements! You can find the completed Scastie worksheet here.
Conclusion
In this article we’ve only scratched the surface of the possibilities of generic refined types. Libraries exist that provide a full-fledged refinement framework, such as the classic refined library and the newer alternative iron, both of which I encourage you to take a look at. But even a simple homegrown setup as we’ve defined here can provide a basis for interesting developments. Hopefully this can help you imagine how you could leverage opaque refinement types (generic or otherwise) in your real-world projects. And I’d love to hear about it – don’t hesitate to contact me with questions or comments!