Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/transformers/loss/loss_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def ForCausalLMLoss(
# Flatten the tokens
logits = logits.view(-1, vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(logits.device)
loss = fixed_cross_entropy(logits, shift_labels, num_items_in_batch, ignore_index, **kwargs)
return loss
Expand All @@ -82,7 +81,6 @@ def ForMaskedLMLoss(
# Flatten the tokens
logits = logits.view(-1, vocab_size)
labels = labels.view(-1)
# Enable model parallelism

labels = labels.to(logits.device)
loss = fixed_cross_entropy(logits, labels, num_items_in_batch, ignore_index, **kwargs)
Expand Down
2 changes: 0 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,7 +1932,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
for this model architecture.
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
classes of the same architecture adding modules on top of the base model.
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
- **can_record_outputs** (dict):
Expand Down Expand Up @@ -1967,7 +1966,6 @@ class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToH
# a list of `state_dict` keys that are potentially tied to another key in the state_dict.
_tied_weights_keys = None

is_parallelizable = False
supports_gradient_checkpointing = False
_is_stateful = False

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/models/bloom/modeling_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,7 +888,7 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(lm_logits.device)
# Flatten the tokens
loss = self.loss_function(
Expand Down Expand Up @@ -1133,7 +1133,7 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(logits.device)
batch_size, seq_length = labels.shape
loss_fct = CrossEntropyLoss()
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/camembert/modeling_camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,7 @@ def forward(

masked_lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(prediction_scores.device)
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
Expand Down Expand Up @@ -1108,7 +1108,7 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
Expand Down Expand Up @@ -1225,7 +1225,7 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(reshaped_logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
Expand Down Expand Up @@ -1298,7 +1298,7 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
Expand Down Expand Up @@ -1491,7 +1491,7 @@ def forward(

lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(prediction_scores.device)
lm_loss = self.loss_function(
prediction_scores,
Expand Down
10 changes: 5 additions & 5 deletions src/transformers/models/camembert/modular_camembert.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def forward(

masked_lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(prediction_scores.device)
loss_fct = CrossEntropyLoss()
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
Expand Down Expand Up @@ -168,7 +168,7 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(logits.device)
if self.config.problem_type is None:
if self.num_labels == 1:
Expand Down Expand Up @@ -280,7 +280,7 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(reshaped_logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
Expand Down Expand Up @@ -344,7 +344,7 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(logits.device)
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
Expand Down Expand Up @@ -513,7 +513,7 @@ def forward(

lm_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(prediction_scores.device)
lm_loss = self.loss_function(
prediction_scores,
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/codegen/modeling_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(lm_logits.device)
# Flatten the tokens
loss = self.loss_function(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,6 @@ def forward(
class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
config: DecisionTransformerConfig
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True

_can_compile_fullgraph = False
Expand Down Expand Up @@ -448,9 +447,6 @@ def __init__(self, config):
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

# Model parallel
self.model_parallel = False
self.device_map = None
self.gradient_checkpointing = False

# Initialize weights and apply final processing
Expand Down Expand Up @@ -581,17 +577,6 @@ def forward(
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, block in enumerate(self.h):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
if isinstance(head_mask, torch.Tensor):
head_mask = head_mask.to(hidden_states.device)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

outputs = block(
hidden_states,
past_key_values if not (self.gradient_checkpointing and self.training) else None,
Expand All @@ -611,12 +596,6 @@ def forward(
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[2],)

# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
for k, v in self.device_map.items():
if i == v[-1] and "cuda:" + str(k) != self.last_device:
hidden_states = hidden_states.to("cuda:" + str(k + 1))

hidden_states = self.ln_f(hidden_states)

hidden_states = hidden_states.view(output_shape)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1214,7 +1214,7 @@ def forward(
router_probs = None
aux_loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(lm_logits.device)

loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
Expand All @@ -755,7 +755,6 @@ def forward(
loss_fct = CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -913,7 +913,7 @@ def forward(

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
# move labels to correct device
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
Expand Down
Loading