diff --git a/python/paddle/distributed/auto_parallel/operators/dist_scale.py b/python/paddle/distributed/auto_parallel/operators/dist_scale.py index 9fc28d05a2077540b6b979fabbbe158703b0316d..940866f01cdbe23b50f5a7554bdbba2905cb1477 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_scale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_scale.py @@ -76,6 +76,10 @@ class DistributedScaleImpl(DistributedOperatorImpl): if dim_changed: changed = True + if changed: + op_dist_attr.set_input_dims_mapping(x_name, x_dims_mapping) + op_dist_attr.set_output_dims_mapping(out_name, out_dims_mapping) + return changed @staticmethod