Mutual Tail Recursion in Python, fully mypy-strict type checked

As one does, I was thinking about how Python is criticized for lacking tail recursion optimization.

I came up with an idea of how to implement this without new language features, by using a decorator around the tail recursive function that catches a special Recur exception and then turns around and calls the same function with the new arguments:

class Recur(BaseException, Generic[P]):
    args: P.args
    kwargs: P.kwargs

    def __init__(self, *args: P.args, **kwargs: P.kwargs):
        super().__init__(*args)
        self.kwargs = kwargs


def recurrent(f: Callable[P, T]) -> Callable[P, T]:
    @functools.wraps(f)
    def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
        while True:
            try:
                return f(*args, **kwargs)
            except Recur as r:
                args = r.args
                kwargs = r.kwargs

    return wrapper

It can be used like so:

@recurrent
def gcd(a: int, b: int) -> int:
    print(f"gcd({a}, {b})")
    if b == 0:
        return a
    raise Recur(b, a % b)


print(gcd(1071, 462))

This is not an original idea. It has been documented before e.g., by Chris Penner. This code all properly type-checks under mypy --strict (some context not shown). However, it doesn't allow mutual tail recursion.

I'll be honest: I didn't find any well-motivated examples for mutual tail recursion! Everyone uses the same awful poorly-motivated example of is-odd and is-even. But, because it was a challenge to placate the mypy type checker, I wanted to implement it anyway.

The problem lies in the implementation of the wrapper function: args and kwargs have the types given in the initial recurrent call, and the types can't change just because f changes.

The solution, which I realized a few weeks later, was to move the responsibility to actually dispatch the recurrent call into the Recur instance. There can be many Recur instances, but there inside the wrapper function they all simply have the same type: Recur!

Here's the full implementation, which type checks clean with mypy --strict (1.11.2) and runs in python 3.11.2:

from __future__ import annotations
import functools
from typing import Callable, ParamSpec, TypeVar, Generic, NoReturn

P = ParamSpec("P")
T = TypeVar("T")


class Recur(BaseException, Generic[P, T]):
    f: Callable[P, T]
    args: P.args
    kwargs: P.kwargs

    def __init__(self, f: Callable[P, T], args: P.args, kwargs: P.kwargs):
        super().__init__()
        self.f = f.f if isinstance(f, Recurrent) else f
        self.args = args
        self.kwargs = kwargs

    def __call__(self) -> T:
        return self.f(*self.args, **self.kwargs)

    def __repr__(self) -> str:
        if self.kwargs:
            return f"<Recur {self.f.__name__}(*{self.args}, **{self.kwargs})>"
        return f"<Recur {self.f.__name__}{self.args})>"
    __str__ = __repr__

class Recurrent(Generic[P, T]):
    f: Callable[P, T]

    def __init__(self, f: Callable[P, T]) -> None:
        self.f = f

    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
        r = Recur(self.f, args, kwargs)
        while True:
            try:
                return r()
            except Recur as exc:
                r = exc

    def recur(self, *args: P.args, **kwargs: P.kwargs) -> NoReturn:
        raise Recur(self.f, args, kwargs)

    def __repr__(self) -> str:
        return f"<Recurrent {self.f.__name__}>"

And here's an example use:

import sys
from recur import Recurrent

@Recurrent
def gcd(a: int, b: int) -> int:
    print(f"gcd({a}, {b})")
    if b == 0:
        return a
    gcd.recur(b, a % b)


@Recurrent
def is_even(a: int) -> bool:
    assert a >= 0
    if a == 0:
        return True
    is_sum_odd.recur(a, -1)


@Recurrent
def is_sum_odd(a: int, b: int) -> bool:
    c = a + b
    assert c >= 0
    if c == 0:
        return False
    is_even.recur(c - 1)


print(gcd)
print(gcd(1071, 462))
print(is_even(sys.getrecursionlimit() * 2))
print(is_even(sys.getrecursionlimit() * 2 + 1))


Entry first conceived on 30 December 2024, 19:06 UTC, last modified on 30 December 2024, 19:33 UTC
Website Copyright © 2004-2024 Jeff Epler