Today I came across a nice problem in a piece of code one of our colleagues sent around asking for improvements. The problem consists of grouping a large List of lines according to some characteristics. Once the List grew large, the function which took care of the grouping took over 2 minutes to group the dataset.
The problem
Let’s first give a fictitious example of the domain and function that was wreaking havoc:
[code lang="scala"]
case class Container(lines: List[Contained])
case class Contained(key: String, value: String)
def group(ts: List[_]) = ts.foldLeft(List[Container]()) {
case (l: List[Container], r: Contained) =>
r match {
case Contained("", "") => l
case x @ Contained("", _) => {
l.init ++ List(Container(l.last.lines ::: x :: Nil))
}
case x : Contained => l ++ List(Container(List[Contained](x)))
case _ => l
}
}
[/code]
If the function was called on a suitably large list (100,000 items), execution time jumped to 130 seconds. Though the code is functionally sound, it isn’t all that optimized for Scala Lists. That leaves us with a number of possible optimizations:
- Replace each List with ListBuffer, making it effectively a mutable collection of mutable objects
- Make the Container object mutable by using a ListBuffer
- Optimize the function for Lists
Because the first and second scenarios don’t differ much, I’ll just present the first and last one here.
Fully mutable solution
[code lang="scala"]
case class Container(lines: ListBuffer[Contained])
case class Contained(key: String, value: String)
def group(ts: List[_]) = ts.foldLeft(ListBuffer[Container]()) {
case (l: ListBuffer[Container], r: Contained) =>
r match {
case Contained("", "") => l
case x @ Contained("", _) => {
l.last.lines.append(x)
l
}
case x : Contained => l.append(Container(ListBuffer[Contained](x))); l
case _ => l
}
}
[/code]
The execution speed over 100,000 items for this solution is around 350ms, quite an achievement by just using mutable objects.
Optimizing the immutable function
The first two options improve the runtime speed considerably, but both make the objects mutable. This is an evil we can avoid if we take a closer look at the group function and the List operations it’s using. The two main operations the group function is using are: List.init and List.last. Let’s take a look at the operations on List and their complexity:
List.head | O(1) |
List.tail | O(1) |
List.last | O(n) |
List.init | O(n) |
List.reverse | O(n) |
We can easily rewrite the group function to use head and tail instead of last and init. Each is the other’s counterpart on a reversed list, ie.
[code lang="scala"]
list.init == list.reverse.tail.reverse
and
list.last == list.reverse.head
[/code]
If we just immediately build the reverse lists, we just need to reverse once at the end. And we’ve saved a lot of O(n) operations. Let’s have a look at the optimized version, without the final reversing:
[code lang="scala"]
def improvedGroup(ts: List[_]) = ts.foldLeft(List[Container]()) {
case (l: List[Container], r: Contained) =>
r match {
case Contained("", "") => l
case x @ Contained("", _) => {
List(Container(List(x) ++ l.head.lines)) ++ l.tail
}
case x : Contained => List(Container(List[Contained](x))) ++ l
case _ => l
}
}
[/code]
The only remaining problem now is that the List[Container] and the lines of the Container are both reversed, so in order to get the same result, the function needs to be adapted a bit more to:
[code lang="scala"]
def fastGroup(ts: List[_]) = improvedGroup(ts).reverse.map(x => Container(x.lines.reverse))
[/code]
Now the fastGroup function delivers the same results, and its runtime speed over 100,000 items is around 450ms.
I don’t know about you, but I prefer the last solution. Though you need to remember to put the elements in the right order again after the grouping, you don’t sacrifice immutability. It sometimes pays off to work towards a solution one step at a time, instead of trying to craft it all at once.