Skip to content

[Feature request] Support bfloat16/float8 inputs in session.run() #20578

@justinchuby

Description

@justinchuby

session.run() currently does not support bfloat16 inputs because numpy does not support bfloat16. Would it be possible to support the input via np.uint16? session.run() can accept an uint16 value which is the bit representation of the bfloat16 value. ORT should be able to interpret the value correctly because it knows the expected input type of the graph. Same can be done for the float8* types.

Metadata

Metadata

Assignees

No one assigned

    Labels

    apiissues related to all other APIs: C, C++, Python, etc.converter:dynamoissues related supporting the PyTorch Dynamo exporterfeature requestrequest for unsupported feature or enhancement

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions