Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions transformer_engine/jax/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,32 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh):
return mesh.shape[resource], resource


def _validate_mesh_resource_configuration():
def _validate_mesh_resource_configuration(mesh_resource):
"""Validate that the mesh resource configuration is consistent and conflict-free."""
gsr = global_mesh_resource()

is_dp_enabled = gsr.dp_resource is not None and get_mesh_axis_size(gsr.dp_resource) > 1
is_tp_enabled = gsr.tp_resource is not None and get_mesh_axis_size(gsr.tp_resource) > 1
is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1
is_fsdp_enabled = gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1
is_dp_enabled = (
mesh_resource.dp_resource is not None and get_mesh_axis_size(mesh_resource.dp_resource) > 1
)
is_tp_enabled = (
mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1
)
is_tpsp_enabled = (
mesh_resource.tpsp_resource is not None
and get_mesh_axis_size(mesh_resource.tpsp_resource) > 1
)
is_fsdp_enabled = (
mesh_resource.fsdp_resource is not None
and get_mesh_axis_size(mesh_resource.fsdp_resource) > 1
)

assert not (is_dp_enabled and is_fsdp_enabled), (
"Data parallelism and full-sharded data parallelism cannot be enabled at the same time."
f" Got dp_resource={gsr.dp_resource} and fsdp_resource={gsr.fsdp_resource}"
f" Got dp_resource={mesh_resource.dp_resource} and"
f" fsdp_resource={mesh_resource.fsdp_resource}"
)
assert not (is_tp_enabled and is_tpsp_enabled), (
"Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time."
f" Got tp_resource={gsr.tp_resource} and tpsp_resource={gsr.tpsp_resource}"
f" Got tp_resource={mesh_resource.tp_resource} and"
f" tpsp_resource={mesh_resource.tpsp_resource}"
)


Expand Down Expand Up @@ -305,7 +315,6 @@ def global_shard_guard(resource: MeshResource):
old_resources = _GLOBAL_MESH_RESOURCE
try:
_GLOBAL_MESH_RESOURCE = resource
_validate_mesh_resource_configuration()
yield
finally:
_GLOBAL_MESH_RESOURCE = old_resources
Expand All @@ -322,6 +331,7 @@ def global_mesh_resource() -> MeshResource:
" context. If you are not using multiple GPUs, you can use an empty MeshResource by"
" wrapping your program in 'with global_shard_guard(MeshResource()):'"
)
_validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE)
return _GLOBAL_MESH_RESOURCE


Expand Down