As one does, I was thinking about how Python is criticized for lacking tail recursion optimization.
I came up with an idea of how to implement this without new language features, by using a decorator around the tail recursive function that catches a special Recur
exception and then turns around and calls the same function with the new arguments:
class Recur(BaseException, Generic[P]):
args: P.args
kwargs: P.kwargs
def __init__(self, *args: P.args, **kwargs: P.kwargs):
super().__init__(*args)
self.kwargs = kwargs
def recurrent(f: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(f)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
while True:
try:
return f(*args, **kwargs)
except Recur as r:
args = r.args
kwargs = r.kwargs
return wrapper
It can be used like so:
@recurrent
def gcd(a: int, b: int) -> int:
print(f"gcd({a}, {b})")
if b == 0:
return a
raise Recur(b, a % b)
print(gcd(1071, 462))
This is not an original idea. It has been documented before e.g., by Chris Penner. This code all properly type-checks under mypy --strict
(some context not shown). However, it doesn't allow mutual tail recursion.
I'll be honest: I didn't find any well-motivated examples for mutual tail recursion! Everyone uses the same awful poorly-motivated example of is-odd
and is-even
. But, because it was a challenge to placate the mypy type checker, I wanted to implement it anyway.
The problem lies in the implementation of the wrapper
function: args
and kwargs
have the types given in the initial recurrent call, and the types can't change just because f
changes.
The solution, which I realized a few weeks later, was to move the responsibility to actually dispatch the recurrent call into the Recur
instance. There can be many Recur
instances, but there inside the wrapper function they all simply have the same type: Recur
!
Here's the full implementation, which type checks clean with mypy --strict
(1.11.2) and runs in python 3.11.2:
from __future__ import annotations
import functools
from typing import Callable, ParamSpec, TypeVar, Generic, NoReturn
P = ParamSpec("P")
T = TypeVar("T")
class Recur(BaseException, Generic[P, T]):
f: Callable[P, T]
args: P.args
kwargs: P.kwargs
def __init__(self, f: Callable[P, T], args: P.args, kwargs: P.kwargs):
super().__init__()
self.f = f.f if isinstance(f, Recurrent) else f
self.args = args
self.kwargs = kwargs
def __call__(self) -> T:
return self.f(*self.args, **self.kwargs)
def __repr__(self) -> str:
if self.kwargs:
return f"<Recur {self.f.__name__}(*{self.args}, **{self.kwargs})>"
return f"<Recur {self.f.__name__}{self.args})>"
__str__ = __repr__
class Recurrent(Generic[P, T]):
f: Callable[P, T]
def __init__(self, f: Callable[P, T]) -> None:
self.f = f
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
r = Recur(self.f, args, kwargs)
while True:
try:
return r()
except Recur as exc:
r = exc
def recur(self, *args: P.args, **kwargs: P.kwargs) -> NoReturn:
raise Recur(self.f, args, kwargs)
def __repr__(self) -> str:
return f"<Recurrent {self.f.__name__}>"
And here's an example use:
import sys
from recur import Recurrent
@Recurrent
def gcd(a: int, b: int) -> int:
print(f"gcd({a}, {b})")
if b == 0:
return a
gcd.recur(b, a % b)
@Recurrent
def is_even(a: int) -> bool:
assert a >= 0
if a == 0:
return True
is_sum_odd.recur(a, -1)
@Recurrent
def is_sum_odd(a: int, b: int) -> bool:
c = a + b
assert c >= 0
if c == 0:
return False
is_even.recur(c - 1)
print(gcd)
print(gcd(1071, 462))
print(is_even(sys.getrecursionlimit() * 2))
print(is_even(sys.getrecursionlimit() * 2 + 1))
Entry first conceived on 30 December 2024, 19:06 UTC, last modified on 30 December 2024, 19:33 UTC
Website Copyright © 2004-2024 Jeff Epler