diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index d18c05a058eea52550b08cec074ab7c35b085b85..4eb2f45cc1859db6070289a037f3a0c3469a51db 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -1396,7 +1396,7 @@ class Completer: op_dist_attr.set_output_dims_mapping( input_var.name, [-1]) else: - assert "Moment" in input_name + assert "Moment" in input_name or "Velocity" in input_name input_var_attr.dims_mapping = ref_dims_mapping op_dist_attr.set_input_dims_mapping( input_var.name, ref_dims_mapping) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index 9b288d36e46eb1b3d1d24c05285a8a1b5875f1fa..d0eba355e7bec05ca2ac828b83b0b4eae0bb9cca 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -555,8 +555,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): rank_id = _get_corresponding_rank( ctx, process_mesh, rank_id) + # NOTE: consider that the variable's shape is None mesh_shape = process_mesh.topology - batch_size_axis = var_dim_mapping[0] + batch_size_axis = var_dim_mapping[0] if len( + var_dim_mapping) > 0 else -1 if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: need_gradient_allreduce = True group_ranks = _get_comm_group(process_mesh.processes, diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index c4f9ad8b6bc844d324070743511b8912741d19b0..5d9499d9286f373e2a0b1149a47ce97d877895f2 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1058,14 +1058,16 @@ def set_grad_var_shape(program, dist_context): "dropout_grad", "tanh_grad", "slice", "assign", "matmul_v2_triple_grad", "elementwise_add_triple_grad", "fill_constant", "sqrt_grad", - "fused_softmax_mask_upper_triangle_grad" + "fused_softmax_mask_upper_triangle_grad", + "flatten_contiguous_range_grad", "relu_grad" ] forward_list = [ "reshape2", "softmax_with_cross_entropy", "transpose2", "softmax", "cross_entropy2", "dropout", "tanh", ["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad", "elementwise_add_grad_grad", "shape", "sqrt", - "fused_softmax_mask_upper_triangle" + "fused_softmax_mask_upper_triangle", "flatten_contiguous_range", + "relu" ] if op.type in need_set_shape_list: for forward_op in block.ops: