未验证 提交 096b2f5a 编写于 作者: L liym27 提交者: GitHub

Polish code for _getitem_impl_ (#32868)

上级 a8625aaf
......@@ -792,29 +792,16 @@ def _getitem_impl_(var, item):
if not isinstance(item, tuple):
item = [item]
decrease_axis = []
slice_axis = []
slice_start = []
slice_end = []
slice_step = []
decrease_axes = []
axes = []
starts = []
ends = []
steps = []
use_strided_slice = False
reverse_axis = []
target_block = default_main_program().current_block()
def fill_constant(shape, value, force_cpu=False, out=None):
var.block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu
})
out.stop_gradient = True
return out
max_integer = 2**31 - 1
for dim, slice_item in enumerate(item):
if isinstance(slice_item, slice):
start = slice_item.start
......@@ -824,8 +811,7 @@ def _getitem_impl_(var, item):
if start is None and end is None and step is None:
continue
if step is None:
step = 1
step = 1 if step is None else step
if start is None and end is None:
assert (step == -1)
......@@ -836,106 +822,56 @@ def _getitem_impl_(var, item):
start = 0
if end is None:
end = 10000000
end = max_integer
if step != 1:
use_strided_slice = True
slice_axis.append(dim)
slice_start.append(start)
slice_end.append(end)
slice_step.append(step)
else:
decrease_axis.append(dim)
slice_axis.append(dim)
slice_start.append(slice_item)
slice_step.append(1)
if isinstance(slice_item, Variable):
temp_1 = var.block.create_var(dtype=slice_item.dtype)
fill_constant([1], 1, force_cpu=True, out=temp_1)
temp_end = target_block.create_var(dtype=slice_item.dtype)
target_block.append_op(
type='elementwise_add',
inputs={'X': slice_item,
'Y': temp_1},
outputs={'Out': temp_end},
attrs={'axis': -1})
slice_end.append(temp_end)
else:
slice_end.append(slice_item + 1
if slice_item != -1 else 10000000)
def contain_var(one_list):
for ele in one_list:
if isinstance(ele, Variable):
return True
return False
decrease_axes.append(dim)
start = slice_item
step = 1
end = slice_item + 1 if slice_item != -1 else max_integer
def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = var.block.create_var(dtype='int64')
fill_constant([1], dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
axes.append(dim)
starts.append(start)
ends.append(end)
steps.append(step)
use_strided_slice = True if step != 1 else use_strided_slice
inputs = {'Input': [var]}
attrs = {
'axes': slice_axis,
'axes': axes,
'starts': [],
'ends': [],
'decrease_axis': decrease_axis
'decrease_axis': decrease_axes
}
if (use_strided_slice == True):
if use_strided_slice == True:
attrs['strides'] = []
infer_flags = list(1 for i in range(len(slice_axis)))
# starts
if contain_var(slice_start):
inputs['StartsTensorList'] = get_new_list_tensor(slice_start)
for i, dim in enumerate(slice_start):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
infer_flags[i] = -1
else:
attrs['starts'].append(dim)
else:
attrs['starts'] = slice_start
infer_flags = list(1 for i in range(len(axes)))
from .layers import utils
# ends
if contain_var(slice_end):
inputs['EndsTensorList'] = get_new_list_tensor(slice_end)
for i, dim in enumerate(slice_end):
def deal_attrs(attr, attr_name, tensor_attr_name, inputs, infer_flags):
if utils._contain_var(attr):
inputs[tensor_attr_name] = utils._convert_to_tensor_list(
attr, dtype="int64")
for i, dim in enumerate(attr):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
attrs[attr_name].append(-1)
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
attrs[attr_name].append(dim)
else:
attrs['ends'] = slice_end
attrs[attr_name] = attr
deal_attrs(starts, "starts", "StartsTensorList", inputs, infer_flags)
deal_attrs(ends, "ends", "EndsTensorList", inputs, infer_flags)
deal_attrs(steps, "strides", "StridesTensorList", inputs, infer_flags)
# strides
if use_strided_slice == True:
if contain_var(slice_step):
inputs['StridesTensorList'] = get_new_list_tensor(slice_step)
for i, dim in enumerate(slice_step):
if isinstance(dim, Variable):
attrs['strides'].append(-1)
infer_flags[i] = -1
else:
attrs['strides'].append(dim)
else:
attrs['strides'] = slice_step
# infer_flags
attrs['infer_flags'] = infer_flags
out = var
if use_strided_slice == False and len(slice_axis) > 0:
target_block = default_main_program().current_block()
if use_strided_slice == False and len(axes) > 0:
# append slice_op here
slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + "_slice"),
......@@ -948,7 +884,7 @@ def _getitem_impl_(var, item):
attrs=attrs)
out = slice_out_var
elif use_strided_slice == True and len(slice_axis) > 0:
elif use_strided_slice == True and len(axes) > 0:
strided_slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name +
"_strided_slice"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册