PLAY PODCASTS
Tensor subclasses and PT2
Episode 76

Tensor subclasses and PT2

Tensor subclasses allow you to add extend PyTorch with new types of tensors without having to write any C++. They have been used to implement DTensor, FP8, Nested Jagged Tensor and Complex Tensor. Recent work by Brian Hirsh means that we can compile tensor subclasses in PT2, eliminating their overhead. The basic mechanism by which this compilation works is a desugaring process in AOTAutograd. There are some complications involving views, dynamic shapes and tangent metadata mismatch.

PyTorch Developer Podcast

February 24, 202413m 25s

Audio is streamed directly from the publisher (cdn.simplecast.com) as published in their RSS feed. Play Podcasts does not host this file. Rights-holders can request removal through the copyright & takedown page.

Show Notes

Tensor subclasses allow you to add extend PyTorch with new types of tensors without having to write any C++. They have been used to implement DTensor, FP8, Nested Jagged Tensor and Complex Tensor. Recent work by Brian Hirsh means that we can compile tensor subclasses in PT2, eliminating their overhead. The basic mechanism by which this compilation works is a desugaring process in AOTAutograd. There are some complications involving views, dynamic shapes and tangent metadata mismatch.