'Requiring the gradient of Ids of lookup_table(v1)dist op is not currently supported. Please open an issue with details on your use case so that we can prioritize adding this (for instance, adversarial training for language model).'
'Requiring the gradient of Ids of lookup_table(v1) dist op is not currently supported. Please open an issue with details on your use case so that we can prioritize adding this (for instance, adversarial training for language model).'
)
target_shape=list(Ids_var.shape[:-1])
...
...
@@ -405,7 +405,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
ctx,op_dist_attr.process_mesh,rank_id
)
# A generalized method to caculate embedding offset using cartisian product
# A generalized method to calculate embedding offset using cartisian product
relative_idx=_get_idx_in_axis(
process_mesh_group,
process_mesh_shape,
...
...
@@ -416,7 +416,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):