diff --git a/python/paddle/distributed/auto_parallel/operators/dist_scale.py b/python/paddle/distributed/auto_parallel/operators/dist_scale.py index 940866f01cdbe23b50f5a7554bdbba2905cb1477..800dbc7673fbfb8a5e7bc237de86cf3c80448cfc 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_scale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_scale.py @@ -74,6 +74,8 @@ class DistributedScaleImpl(DistributedOperatorImpl): [x_dims_mapping, out_dims_mapping], [i, i] ) if dim_changed: + op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping) + op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping) changed = True if changed: