Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Float8 + DTensor Integration #194

@drisspg

Description

@drisspg

Summary

This issue is used to track progress and updates on the integration with Float8 tensor and DTensors.

Background

DTensor is the PyTorch native solution for TP/SP and is designed to work with torch.compile. It uses tensor subclasses and module hooks to extend existing models.
Documentation: https://github.com/pytorch/pytorch/blob/main/torch/distributed/_tensor/README.md

What is needed:

  • refactor activation/grad casts into hooks: [wip] add option to do activation/grad cast from hooks #170
  • torch.compile support for Float8Linear + DTensor
  • support hook reordering
  • allgather/reduce_scatter:Float8Tensor subclass need to implement
  • cast_to_float8(DTensor(fp32/fp16 shard)) -> produce: DTensor(Float8Tensor)

Related Issues:

Pull Requests:

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions