scalaadvanced

Trampolining for Stack Safety

Use trampolining to make recursive algorithms stack-safe without tail recursion.

scala
// Trampoline: stack-safe recursion without @tailrec
sealed trait Trampoline[+A]:
  def flatMap[B](f: A => Trampoline[B]): Trampoline[B] =
    FlatMap(this, f)

  def map[B](f: A => B): Trampoline[B] =
    flatMap(a => Done(f(a)))

case class Done[A](value: A) extends Trampoline[A]
case class More[A](thunk: () => Trampoline[A]) extends Trampoline[A]
case class FlatMap[A, B](sub: Trampoline[A], f: A => Trampoline[B]) extends Trampoline[B]

object Trampoline:
  // Stack-safe interpreter
  @annotation.tailrec
  def run[A](trampoline: Trampoline[A]): A = trampoline match
    case Done(v) => v
    case More(t) => run(t())
    case FlatMap(sub, f) => sub match
      case Done(v) => run(f(v))
      case More(t) => run(FlatMap(t(), f))
      case FlatMap(sub2, g) =>
        run(FlatMap(sub2, (x: Any) => FlatMap(g(x), f)))

  def done[A](a: A): Trampoline[A] = Done(a)
  def more[A](a: => Trampoline[A]): Trampoline[A] = More(() => a)
  def delay[A](a: => A): Trampoline[A] = More(() => Done(a))

// Mutual recursion (not possible with @tailrec!)
def isEven(n: Long): Trampoline[Boolean] =
  if n == 0 then Trampoline.done(true)
  else Trampoline.more(isOdd(n - 1))

def isOdd(n: Long): Trampoline[Boolean] =
  if n == 0 then Trampoline.done(false)
  else Trampoline.more(isEven(n - 1))

// Fibonacci (naive recursive but stack-safe)
def fib(n: Int): Trampoline[BigInt] =
  if n <= 1 then Trampoline.done(BigInt(n))
  else
    Trampoline.more(fib(n - 1)).flatMap { a =>
      Trampoline.more(fib(n - 2)).map { b =>
        a + b
      }
    }

// Factorial
def factorial(n: Int): Trampoline[BigInt] =
  if n <= 1 then Trampoline.done(BigInt(1))
  else
    Trampoline.more(factorial(n - 1)).map(_ * n)

// Tree traversal (stack-safe)
sealed trait Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]

def treeSum(tree: Tree[Int]): Trampoline[Int] = tree match
  case Leaf(v) => Trampoline.done(v)
  case Branch(l, r) =>
    Trampoline.more(treeSum(l)).flatMap { leftSum =>
      Trampoline.more(treeSum(r)).map { rightSum =>
        leftSum + rightSum
      }
    }

@main def run(): Unit =
  // Mutual recursion — would stack overflow without trampolining
  val n = 1_000_000L
  println(s"isEven($n): ${Trampoline.run(isEven(n))}")
  println(s"isOdd($n): ${Trampoline.run(isOdd(n))}")

  // Factorial of large numbers
  println(s"20! = ${Trampoline.run(factorial(20))}")
  println(s"100! digits: ${Trampoline.run(factorial(100)).toString.length}")

  // Fibonacci (small n — exponential but stack-safe)
  println(s"fib(20) = ${Trampoline.run(fib(20))}")

  // Deep tree traversal
  def deepTree(depth: Int): Tree[Int] =
    if depth == 0 then Leaf(1)
    else Branch(deepTree(depth - 1), Leaf(1))

  val deep = deepTree(10000)
  println(s"Deep tree sum: ${Trampoline.run(treeSum(deep))}")

Use Cases

  • Stack-safe mutual recursion
  • Deep recursive data structures
  • CPS-style transformations

Tags

Related Snippets

Similar patterns you can reuse in the same workflow.