未验证 提交 2cec4c88 编写于 作者: Z zhaoyingli 提交者: GitHub

adapt for resnet (#44685)

上级 a9f76d07
...@@ -1396,7 +1396,7 @@ class Completer: ...@@ -1396,7 +1396,7 @@ class Completer:
op_dist_attr.set_output_dims_mapping( op_dist_attr.set_output_dims_mapping(
input_var.name, [-1]) input_var.name, [-1])
else: else:
assert "Moment" in input_name assert "Moment" in input_name or "Velocity" in input_name
input_var_attr.dims_mapping = ref_dims_mapping input_var_attr.dims_mapping = ref_dims_mapping
op_dist_attr.set_input_dims_mapping( op_dist_attr.set_input_dims_mapping(
input_var.name, ref_dims_mapping) input_var.name, ref_dims_mapping)
......
...@@ -555,8 +555,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -555,8 +555,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
rank_id = _get_corresponding_rank( rank_id = _get_corresponding_rank(
ctx, process_mesh, rank_id) ctx, process_mesh, rank_id)
# NOTE: consider that the variable's shape is None
mesh_shape = process_mesh.topology mesh_shape = process_mesh.topology
batch_size_axis = var_dim_mapping[0] batch_size_axis = var_dim_mapping[0] if len(
var_dim_mapping) > 0 else -1
if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1:
need_gradient_allreduce = True need_gradient_allreduce = True
group_ranks = _get_comm_group(process_mesh.processes, group_ranks = _get_comm_group(process_mesh.processes,
......
...@@ -1058,14 +1058,16 @@ def set_grad_var_shape(program, dist_context): ...@@ -1058,14 +1058,16 @@ def set_grad_var_shape(program, dist_context):
"dropout_grad", "tanh_grad", "slice", "assign", "dropout_grad", "tanh_grad", "slice", "assign",
"matmul_v2_triple_grad", "elementwise_add_triple_grad", "matmul_v2_triple_grad", "elementwise_add_triple_grad",
"fill_constant", "sqrt_grad", "fill_constant", "sqrt_grad",
"fused_softmax_mask_upper_triangle_grad" "fused_softmax_mask_upper_triangle_grad",
"flatten_contiguous_range_grad", "relu_grad"
] ]
forward_list = [ forward_list = [
"reshape2", "softmax_with_cross_entropy", "transpose2", "reshape2", "softmax_with_cross_entropy", "transpose2",
"softmax", "cross_entropy2", "dropout", "tanh", "softmax", "cross_entropy2", "dropout", "tanh",
["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad", ["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad",
"elementwise_add_grad_grad", "shape", "sqrt", "elementwise_add_grad_grad", "shape", "sqrt",
"fused_softmax_mask_upper_triangle" "fused_softmax_mask_upper_triangle", "flatten_contiguous_range",
"relu"
] ]
if op.type in need_set_shape_list: if op.type in need_set_shape_list:
for forward_op in block.ops: for forward_op in block.ops:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册