Skip to content

justinchuby/torch-code-tracing

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

33 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torch-code-tracing

PyPI - Version PyPI - Downloads

Code trace your PyTorch model to understand its exexution and intermediate tensor shapes.

Install

pip install git+https://github.com/justinchuby/torch-code-tracing.git

Usage

from torch_code_tracing import TracingMode

with TracingMode(succinct=True, color=True):
    out = model(*args, **kwargs)
out = model(**example_kwargs)  # test.py:41 in <module>:
│ ⬇️
│ output = func(self, *args, **kwargs)  # site-packages/transformers/utils/generic.py:969 in wrapper:
│ │ ⬇️
│ │ inputs_embeds = self.get_input_embeddings()(llm_input_ids)  # site-packages/transformers/models/gemma3/modeling_gemma3.py:1175 in forward:
│ │ │ ⬇️
│ │ │ return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)  # site-packages/transformers/models/gemma3/modeling_gemma3.py:144 in forward:
│ │ │ │ # embedding(bf16[262208, 2560], i64[2, 3], 0) -> bf16[2, 3, 2560];
│ │ │ return super().forward(input_ids) * self.embed_scale.to(self.weight.dtype)  # site-packages/transformers/models/gemma3/modeling_gemma3.py:144 in forward:
│ │ │ │ # mul.Tensor(bf16[2, 3, 2560], bf16[]) -> bf16[2, 3, 2560];
│ │ cache_position = torch.arange(  # site-packages/transformers/models/gemma3/modeling_gemma3.py:1179 in forward:
│ │ │ # arange.start(30, 33, device=meta, pin_memory=False) -> i64[3];
│ │ causal_mask = self._update_causal_mask(  # site-packages/transformers/models/gemma3/modeling_gemma3.py:1205 in forward:
│ │ │ ⬇️
│ │ │ causal_mask = torch.full(  # site-packages/transformers/models/gemma3/modeling_gemma3.py:1050 in _update_causal_mask:
│ │ │ │ # full([3, 33], -3.3895313892515355e+38, dtype=torch.bfloat16, device=meta, pin_memory=False) -> bf16[3, 33];
│ │ │ causal_mask = torch.triu(causal_mask, diagonal=1)  # site-packages/transformers/models/gemma3/modeling_gemma3.py:1056 in _update_causal_mask:
│ │ │ │ # triu(bf16[3, 33], 1) -> bf16[3, 33];
...

About

Code trace your PyTorch model

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages