From 8fece428e96d22ec843d776233579b4a3fc254f7 Mon Sep 17 00:00:00 2001 From: zhaoyingli <86812880+zhaoyinglia@users.noreply.github.com> Date: Tue, 15 Nov 2022 14:43:11 +0800 Subject: [PATCH] fix dist slice op (#47980) --- .../distributed/auto_parallel/operators/dist_slice.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_slice.py b/python/paddle/distributed/auto_parallel/operators/dist_slice.py index b0d31f12dac..18c643c1d76 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_slice.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_slice.py @@ -40,8 +40,8 @@ class DistributedSliceImpl(DistributedOperatorImpl): op_dist_attr = dist_op.dist_attr in_name = op_desc.input('Input')[0] out_name = op_desc.output('Out')[0] - in_var = dist_op.serial_op.block.var(in_name) - out_var = dist_op.serial_op.block.var(out_name) + in_var = dist_op.serial_op.block._var_recursive(in_name) + out_var = dist_op.serial_op.block._var_recursive(out_name) axes = op_desc.attr('axes') in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name) for axis in axes: @@ -57,8 +57,8 @@ class DistributedSliceImpl(DistributedOperatorImpl): op_dist_attr = dist_op.dist_attr in_name = op_desc.input('Input')[0] out_name = op_desc.output('Out')[0] - in_var = dist_op.serial_op.block.var(in_name) - out_var = dist_op.serial_op.block.var(out_name) + in_var = dist_op.serial_op.block._var_recursive(in_name) + out_var = dist_op.serial_op.block._var_recursive(out_name) axes = op_desc.attr('axes') decrease_axis = op_desc.attr('decrease_axis') in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name) -- GitLab