未验证 提交 6992170e 编写于 作者: Z zhaoyingli 提交者: GitHub

fix_var_recursive (#48206)

上级 3c0bd3af
...@@ -266,13 +266,13 @@ def is_parameter_related(varname, block): ...@@ -266,13 +266,13 @@ def is_parameter_related(varname, block):
varname = varname[: varname.index(".cast_fp")] varname = varname[: varname.index(".cast_fp")]
if ".quantized" in varname: if ".quantized" in varname:
varname = varname[: varname.index(".quantized")] varname = varname[: varname.index(".quantized")]
assert block.has_var(varname) assert block._find_var_recursive(varname)
var = block.var(varname) var = block._var_recursive(varname)
return var.is_parameter return var.is_parameter
def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr):
var_shape = block.var(src_var.name).shape var_shape = block._var_recursive(src_var.name).shape
var_topoloy = src_var_dist_attr.process_mesh.topology var_topoloy = src_var_dist_attr.process_mesh.topology
var_dims_mapping = src_var_dist_attr.dims_mapping var_dims_mapping = src_var_dist_attr.dims_mapping
......
...@@ -117,7 +117,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -117,7 +117,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
if ( if (
rank_id rank_id
in ctx.get_tensor_dist_attr_for_program( in ctx.get_tensor_dist_attr_for_program(
main_block.var(varname) main_block._var_recursive(varname)
).process_mesh.processes ).process_mesh.processes
): ):
filter_vars.append(varname) filter_vars.append(varname)
...@@ -132,7 +132,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -132,7 +132,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
# sync result # sync result
group = new_process_group(world_process_group.ranks) group = new_process_group(world_process_group.ranks)
inf_var = main_block.var(kwargs['FoundInfinite'][0]) inf_var = main_block._var_recursive(kwargs['FoundInfinite'][0])
inf_var_int32 = main_block.create_var( inf_var_int32 = main_block.create_var(
name=inf_var.name + "@cast_int32", name=inf_var.name + "@cast_int32",
shape=inf_var.shape, shape=inf_var.shape,
...@@ -179,7 +179,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -179,7 +179,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistributedAttribute()
for varname in op.input_arg_names: for varname in op.input_arg_names:
var_dist_attr = ctx.get_tensor_dist_attr_for_program( var_dist_attr = ctx.get_tensor_dist_attr_for_program(
main_block.var(varname) main_block._var_recursive(varname)
) )
assert var_dist_attr is not None assert var_dist_attr is not None
new_op_dist_attr.set_input_dims_mapping( new_op_dist_attr.set_input_dims_mapping(
...@@ -187,7 +187,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -187,7 +187,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
) )
for varname in op.output_arg_names: for varname in op.output_arg_names:
var_dist_attr = ctx.get_tensor_dist_attr_for_program( var_dist_attr = ctx.get_tensor_dist_attr_for_program(
main_block.var(varname) main_block._var_recursive(varname)
) )
new_op_dist_attr.set_output_dims_mapping( new_op_dist_attr.set_output_dims_mapping(
varname, var_dist_attr.dims_mapping varname, var_dist_attr.dims_mapping
......
...@@ -69,7 +69,7 @@ def prim_operator_data_parallel_functor(ctx, src_op): ...@@ -69,7 +69,7 @@ def prim_operator_data_parallel_functor(ctx, src_op):
}, },
) )
grad_var = main_block.var(var_name) grad_var = main_block._var_recursive(var_name)
dims_mapping = ctx.get_tensor_dist_attr_for_program( dims_mapping = ctx.get_tensor_dist_attr_for_program(
grad_var grad_var
).dims_mapping ).dims_mapping
...@@ -140,7 +140,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -140,7 +140,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
res.append(cost_mapping) res.append(cost_mapping)
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars
need_gradient_allreduce = False need_gradient_allreduce = False
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.input(input_name):
...@@ -588,7 +587,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -588,7 +587,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for varname in backward_op.desc.output(output_name): for varname in backward_op.desc.output(output_name):
if varname in kwargs["grad_var_to_var"]: if varname in kwargs["grad_var_to_var"]:
fwd_name = kwargs["grad_var_to_var"][varname] fwd_name = kwargs["grad_var_to_var"][varname]
if fwd_name not in main_block.vars: if not main_block._find_var_recursive(fwd_name):
continue continue
if is_parameter_related(fwd_name, main_block): if is_parameter_related(fwd_name, main_block):
out_grad_names.append(varname) out_grad_names.append(varname)
......
...@@ -84,7 +84,6 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): ...@@ -84,7 +84,6 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl):
res.append(cost_mapping) res.append(cost_mapping)
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars
need_gradient_allreduce = False need_gradient_allreduce = False
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.input(input_name):
......
...@@ -370,9 +370,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -370,9 +370,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs['Out'] kwargs['Out']
) )
Ids_var = main_block.var(kwargs['Ids'][0]) Ids_var = main_block._var_recursive(kwargs['Ids'][0])
Weight_var = main_block._var_recursive(kwargs['W'][0]) Weight_var = main_block._var_recursive(kwargs['W'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block._var_recursive(kwargs['Out'][0])
# support lookup_table_v1 # support lookup_table_v1
if src_op.type == 'lookup_table': if src_op.type == 'lookup_table':
...@@ -507,7 +507,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -507,7 +507,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_allreduce_sum_op.desc.input_arg_names(): for input_varname in c_allreduce_sum_op.desc.input_arg_names():
input_var = main_block.var(input_varname) input_var = main_block._var_recursive(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr( allreduce_op_dist_attr.set_input_dist_attr(
...@@ -607,10 +607,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -607,10 +607,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
kwargs['W@GRAD'] kwargs['W@GRAD']
) )
Ids_var = main_block.var(kwargs['Ids'][0]) Ids_var = main_block._var_recursive(kwargs['Ids'][0])
Weight_var = main_block.var(kwargs['W'][0]) Weight_var = main_block._var_recursive(kwargs['W'][0])
Out_grad = main_block.var(kwargs['Out@GRAD'][0]) Out_grad = main_block._var_recursive(kwargs['Out@GRAD'][0])
Weight_grad = main_block.var(kwargs['W@GRAD'][0]) Weight_grad = main_block._var_recursive(kwargs['W@GRAD'][0])
embedding_row_dim_mapping = dist_attr.get_input_dims_mapping( embedding_row_dim_mapping = dist_attr.get_input_dims_mapping(
Weight_var.name Weight_var.name
......
...@@ -316,10 +316,10 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -316,10 +316,10 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
kwargs['Y@GRAD'] kwargs['Y@GRAD']
) )
X_var = main_block.var(kwargs['X'][0]) X_var = main_block._var_recursive(kwargs['X'][0])
Y_var = main_block._var_recursive(kwargs['Y'][0]) Y_var = main_block._var_recursive(kwargs['Y'][0])
Out_grad = main_block.var(kwargs['Out@GRAD'][0]) Out_grad = main_block._var_recursive(kwargs['Out@GRAD'][0])
Y_grad = main_block.var(kwargs['Y@GRAD'][0]) Y_grad = main_block._var_recursive(kwargs['Y@GRAD'][0])
assert not is_parameter_related( assert not is_parameter_related(
X_var.name, main_block X_var.name, main_block
...@@ -433,7 +433,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -433,7 +433,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
has_x_grad = len(kwargs['X@GRAD']) > 0 has_x_grad = len(kwargs['X@GRAD']) > 0
if has_x_grad: if has_x_grad:
assert len(kwargs['X@GRAD']) == 1 assert len(kwargs['X@GRAD']) == 1
X_grad = main_block.var(kwargs['X@GRAD'][0]) X_grad = main_block._var_recursive(kwargs['X@GRAD'][0])
intermediate_var_0 = main_block.create_var( intermediate_var_0 = main_block.create_var(
name=unique_name.generate_with_ignorable_key( name=unique_name.generate_with_ignorable_key(
".".join(["c_identity", 'tmp']) ".".join(["c_identity", 'tmp'])
...@@ -572,7 +572,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -572,7 +572,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping( Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0] backward_op.input("Y")[0]
) )
...@@ -647,7 +646,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -647,7 +646,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
# calc comm op cost # calc comm op cost
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0] serial_op.input("Y")[0]
)[-1] )[-1]
...@@ -762,9 +760,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -762,9 +760,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
output_name output_name
) )
X_var = main_block.var(kwargs['X'][0]) X_var = main_block._var_recursive(kwargs['X'][0])
Weight_var = main_block.var(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block._var_recursive(kwargs['Out'][0])
trans_x = src_op.attr("transpose_X") trans_x = src_op.attr("transpose_X")
trans_y = src_op.attr("transpose_Y") trans_y = src_op.attr("transpose_Y")
...@@ -906,7 +904,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -906,7 +904,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
input_varname, input_dist_attr input_varname, input_dist_attr
) )
else: else:
input_var = main_block.var(input_varname) input_var = main_block._var_recursive(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
input_var input_var
) )
...@@ -958,7 +956,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -958,7 +956,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping( Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0] backward_op.input("Y")[0]
) )
...@@ -1023,8 +1020,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1023,8 +1020,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
# calc comm op cost # calc comm op cost
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0] serial_op.input("Y")[0]
)[-2] )[-2]
...@@ -1147,9 +1142,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1147,9 +1142,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
output_name output_name
) )
X_var = main_block.var(kwargs['X'][0]) X_var = main_block._var_recursive(kwargs['X'][0])
Weight_var = main_block.var(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block._var_recursive(kwargs['Out'][0])
trans_x = src_op.attr('transpose_X') trans_x = src_op.attr('transpose_X')
trans_y = src_op.attr('transpose_Y') trans_y = src_op.attr('transpose_Y')
...@@ -1268,7 +1263,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -1268,7 +1263,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_allreduce_sum_op.desc.input_arg_names(): for input_varname in c_allreduce_sum_op.desc.input_arg_names():
input_var = main_block.var(input_varname) input_var = main_block._var_recursive(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr( allreduce_op_dist_attr.set_input_dist_attr(
...@@ -1316,7 +1311,6 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): ...@@ -1316,7 +1311,6 @@ class DistributedMatmulImpl2(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
...@@ -1469,7 +1463,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1469,7 +1463,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping( Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0] backward_op.input("Y")[0]
) )
...@@ -1549,8 +1542,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1549,8 +1542,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
# calc comm op cost # calc comm op cost
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0] serial_op.input("Y")[0]
)[-1] )[-1]
...@@ -1665,9 +1656,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1665,9 +1656,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
output_name output_name
) )
X_var = main_block.var(kwargs['X'][0]) X_var = main_block._var_recursive(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block._var_recursive(kwargs['Out'][0])
trans_x = src_op.attr('trans_x') trans_x = src_op.attr('trans_x')
trans_y = src_op.attr('trans_y') trans_y = src_op.attr('trans_y')
...@@ -1808,7 +1799,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1808,7 +1799,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
input_varname, input_dist_attr input_varname, input_dist_attr
) )
else: else:
input_var = main_block.var(input_varname) input_var = main_block._var_recursive(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
input_var input_var
) )
...@@ -1858,7 +1849,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1858,7 +1849,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping( Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0] backward_op.input("Y")[0]
) )
...@@ -1924,8 +1915,6 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1924,8 +1915,6 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
# calc comm op cost # calc comm op cost
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0] serial_op.input("Y")[0]
)[-2] )[-2]
...@@ -2047,9 +2036,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -2047,9 +2036,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
output_name output_name
) )
X_var = main_block.var(kwargs['X'][0]) X_var = main_block._var_recursive(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block._var_recursive(kwargs['Out'][0])
trans_x = src_op.attr('trans_x') trans_x = src_op.attr('trans_x')
trans_y = src_op.attr('trans_y') trans_y = src_op.attr('trans_y')
...@@ -2167,7 +2156,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -2167,7 +2156,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_allreduce_sum_op.desc.input_arg_names(): for input_varname in c_allreduce_sum_op.desc.input_arg_names():
input_var = main_block.var(input_varname) input_var = main_block._var_recursive(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr( allreduce_op_dist_attr.set_input_dist_attr(
...@@ -2215,7 +2204,6 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): ...@@ -2215,7 +2204,6 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
# calc comp op cost # calc comp op cost
...@@ -2370,7 +2358,6 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2370,7 +2358,6 @@ class DistributedMulImpl0(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping( Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0] backward_op.input("Y")[0]
) )
...@@ -2445,7 +2432,6 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2445,7 +2432,6 @@ class DistributedMulImpl0(DistributedOperatorImpl):
# calc comm op cost # calc comm op cost
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0] serial_op.input("Y")[0]
)[-1] )[-1]
...@@ -2555,9 +2541,9 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2555,9 +2541,9 @@ class DistributedMulImpl0(DistributedOperatorImpl):
output_name output_name
) )
X_var = main_block.var(kwargs['X'][0]) X_var = main_block._var_recursive(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block._var_recursive(kwargs['Out'][0])
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping(
...@@ -2712,7 +2698,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): ...@@ -2712,7 +2698,7 @@ class DistributedMulImpl0(DistributedOperatorImpl):
input_varname, input_dist_attr input_varname, input_dist_attr
) )
else: else:
input_var = main_block.var(input_varname) input_var = main_block._var_recursive(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(
input_var input_var
) )
...@@ -2763,7 +2749,6 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2763,7 +2749,6 @@ class DistributedMulImpl1(DistributedOperatorImpl):
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
process_mesh = dist_attr.process_mesh process_mesh = dist_attr.process_mesh
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars
Y_var_dim_mapping = dist_attr.get_input_dims_mapping( Y_var_dim_mapping = dist_attr.get_input_dims_mapping(
backward_op.input("Y")[0] backward_op.input("Y")[0]
) )
...@@ -2827,8 +2812,6 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2827,8 +2812,6 @@ class DistributedMulImpl1(DistributedOperatorImpl):
# calc comm op cost # calc comm op cost
serial_op = dist_op.serial_op serial_op = dist_op.serial_op
vars = serial_op.block.vars
parallel_axis = dist_op.dist_attr.get_input_dims_mapping( parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
serial_op.input("Y")[0] serial_op.input("Y")[0]
)[-2] )[-2]
...@@ -2947,9 +2930,9 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -2947,9 +2930,9 @@ class DistributedMulImpl1(DistributedOperatorImpl):
output_name output_name
) )
X_var = main_block.var(kwargs['X'][0]) X_var = main_block._var_recursive(kwargs['X'][0])
Weight_var = main_block._var_recursive(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block._var_recursive(kwargs['Out'][0])
# TODO infer logic comm presentation # TODO infer logic comm presentation
matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping(
...@@ -3082,7 +3065,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): ...@@ -3082,7 +3065,7 @@ class DistributedMulImpl1(DistributedOperatorImpl):
allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type
allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx
for input_varname in c_allreduce_sum_op.desc.input_arg_names(): for input_varname in c_allreduce_sum_op.desc.input_arg_names():
input_var = main_block.var(input_varname) input_var = main_block._var_recursive(input_varname)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var)
assert tensor_dist_attr is not None assert tensor_dist_attr is not None
allreduce_op_dist_attr.set_input_dist_attr( allreduce_op_dist_attr.set_input_dist_attr(
...@@ -3130,7 +3113,6 @@ class DistributedMulImpl2(DistributedOperatorImpl): ...@@ -3130,7 +3113,6 @@ class DistributedMulImpl2(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
main_block = backward_op.block main_block = backward_op.block
vars = main_block.vars
# calc comp op cost # calc comp op cost
desc_mapping = build_comp_desc_from_dist_op( desc_mapping = build_comp_desc_from_dist_op(
......
...@@ -155,7 +155,7 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -155,7 +155,7 @@ class DistributedPNormImpl(DistributedOperatorImpl):
ctx, op_dist_attr.process_mesh, rank_id ctx, op_dist_attr.process_mesh, rank_id
) )
X_var = main_block.var(kwargs['X'][0]) X_var = main_block._var_recursive(kwargs['X'][0])
in_dims_mapping = op_dist_attr.get_input_dims_mapping(X_var.name) in_dims_mapping = op_dist_attr.get_input_dims_mapping(X_var.name)
for axis in range(len(in_dims_mapping)): for axis in range(len(in_dims_mapping)):
if in_dims_mapping[axis] != -1: if in_dims_mapping[axis] != -1:
...@@ -260,13 +260,13 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -260,13 +260,13 @@ class DistributedPNormImpl(DistributedOperatorImpl):
output_name output_name
) )
X_var = main_block.var(kwargs['X'][0]) X_var = main_block._var_recursive(kwargs['X'][0])
X_grad_var = main_block.var(kwargs['X@GRAD'][0]) X_grad_var = main_block._var_recursive(kwargs['X@GRAD'][0])
# 1. copy p_norm_grad op and reset input name and output name # 1. copy p_norm_grad op and reset input name and output name
new_kwargs = copy.deepcopy(kwargs) new_kwargs = copy.deepcopy(kwargs)
new_kwargs['X'] = [".".join(["c_allgather", X_var.name])] new_kwargs['X'] = [".".join(["c_allgather", X_var.name])]
new_X_var = main_block.var(new_kwargs['X'][0]) new_X_var = main_block._var_recursive(new_kwargs['X'][0])
new_X_grad = main_block.create_var( new_X_grad = main_block.create_var(
name=".".join(["c_allgather", X_grad_var.name]), name=".".join(["c_allgather", X_grad_var.name]),
dtype=X_grad_var.dtype, dtype=X_grad_var.dtype,
......
...@@ -54,7 +54,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl): ...@@ -54,7 +54,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl):
return False return False
output_name = outputs[0] output_name = outputs[0]
output_var = dist_op.serial_op.block.var(output_name) output_var = dist_op.serial_op.block._var_recursive(output_name)
if output_var.shape != (1,): if output_var.shape != (1,):
return False return False
...@@ -124,7 +124,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl): ...@@ -124,7 +124,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl):
) )
# dist attr # dist attr
var = main_block.var(var_name) var = main_block._var_recursive(var_name)
tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var)
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
new_op_attr = OperatorDistributedAttribute() new_op_attr = OperatorDistributedAttribute()
......
...@@ -53,7 +53,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -53,7 +53,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
res = [] res = []
op = dist_op.serial_op op = dist_op.serial_op
vars = op.block.vars
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
shape_list = op.desc.attr("shape") shape_list = op.desc.attr("shape")
...@@ -103,7 +102,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -103,7 +102,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
main_block = backward_op.block main_block = backward_op.block
need_gradient_allreduce = False need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related( if "@GRAD" not in varname and is_parameter_related(
...@@ -246,9 +244,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -246,9 +244,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
output_name output_name
) )
X_var = main_block.var(kwargs['X'][0]) X_var = main_block._var_recursive(kwargs['X'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block._var_recursive(kwargs['Out'][0])
XShape_var = main_block.var(kwargs['XShape'][0]) XShape_var = main_block._var_recursive(kwargs['XShape'][0])
shape_list = src_op.desc.attr("shape") shape_list = src_op.desc.attr("shape")
ShapeTensor_var_list = [] ShapeTensor_var_list = []
for name in kwargs['ShapeTensor']: for name in kwargs['ShapeTensor']:
...@@ -303,7 +301,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -303,7 +301,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
res = [] res = []
op = dist_op.serial_op op = dist_op.serial_op
vars = op.block.vars
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
shape_list = op.desc.attr("shape") shape_list = op.desc.attr("shape")
...@@ -353,7 +350,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -353,7 +350,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
main_block = backward_op.block main_block = backward_op.block
need_gradient_allreduce = False need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related( if "@GRAD" not in varname and not is_parameter_related(
...@@ -499,9 +495,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -499,9 +495,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
output_name output_name
) )
X_var = main_block.var(kwargs['X'][0]) X_var = main_block._var_recursive(kwargs['X'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block._var_recursive(kwargs['Out'][0])
XShape_var = main_block.var(kwargs['XShape'][0]) XShape_var = main_block._var_recursive(kwargs['XShape'][0])
shape_list = src_op.desc.attr("shape") shape_list = src_op.desc.attr("shape")
ShapeTensor_var_list = [] ShapeTensor_var_list = []
for name in kwargs['ShapeTensor']: for name in kwargs['ShapeTensor']:
...@@ -556,7 +552,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -556,7 +552,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
def calc_fwd_cost(self, dist_op, ctx, cluster): def calc_fwd_cost(self, dist_op, ctx, cluster):
res = [] res = []
op = dist_op.serial_op op = dist_op.serial_op
vars = op.block.vars
dist_attr = dist_op.dist_attr dist_attr = dist_op.dist_attr
shape_list = op.desc.attr("shape") shape_list = op.desc.attr("shape")
...@@ -606,7 +601,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -606,7 +601,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
main_block = backward_op.block main_block = backward_op.block
need_gradient_allreduce = False need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and not is_parameter_related( if "@GRAD" not in varname and not is_parameter_related(
...@@ -745,9 +739,9 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -745,9 +739,9 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
output_name output_name
) )
X_var = main_block.var(kwargs['X'][0]) X_var = main_block._var_recursive(kwargs['X'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block._var_recursive(kwargs['Out'][0])
XShape_var = main_block.var(kwargs['XShape'][0]) XShape_var = main_block._var_recursive(kwargs['XShape'][0])
shape_list = src_op.desc.attr("shape") shape_list = src_op.desc.attr("shape")
ShapeTensor_var_list = [] ShapeTensor_var_list = []
for name in kwargs['ShapeTensor']: for name in kwargs['ShapeTensor']:
......
...@@ -79,7 +79,6 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): ...@@ -79,7 +79,6 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
main_block = backward_op.block main_block = backward_op.block
need_gradient_allreduce = False need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related( if "@GRAD" not in varname and is_parameter_related(
......
...@@ -160,7 +160,6 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): ...@@ -160,7 +160,6 @@ class DistributedTranspose2Impl(DistributedOperatorImpl):
backward_op = dist_op.serial_op backward_op = dist_op.serial_op
main_block = backward_op.block main_block = backward_op.block
need_gradient_allreduce = False need_gradient_allreduce = False
vars = main_block.vars
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
for varname in backward_op.desc.input(input_name): for varname in backward_op.desc.input(input_name):
if "@GRAD" not in varname and is_parameter_related( if "@GRAD" not in varname and is_parameter_related(
......
...@@ -151,7 +151,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): ...@@ -151,7 +151,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
if ( if (
rank_id rank_id
in ctx.get_tensor_dist_attr_for_program( in ctx.get_tensor_dist_attr_for_program(
main_block.var(varname) main_block._var_recursive(varname)
).process_mesh.processes ).process_mesh.processes
): ):
filter_vars.append(varname) filter_vars.append(varname)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册