未验证 提交 8fece428 编写于 作者: Z zhaoyingli 提交者: GitHub

fix dist slice op (#47980)

上级 e65bac28
......@@ -40,8 +40,8 @@ class DistributedSliceImpl(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr
in_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0]
in_var = dist_op.serial_op.block.var(in_name)
out_var = dist_op.serial_op.block.var(out_name)
in_var = dist_op.serial_op.block._var_recursive(in_name)
out_var = dist_op.serial_op.block._var_recursive(out_name)
axes = op_desc.attr('axes')
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
for axis in axes:
......@@ -57,8 +57,8 @@ class DistributedSliceImpl(DistributedOperatorImpl):
op_dist_attr = dist_op.dist_attr
in_name = op_desc.input('Input')[0]
out_name = op_desc.output('Out')[0]
in_var = dist_op.serial_op.block.var(in_name)
out_var = dist_op.serial_op.block.var(out_name)
in_var = dist_op.serial_op.block._var_recursive(in_name)
out_var = dist_op.serial_op.block._var_recursive(out_name)
axes = op_desc.attr('axes')
decrease_axis = op_desc.attr('decrease_axis')
in_dims_mapping = op_dist_attr.get_input_dims_mapping(in_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册