diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index af9c53a88ea869cb99e0327f8b83df681c7727f9..72ed66f3e41a0c6338aa06c98b81b5028643d38f 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -266,13 +266,13 @@ def is_parameter_related(varname, block): varname = varname[: varname.index(".cast_fp")] if ".quantized" in varname: varname = varname[: varname.index(".quantized")] - assert block.has_var(varname) - var = block.var(varname) + assert block._find_var_recursive(varname) + var = block._var_recursive(varname) return var.is_parameter def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): - var_shape = block.var(src_var.name).shape + var_shape = block._var_recursive(src_var.name).shape var_topoloy = src_var_dist_attr.process_mesh.topology var_dims_mapping = src_var_dist_attr.dims_mapping diff --git a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py index 4fa689f5fa1ae7985ad765d36d91872687bfbbc1..c1834bde1136c704d02ba065db2b39094439963f 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_check_finite_and_unscale.py @@ -117,7 +117,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): if ( rank_id in ctx.get_tensor_dist_attr_for_program( - main_block.var(varname) + main_block._var_recursive(varname) ).process_mesh.processes ): filter_vars.append(varname) @@ -132,7 +132,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): # sync result group = new_process_group(world_process_group.ranks) - inf_var = main_block.var(kwargs['FoundInfinite'][0]) + inf_var = main_block._var_recursive(kwargs['FoundInfinite'][0]) inf_var_int32 = main_block.create_var( name=inf_var.name + "@cast_int32", shape=inf_var.shape, @@ -179,7 +179,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): new_op_dist_attr = OperatorDistributedAttribute() for varname in op.input_arg_names: var_dist_attr = ctx.get_tensor_dist_attr_for_program( - main_block.var(varname) + main_block._var_recursive(varname) ) assert var_dist_attr is not None new_op_dist_attr.set_input_dims_mapping( @@ -187,7 +187,7 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ) for varname in op.output_arg_names: var_dist_attr = ctx.get_tensor_dist_attr_for_program( - main_block.var(varname) + main_block._var_recursive(varname) ) new_op_dist_attr.set_output_dims_mapping( varname, var_dist_attr.dims_mapping diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index b1c0014045b957d3c26ff24cea5f7d22332e0612..85ffb77d97b5255acaa2c631693ee9e5c14b37d4 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -69,7 +69,7 @@ def prim_operator_data_parallel_functor(ctx, src_op): }, ) - grad_var = main_block.var(var_name) + grad_var = main_block._var_recursive(var_name) dims_mapping = ctx.get_tensor_dist_attr_for_program( grad_var ).dims_mapping @@ -140,7 +140,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): res.append(cost_mapping) main_block = backward_op.block - vars = main_block.vars need_gradient_allreduce = False for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): @@ -588,7 +587,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for varname in backward_op.desc.output(output_name): if varname in kwargs["grad_var_to_var"]: fwd_name = kwargs["grad_var_to_var"][varname] - if fwd_name not in main_block.vars: + if not main_block._find_var_recursive(fwd_name): continue if is_parameter_related(fwd_name, main_block): out_grad_names.append(varname) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py index 50d4f138dcaa6fe4df531f323461b0601259518b..75dcc98faa130ce34522cf230951bf83bb0d474f 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_eltwise.py @@ -84,7 +84,6 @@ class DistributedElementwiseImpl0(DistributedOperatorImpl): res.append(cost_mapping) main_block = backward_op.block - vars = main_block.vars need_gradient_allreduce = False for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py index af3514c85f1a56ac86da24d0bd1139e647355203..683236cadd14f812d26c3373d9fca79a171634ef 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_embedding.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_embedding.py @@ -370,9 +370,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): kwargs['Out'] ) - Ids_var = main_block.var(kwargs['Ids'][0]) + Ids_var = main_block._var_recursive(kwargs['Ids'][0]) Weight_var = main_block._var_recursive(kwargs['W'][0]) - Out_var = main_block.var(kwargs['Out'][0]) + Out_var = main_block._var_recursive(kwargs['Out'][0]) # support lookup_table_v1 if src_op.type == 'lookup_table': @@ -507,7 +507,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in c_allreduce_sum_op.desc.input_arg_names(): - input_var = main_block.var(input_varname) + input_var = main_block._var_recursive(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) assert tensor_dist_attr is not None allreduce_op_dist_attr.set_input_dist_attr( @@ -607,10 +607,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): kwargs['W@GRAD'] ) - Ids_var = main_block.var(kwargs['Ids'][0]) - Weight_var = main_block.var(kwargs['W'][0]) - Out_grad = main_block.var(kwargs['Out@GRAD'][0]) - Weight_grad = main_block.var(kwargs['W@GRAD'][0]) + Ids_var = main_block._var_recursive(kwargs['Ids'][0]) + Weight_var = main_block._var_recursive(kwargs['W'][0]) + Out_grad = main_block._var_recursive(kwargs['Out@GRAD'][0]) + Weight_grad = main_block._var_recursive(kwargs['W@GRAD'][0]) embedding_row_dim_mapping = dist_attr.get_input_dims_mapping( Weight_var.name diff --git a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py index ace72f0a2162a2def32e9b0ae857a2388366a2f7..fa6557f497bb28adaf24d0b31ac32d0ebeab66a4 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_matmul.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_matmul.py @@ -316,10 +316,10 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): kwargs['Y@GRAD'] ) - X_var = main_block.var(kwargs['X'][0]) + X_var = main_block._var_recursive(kwargs['X'][0]) Y_var = main_block._var_recursive(kwargs['Y'][0]) - Out_grad = main_block.var(kwargs['Out@GRAD'][0]) - Y_grad = main_block.var(kwargs['Y@GRAD'][0]) + Out_grad = main_block._var_recursive(kwargs['Out@GRAD'][0]) + Y_grad = main_block._var_recursive(kwargs['Y@GRAD'][0]) assert not is_parameter_related( X_var.name, main_block @@ -433,7 +433,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): has_x_grad = len(kwargs['X@GRAD']) > 0 if has_x_grad: assert len(kwargs['X@GRAD']) == 1 - X_grad = main_block.var(kwargs['X@GRAD'][0]) + X_grad = main_block._var_recursive(kwargs['X@GRAD'][0]) intermediate_var_0 = main_block.create_var( name=unique_name.generate_with_ignorable_key( ".".join(["c_identity", 'tmp']) @@ -572,7 +572,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): backward_op = dist_op.serial_op dist_attr = dist_op.dist_attr main_block = backward_op.block - vars = main_block.vars Y_var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("Y")[0] ) @@ -647,7 +646,6 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): # calc comm op cost serial_op = dist_op.serial_op - vars = serial_op.block.vars parallel_axis = dist_op.dist_attr.get_input_dims_mapping( serial_op.input("Y")[0] )[-1] @@ -762,9 +760,9 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): output_name ) - X_var = main_block.var(kwargs['X'][0]) - Weight_var = main_block.var(kwargs['Y'][0]) - Out_var = main_block.var(kwargs['Out'][0]) + X_var = main_block._var_recursive(kwargs['X'][0]) + Weight_var = main_block._var_recursive(kwargs['Y'][0]) + Out_var = main_block._var_recursive(kwargs['Out'][0]) trans_x = src_op.attr("transpose_X") trans_y = src_op.attr("transpose_Y") @@ -906,7 +904,7 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): input_varname, input_dist_attr ) else: - input_var = main_block.var(input_varname) + input_var = main_block._var_recursive(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( input_var ) @@ -958,7 +956,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): backward_op = dist_op.serial_op dist_attr = dist_op.dist_attr main_block = backward_op.block - vars = main_block.vars Y_var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("Y")[0] ) @@ -1023,8 +1020,6 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): # calc comm op cost serial_op = dist_op.serial_op - vars = serial_op.block.vars - parallel_axis = dist_op.dist_attr.get_input_dims_mapping( serial_op.input("Y")[0] )[-2] @@ -1147,9 +1142,9 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): output_name ) - X_var = main_block.var(kwargs['X'][0]) - Weight_var = main_block.var(kwargs['Y'][0]) - Out_var = main_block.var(kwargs['Out'][0]) + X_var = main_block._var_recursive(kwargs['X'][0]) + Weight_var = main_block._var_recursive(kwargs['Y'][0]) + Out_var = main_block._var_recursive(kwargs['Out'][0]) trans_x = src_op.attr('transpose_X') trans_y = src_op.attr('transpose_Y') @@ -1268,7 +1263,7 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in c_allreduce_sum_op.desc.input_arg_names(): - input_var = main_block.var(input_varname) + input_var = main_block._var_recursive(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) assert tensor_dist_attr is not None allreduce_op_dist_attr.set_input_dist_attr( @@ -1316,7 +1311,6 @@ class DistributedMatmulImpl2(DistributedOperatorImpl): backward_op = dist_op.serial_op dist_attr = dist_op.dist_attr main_block = backward_op.block - vars = main_block.vars # calc comp op cost desc_mapping = build_comp_desc_from_dist_op( @@ -1469,7 +1463,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): backward_op = dist_op.serial_op dist_attr = dist_op.dist_attr main_block = backward_op.block - vars = main_block.vars Y_var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("Y")[0] ) @@ -1549,8 +1542,6 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): # calc comm op cost serial_op = dist_op.serial_op - vars = serial_op.block.vars - parallel_axis = dist_op.dist_attr.get_input_dims_mapping( serial_op.input("Y")[0] )[-1] @@ -1665,9 +1656,9 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): output_name ) - X_var = main_block.var(kwargs['X'][0]) + X_var = main_block._var_recursive(kwargs['X'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0]) - Out_var = main_block.var(kwargs['Out'][0]) + Out_var = main_block._var_recursive(kwargs['Out'][0]) trans_x = src_op.attr('trans_x') trans_y = src_op.attr('trans_y') @@ -1808,7 +1799,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): input_varname, input_dist_attr ) else: - input_var = main_block.var(input_varname) + input_var = main_block._var_recursive(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( input_var ) @@ -1858,7 +1849,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): backward_op = dist_op.serial_op dist_attr = dist_op.dist_attr main_block = backward_op.block - vars = main_block.vars + Y_var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("Y")[0] ) @@ -1924,8 +1915,6 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): # calc comm op cost serial_op = dist_op.serial_op - vars = serial_op.block.vars - parallel_axis = dist_op.dist_attr.get_input_dims_mapping( serial_op.input("Y")[0] )[-2] @@ -2047,9 +2036,9 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): output_name ) - X_var = main_block.var(kwargs['X'][0]) + X_var = main_block._var_recursive(kwargs['X'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0]) - Out_var = main_block.var(kwargs['Out'][0]) + Out_var = main_block._var_recursive(kwargs['Out'][0]) trans_x = src_op.attr('trans_x') trans_y = src_op.attr('trans_y') @@ -2167,7 +2156,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in c_allreduce_sum_op.desc.input_arg_names(): - input_var = main_block.var(input_varname) + input_var = main_block._var_recursive(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) assert tensor_dist_attr is not None allreduce_op_dist_attr.set_input_dist_attr( @@ -2215,7 +2204,6 @@ class DistributedMatmulV2Impl2(DistributedOperatorImpl): backward_op = dist_op.serial_op dist_attr = dist_op.dist_attr main_block = backward_op.block - vars = main_block.vars process_mesh = dist_attr.process_mesh # calc comp op cost @@ -2370,7 +2358,6 @@ class DistributedMulImpl0(DistributedOperatorImpl): backward_op = dist_op.serial_op dist_attr = dist_op.dist_attr main_block = backward_op.block - vars = main_block.vars Y_var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("Y")[0] ) @@ -2445,7 +2432,6 @@ class DistributedMulImpl0(DistributedOperatorImpl): # calc comm op cost serial_op = dist_op.serial_op - vars = serial_op.block.vars parallel_axis = dist_op.dist_attr.get_input_dims_mapping( serial_op.input("Y")[0] )[-1] @@ -2555,9 +2541,9 @@ class DistributedMulImpl0(DistributedOperatorImpl): output_name ) - X_var = main_block.var(kwargs['X'][0]) + X_var = main_block._var_recursive(kwargs['X'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0]) - Out_var = main_block.var(kwargs['Out'][0]) + Out_var = main_block._var_recursive(kwargs['Out'][0]) # TODO infer logic comm presentation matmul_col_dim_mapping = op_dist_attr.get_input_dims_mapping( @@ -2712,7 +2698,7 @@ class DistributedMulImpl0(DistributedOperatorImpl): input_varname, input_dist_attr ) else: - input_var = main_block.var(input_varname) + input_var = main_block._var_recursive(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program( input_var ) @@ -2763,7 +2749,6 @@ class DistributedMulImpl1(DistributedOperatorImpl): dist_attr = dist_op.dist_attr process_mesh = dist_attr.process_mesh main_block = backward_op.block - vars = main_block.vars Y_var_dim_mapping = dist_attr.get_input_dims_mapping( backward_op.input("Y")[0] ) @@ -2827,8 +2812,6 @@ class DistributedMulImpl1(DistributedOperatorImpl): # calc comm op cost serial_op = dist_op.serial_op - vars = serial_op.block.vars - parallel_axis = dist_op.dist_attr.get_input_dims_mapping( serial_op.input("Y")[0] )[-2] @@ -2947,9 +2930,9 @@ class DistributedMulImpl1(DistributedOperatorImpl): output_name ) - X_var = main_block.var(kwargs['X'][0]) + X_var = main_block._var_recursive(kwargs['X'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0]) - Out_var = main_block.var(kwargs['Out'][0]) + Out_var = main_block._var_recursive(kwargs['Out'][0]) # TODO infer logic comm presentation matmul_row_dim_mapping = op_dist_attr.get_input_dims_mapping( @@ -3082,7 +3065,7 @@ class DistributedMulImpl1(DistributedOperatorImpl): allreduce_op_dist_attr.impl_type = op_dist_attr.impl_type allreduce_op_dist_attr.impl_idx = op_dist_attr.impl_idx for input_varname in c_allreduce_sum_op.desc.input_arg_names(): - input_var = main_block.var(input_varname) + input_var = main_block._var_recursive(input_varname) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(input_var) assert tensor_dist_attr is not None allreduce_op_dist_attr.set_input_dist_attr( @@ -3130,7 +3113,6 @@ class DistributedMulImpl2(DistributedOperatorImpl): backward_op = dist_op.serial_op dist_attr = dist_op.dist_attr main_block = backward_op.block - vars = main_block.vars # calc comp op cost desc_mapping = build_comp_desc_from_dist_op( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py index 662b4d666fdc4ba7cc3dbd6da0c49aedcffb993f..99cc63a7b93dd59369393f25e17f990faca8a135 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_pnorm.py @@ -155,7 +155,7 @@ class DistributedPNormImpl(DistributedOperatorImpl): ctx, op_dist_attr.process_mesh, rank_id ) - X_var = main_block.var(kwargs['X'][0]) + X_var = main_block._var_recursive(kwargs['X'][0]) in_dims_mapping = op_dist_attr.get_input_dims_mapping(X_var.name) for axis in range(len(in_dims_mapping)): if in_dims_mapping[axis] != -1: @@ -260,13 +260,13 @@ class DistributedPNormImpl(DistributedOperatorImpl): output_name ) - X_var = main_block.var(kwargs['X'][0]) - X_grad_var = main_block.var(kwargs['X@GRAD'][0]) + X_var = main_block._var_recursive(kwargs['X'][0]) + X_grad_var = main_block._var_recursive(kwargs['X@GRAD'][0]) # 1. copy p_norm_grad op and reset input name and output name new_kwargs = copy.deepcopy(kwargs) new_kwargs['X'] = [".".join(["c_allgather", X_var.name])] - new_X_var = main_block.var(new_kwargs['X'][0]) + new_X_var = main_block._var_recursive(new_kwargs['X'][0]) new_X_grad = main_block.create_var( name=".".join(["c_allgather", X_grad_var.name]), dtype=X_grad_var.dtype, diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py b/python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py index 01b326d3a562c75a7f122d677b406bf8aa687de5..75dbb7f9c0dcba15acb6e7e927dfd3cc994b6a7a 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reduce_sum_p.py @@ -54,7 +54,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl): return False output_name = outputs[0] - output_var = dist_op.serial_op.block.var(output_name) + output_var = dist_op.serial_op.block._var_recursive(output_name) if output_var.shape != (1,): return False @@ -124,7 +124,7 @@ class DistributedReduceSumPrimtiveImpl0(DistributedOperatorImpl): ) # dist attr - var = main_block.var(var_name) + var = main_block._var_recursive(var_name) tensor_dist_attr = ctx.get_tensor_dist_attr_for_program(var) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) new_op_attr = OperatorDistributedAttribute() diff --git a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py index b305d88d7df1b270b352b74c1d58198c2f47d25b..7d4aa3f517be86d4f6f07f86d11f13259637de5a 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_reshape.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_reshape.py @@ -53,7 +53,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): def calc_fwd_cost(self, dist_op, ctx, cluster): res = [] op = dist_op.serial_op - vars = op.block.vars dist_attr = dist_op.dist_attr shape_list = op.desc.attr("shape") @@ -103,7 +102,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): backward_op = dist_op.serial_op main_block = backward_op.block need_gradient_allreduce = False - vars = main_block.vars for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): if "@GRAD" not in varname and is_parameter_related( @@ -246,9 +244,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): output_name ) - X_var = main_block.var(kwargs['X'][0]) - Out_var = main_block.var(kwargs['Out'][0]) - XShape_var = main_block.var(kwargs['XShape'][0]) + X_var = main_block._var_recursive(kwargs['X'][0]) + Out_var = main_block._var_recursive(kwargs['Out'][0]) + XShape_var = main_block._var_recursive(kwargs['XShape'][0]) shape_list = src_op.desc.attr("shape") ShapeTensor_var_list = [] for name in kwargs['ShapeTensor']: @@ -303,7 +301,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): def calc_fwd_cost(self, dist_op, ctx, cluster): res = [] op = dist_op.serial_op - vars = op.block.vars dist_attr = dist_op.dist_attr shape_list = op.desc.attr("shape") @@ -353,7 +350,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): backward_op = dist_op.serial_op main_block = backward_op.block need_gradient_allreduce = False - vars = main_block.vars for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): if "@GRAD" not in varname and not is_parameter_related( @@ -499,9 +495,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): output_name ) - X_var = main_block.var(kwargs['X'][0]) - Out_var = main_block.var(kwargs['Out'][0]) - XShape_var = main_block.var(kwargs['XShape'][0]) + X_var = main_block._var_recursive(kwargs['X'][0]) + Out_var = main_block._var_recursive(kwargs['Out'][0]) + XShape_var = main_block._var_recursive(kwargs['XShape'][0]) shape_list = src_op.desc.attr("shape") ShapeTensor_var_list = [] for name in kwargs['ShapeTensor']: @@ -556,7 +552,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): def calc_fwd_cost(self, dist_op, ctx, cluster): res = [] op = dist_op.serial_op - vars = op.block.vars dist_attr = dist_op.dist_attr shape_list = op.desc.attr("shape") @@ -606,7 +601,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): backward_op = dist_op.serial_op main_block = backward_op.block need_gradient_allreduce = False - vars = main_block.vars for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): if "@GRAD" not in varname and not is_parameter_related( @@ -745,9 +739,9 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): output_name ) - X_var = main_block.var(kwargs['X'][0]) - Out_var = main_block.var(kwargs['Out'][0]) - XShape_var = main_block.var(kwargs['XShape'][0]) + X_var = main_block._var_recursive(kwargs['X'][0]) + Out_var = main_block._var_recursive(kwargs['Out'][0]) + XShape_var = main_block._var_recursive(kwargs['XShape'][0]) shape_list = src_op.desc.attr("shape") ShapeTensor_var_list = [] for name in kwargs['ShapeTensor']: diff --git a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py index c480a634b097cfc500dccd7c8cdf44033ed15adf..0059d0e1bb4592f2a69418062542bf9a6de93a50 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_softmax.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_softmax.py @@ -79,7 +79,6 @@ class DistributedSoftmaxImpl(DistributedOperatorImpl): backward_op = dist_op.serial_op main_block = backward_op.block need_gradient_allreduce = False - vars = main_block.vars for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): if "@GRAD" not in varname and is_parameter_related( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py index 6153b4a7406e385981bbed5a269a86bdd577abe2..c5ce7628dc7d4ef4ee7d9a1010e27ce7646ac2df 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_transpose.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_transpose.py @@ -160,7 +160,6 @@ class DistributedTranspose2Impl(DistributedOperatorImpl): backward_op = dist_op.serial_op main_block = backward_op.block need_gradient_allreduce = False - vars = main_block.vars for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): if "@GRAD" not in varname and is_parameter_related( diff --git a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py index 26c530250e15a33ba9dc8cbf37e4e74e0aa64268..048d06791bbfe9ec1483ee444dd1d7b299164510 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_update_loss_scaling.py @@ -151,7 +151,7 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): if ( rank_id in ctx.get_tensor_dist_attr_for_program( - main_block.var(varname) + main_block._var_recursive(varname) ).process_mesh.processes ): filter_vars.append(varname)