diff --git a/python/paddle/distributed/auto_parallel/operators/dist_slice.py b/python/paddle/distributed/auto_parallel/operators/dist_slice.py index b0d31f12dace9965191f6fb92ba8c84393ee0c50..18c643c1d76cbbc9822dcb51fd5bbf2cf30b7708 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_slice.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_slice.py @@ -40,8 +40,8 @@ class DistributedSliceImpl(DistributedOperatorImpl): op_dist_attr = dist_op.dist_attr in_name = op_desc.input('Input')[0] out_name = op_desc.output('Out')[0] - in_var = dist_op.serial_op.block.var(in_name) - out_var = dist_op.serial_op.block.var(out_name) + in_var = dist_op.serial_op.block._var_recursive(in_name) + out_var = dist_op.serial_op.block._var_recursive(out_name) axes = op_desc.attr('axes') in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name) for axis in axes: @@ -57,8 +57,8 @@ class DistributedSliceImpl(DistributedOperatorImpl): op_dist_attr = dist_op.dist_attr in_name = op_desc.input('Input')[0] out_name = op_desc.output('Out')[0] - in_var = dist_op.serial_op.block.var(in_name) - out_var = dist_op.serial_op.block.var(out_name) + in_var = dist_op.serial_op.block._var_recursive(in_name) + out_var = dist_op.serial_op.block._var_recursive(out_name) axes = op_desc.attr('axes') decrease_axis = op_desc.attr('decrease_axis') in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)