
Multiple dispatch in __torch_function__
Python is a single dispatch OO language, but there are some operations such as binary magic methods which implement a simple form of multiple dispatch. __torch_function__ (through its Numpy predecessor __array_function__) generalizes this mechanism so that invocations of torch.add with different subclasses work properly. This podcast describes how this mechanism works and how it can be used (in an unconventional way) to build composable subclasses ala JAX in functorch.
Audio is streamed directly from the publisher (cdn.simplecast.com) as published in their RSS feed. Play Podcasts does not host this file. Rights-holders can request removal through the copyright & takedown page.
Show Notes
Python is a single dispatch OO language, but there are some operations such as binary magic methods which implement a simple form of multiple dispatch. torch_function__ (through its Numpy predecessor __array_function) generalizes this mechanism so that invocations of torch.add with different subclasses work properly. This podcast describes how this mechanism works and how it can be used (in an unconventional way) to build composable subclasses ala JAX in functorch.
Further reading:
- This podcast in written form https://dev-discuss.pytorch.org/t/functorch-levels-as-dynamically-allocated-classes/294
- Multiple dispatch resolution rules in the RFC https://github.com/pytorch/rfcs/blob/master/RFC-0001-torch-function-for-methods.md#process-followed-during-a-functionmethod-call