未验证 提交 096b2f5a 编写于 作者: L liym27 提交者: GitHub

Polish code for _getitem_impl_ (#32868)

上级 a8625aaf
...@@ -792,29 +792,16 @@ def _getitem_impl_(var, item): ...@@ -792,29 +792,16 @@ def _getitem_impl_(var, item):
if not isinstance(item, tuple): if not isinstance(item, tuple):
item = [item] item = [item]
decrease_axis = [] decrease_axes = []
slice_axis = [] axes = []
slice_start = [] starts = []
slice_end = [] ends = []
slice_step = [] steps = []
use_strided_slice = False use_strided_slice = False
reverse_axis = [] reverse_axis = []
target_block = default_main_program().current_block()
def fill_constant(shape, value, force_cpu=False, out=None):
var.block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [out]},
attrs={
'shape': shape,
'dtype': out.dtype,
'value': float(value),
'force_cpu': force_cpu
})
out.stop_gradient = True
return out
max_integer = 2**31 - 1
for dim, slice_item in enumerate(item): for dim, slice_item in enumerate(item):
if isinstance(slice_item, slice): if isinstance(slice_item, slice):
start = slice_item.start start = slice_item.start
...@@ -824,8 +811,7 @@ def _getitem_impl_(var, item): ...@@ -824,8 +811,7 @@ def _getitem_impl_(var, item):
if start is None and end is None and step is None: if start is None and end is None and step is None:
continue continue
if step is None: step = 1 if step is None else step
step = 1
if start is None and end is None: if start is None and end is None:
assert (step == -1) assert (step == -1)
...@@ -836,106 +822,56 @@ def _getitem_impl_(var, item): ...@@ -836,106 +822,56 @@ def _getitem_impl_(var, item):
start = 0 start = 0
if end is None: if end is None:
end = 10000000 end = max_integer
if step != 1:
use_strided_slice = True
slice_axis.append(dim)
slice_start.append(start)
slice_end.append(end)
slice_step.append(step)
else: else:
decrease_axis.append(dim) decrease_axes.append(dim)
slice_axis.append(dim) start = slice_item
slice_start.append(slice_item) step = 1
slice_step.append(1) end = slice_item + 1 if slice_item != -1 else max_integer
if isinstance(slice_item, Variable):
temp_1 = var.block.create_var(dtype=slice_item.dtype)
fill_constant([1], 1, force_cpu=True, out=temp_1)
temp_end = target_block.create_var(dtype=slice_item.dtype)
target_block.append_op(
type='elementwise_add',
inputs={'X': slice_item,
'Y': temp_1},
outputs={'Out': temp_end},
attrs={'axis': -1})
slice_end.append(temp_end)
else:
slice_end.append(slice_item + 1
if slice_item != -1 else 10000000)
def contain_var(one_list): axes.append(dim)
for ele in one_list: starts.append(start)
if isinstance(ele, Variable): ends.append(end)
return True steps.append(step)
return False use_strided_slice = True if step != 1 else use_strided_slice
def get_new_list_tensor(old_list):
new_list_tensor = []
for dim in old_list:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_list_tensor.append(dim)
else:
assert (isinstance(dim, int))
temp_out = var.block.create_var(dtype='int64')
fill_constant([1], dim, force_cpu=True, out=temp_out)
new_list_tensor.append(temp_out)
return new_list_tensor
inputs = {'Input': [var]} inputs = {'Input': [var]}
attrs = { attrs = {
'axes': slice_axis, 'axes': axes,
'starts': [], 'starts': [],
'ends': [], 'ends': [],
'decrease_axis': decrease_axis 'decrease_axis': decrease_axes
} }
if (use_strided_slice == True): if use_strided_slice == True:
attrs['strides'] = [] attrs['strides'] = []
infer_flags = list(1 for i in range(len(slice_axis)))
# starts
if contain_var(slice_start):
inputs['StartsTensorList'] = get_new_list_tensor(slice_start)
for i, dim in enumerate(slice_start):
if isinstance(dim, Variable):
attrs['starts'].append(-1)
infer_flags[i] = -1
else:
attrs['starts'].append(dim)
else:
attrs['starts'] = slice_start
# ends
if contain_var(slice_end):
inputs['EndsTensorList'] = get_new_list_tensor(slice_end)
for i, dim in enumerate(slice_end):
if isinstance(dim, Variable):
attrs['ends'].append(-1)
infer_flags[i] = -1
else:
attrs['ends'].append(dim)
else:
attrs['ends'] = slice_end
# strides infer_flags = list(1 for i in range(len(axes)))
if use_strided_slice == True: from .layers import utils
if contain_var(slice_step):
inputs['StridesTensorList'] = get_new_list_tensor(slice_step) def deal_attrs(attr, attr_name, tensor_attr_name, inputs, infer_flags):
for i, dim in enumerate(slice_step): if utils._contain_var(attr):
inputs[tensor_attr_name] = utils._convert_to_tensor_list(
attr, dtype="int64")
for i, dim in enumerate(attr):
if isinstance(dim, Variable): if isinstance(dim, Variable):
attrs['strides'].append(-1) attrs[attr_name].append(-1)
infer_flags[i] = -1 infer_flags[i] = -1
else: else:
attrs['strides'].append(dim) attrs[attr_name].append(dim)
else: else:
attrs['strides'] = slice_step attrs[attr_name] = attr
deal_attrs(starts, "starts", "StartsTensorList", inputs, infer_flags)
deal_attrs(ends, "ends", "EndsTensorList", inputs, infer_flags)
deal_attrs(steps, "strides", "StridesTensorList", inputs, infer_flags)
# infer_flags # infer_flags
attrs['infer_flags'] = infer_flags attrs['infer_flags'] = infer_flags
out = var out = var
if use_strided_slice == False and len(slice_axis) > 0: target_block = default_main_program().current_block()
if use_strided_slice == False and len(axes) > 0:
# append slice_op here # append slice_op here
slice_out_var = target_block.create_var( slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + "_slice"), name=unique_name.generate_with_ignorable_key(var.name + "_slice"),
...@@ -948,7 +884,7 @@ def _getitem_impl_(var, item): ...@@ -948,7 +884,7 @@ def _getitem_impl_(var, item):
attrs=attrs) attrs=attrs)
out = slice_out_var out = slice_out_var
elif use_strided_slice == True and len(slice_axis) > 0: elif use_strided_slice == True and len(axes) > 0:
strided_slice_out_var = target_block.create_var( strided_slice_out_var = target_block.create_var(
name=unique_name.generate_with_ignorable_key(var.name + name=unique_name.generate_with_ignorable_key(var.name +
"_strided_slice"), "_strided_slice"),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册