From 096b2f5af14cc511a4e045d41c7f1b17d1983cd1 Mon Sep 17 00:00:00 2001 From: liym27 <33742067+liym27@users.noreply.github.com> Date: Fri, 14 May 2021 11:17:45 +0800 Subject: [PATCH] Polish code for _getitem_impl_ (#32868) --- python/paddle/fluid/framework.py | 144 +++++++++---------------------- 1 file changed, 40 insertions(+), 104 deletions(-) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index f4cad7894a..e9a114b3d5 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -792,29 +792,16 @@ def _getitem_impl_(var, item): if not isinstance(item, tuple): item = [item] - decrease_axis = [] - slice_axis = [] - slice_start = [] - slice_end = [] - slice_step = [] + decrease_axes = [] + axes = [] + starts = [] + ends = [] + steps = [] + use_strided_slice = False reverse_axis = [] - target_block = default_main_program().current_block() - - def fill_constant(shape, value, force_cpu=False, out=None): - var.block.append_op( - type='fill_constant', - inputs={}, - outputs={'Out': [out]}, - attrs={ - 'shape': shape, - 'dtype': out.dtype, - 'value': float(value), - 'force_cpu': force_cpu - }) - out.stop_gradient = True - return out + max_integer = 2**31 - 1 for dim, slice_item in enumerate(item): if isinstance(slice_item, slice): start = slice_item.start @@ -824,8 +811,7 @@ def _getitem_impl_(var, item): if start is None and end is None and step is None: continue - if step is None: - step = 1 + step = 1 if step is None else step if start is None and end is None: assert (step == -1) @@ -836,106 +822,56 @@ def _getitem_impl_(var, item): start = 0 if end is None: - end = 10000000 - - if step != 1: - use_strided_slice = True + end = max_integer - slice_axis.append(dim) - slice_start.append(start) - slice_end.append(end) - slice_step.append(step) else: - decrease_axis.append(dim) - slice_axis.append(dim) - slice_start.append(slice_item) - slice_step.append(1) - if isinstance(slice_item, Variable): - temp_1 = var.block.create_var(dtype=slice_item.dtype) - fill_constant([1], 1, force_cpu=True, out=temp_1) - temp_end = target_block.create_var(dtype=slice_item.dtype) - target_block.append_op( - type='elementwise_add', - inputs={'X': slice_item, - 'Y': temp_1}, - outputs={'Out': temp_end}, - attrs={'axis': -1}) - slice_end.append(temp_end) - else: - slice_end.append(slice_item + 1 - if slice_item != -1 else 10000000) + decrease_axes.append(dim) + start = slice_item + step = 1 + end = slice_item + 1 if slice_item != -1 else max_integer - def contain_var(one_list): - for ele in one_list: - if isinstance(ele, Variable): - return True - return False - - def get_new_list_tensor(old_list): - new_list_tensor = [] - for dim in old_list: - if isinstance(dim, Variable): - dim.stop_gradient = True - new_list_tensor.append(dim) - else: - assert (isinstance(dim, int)) - temp_out = var.block.create_var(dtype='int64') - fill_constant([1], dim, force_cpu=True, out=temp_out) - new_list_tensor.append(temp_out) - return new_list_tensor + axes.append(dim) + starts.append(start) + ends.append(end) + steps.append(step) + use_strided_slice = True if step != 1 else use_strided_slice inputs = {'Input': [var]} attrs = { - 'axes': slice_axis, + 'axes': axes, 'starts': [], 'ends': [], - 'decrease_axis': decrease_axis + 'decrease_axis': decrease_axes } - if (use_strided_slice == True): + if use_strided_slice == True: attrs['strides'] = [] - infer_flags = list(1 for i in range(len(slice_axis))) - - # starts - if contain_var(slice_start): - inputs['StartsTensorList'] = get_new_list_tensor(slice_start) - for i, dim in enumerate(slice_start): - if isinstance(dim, Variable): - attrs['starts'].append(-1) - infer_flags[i] = -1 - else: - attrs['starts'].append(dim) - else: - attrs['starts'] = slice_start - - # ends - if contain_var(slice_end): - inputs['EndsTensorList'] = get_new_list_tensor(slice_end) - for i, dim in enumerate(slice_end): - if isinstance(dim, Variable): - attrs['ends'].append(-1) - infer_flags[i] = -1 - else: - attrs['ends'].append(dim) - else: - attrs['ends'] = slice_end - # strides - if use_strided_slice == True: - if contain_var(slice_step): - inputs['StridesTensorList'] = get_new_list_tensor(slice_step) - for i, dim in enumerate(slice_step): + infer_flags = list(1 for i in range(len(axes))) + from .layers import utils + + def deal_attrs(attr, attr_name, tensor_attr_name, inputs, infer_flags): + if utils._contain_var(attr): + inputs[tensor_attr_name] = utils._convert_to_tensor_list( + attr, dtype="int64") + for i, dim in enumerate(attr): if isinstance(dim, Variable): - attrs['strides'].append(-1) + attrs[attr_name].append(-1) infer_flags[i] = -1 else: - attrs['strides'].append(dim) + attrs[attr_name].append(dim) else: - attrs['strides'] = slice_step + attrs[attr_name] = attr + + deal_attrs(starts, "starts", "StartsTensorList", inputs, infer_flags) + deal_attrs(ends, "ends", "EndsTensorList", inputs, infer_flags) + deal_attrs(steps, "strides", "StridesTensorList", inputs, infer_flags) + # infer_flags attrs['infer_flags'] = infer_flags out = var - if use_strided_slice == False and len(slice_axis) > 0: + target_block = default_main_program().current_block() + if use_strided_slice == False and len(axes) > 0: # append slice_op here slice_out_var = target_block.create_var( name=unique_name.generate_with_ignorable_key(var.name + "_slice"), @@ -948,7 +884,7 @@ def _getitem_impl_(var, item): attrs=attrs) out = slice_out_var - elif use_strided_slice == True and len(slice_axis) > 0: + elif use_strided_slice == True and len(axes) > 0: strided_slice_out_var = target_block.create_var( name=unique_name.generate_with_ignorable_key(var.name + "_strided_slice"), -- GitLab