diff --git a/x2paddle/decoder/caffe_shape.py b/x2paddle/decoder/caffe_shape.py index e87d1f3f0f2ec867b3d99318eb52849c4317a09e..2bca90ad8effe080f055a2c0b3670099485af2ed 100644 --- a/x2paddle/decoder/caffe_shape.py +++ b/x2paddle/decoder/caffe_shape.py @@ -230,3 +230,226 @@ def shape_batchnorm(layer, input_shape): def shape_scale(layer, input_shape): return input_shape + + +def shape_reshape(layer, input_shape): + def count(num_list): + return reduce(lambda a, b: a * b, num_list) + + inshape = input_shape[0] + params = layer.reshape_param + axis = params.axis if hasattr(params, axis) else 0 + num_axes = params.num_axes if hasattr(params, num_axes) else -1 + if inshape[0] == -1: + inshape[0] = 1 + input_count = count(inshape) + + input_num_axes = len(inshape) + + input_start_axis = axis + start_axis = input_start_axis if input_start_axis >= 0 \ + else input_num_axes + input_start_axis + 1 + + assert start_axis >= 0, "[Reshape]axis %d out of range" % (input_start_axis) + assert start_axis <= input_num_axes, "[Reshape]axis %d out of range for %d-D input data"\ + % (input_start_axis, input_num_axes) + + assert num_axes >= -1, "[Reshape]num_axes must be >= 0, or -1 for all" + + end_axis = input_num_axes if num_axes == -1 else start_axis + num_axes + assert end_axis <= input_num_axes, "end_axis[%d] = axis[%d] + num_axes[%d] is out of range"\ + % (end_axis, start_axis, num_axes) + + num_axes_replaced = end_axis - start_axis + num_axes_retained = input_num_axes - num_axes_replaced + num_new_axes = len(shape['dim']) + outshape = [] + + for i in range(start_axis): + outshape.append(inshape[i]) + + for i in range(num_new_axes): + outshape.append(shape['dim'][i]) + + for i in range(end_axis, input_num_axes): + outshape.append(inshape[i]) + + assert len(outshape) == num_axes_retained + num_new_axes,\ + "[Reshape]invalid dims of output shape[%s]" % (str(outshape)) + + inferred_axis = -1 + copy_axes = [] + constant_count = 1 + for i in range(num_new_axes): + top_dim = shape['dim'][i] + if top_dim == 0: + copy_axes.append(i) + copy_axis_index = start_axis + i + outshape[copy_axis_index] = inshape[copy_axis_index] + elif top_dim == -1: + assert inferred_axis == -1, "[Reshape]new shape contains multiple -1 dims" + inferred_axis = i + else: + constant_count *= top_dim + + if inferred_axis >= 0: + explicit_count = constant_count + l = inshape[0:start_axis] + if len(l) > 0: + explicit_count *= count(l) + + l = inshape[end_axis:] + if len(l) > 0: + explicit_count *= count(l) + + for i in range(len(copy_axes)): + explicit_count *= outshape[start_axis + copy_axes[i]] + + assert input_count % explicit_count == 0, "[Reshape]botom count[%d] "\ + "must be divisible by product of the specified dimensions[%d] "\ + % (input_count, explicit_count) + outshape[start_axis + inferred_axis] = input_count / explicit_count + + output_count = count(outshape) + assert output_count == input_count, "[Reshape]output count[%d] must match input count[%d]" % ( + output_count, input_count) + if inshape[0] == -1: + outshape[0] = -1 + return [outshape] + + +def shape_argmax(layer, input_shape): + inshape = input_shape[0] + params = layer.argmax_param + out_max_val = params.out_max_val if hasattr(params, out_max_val) else False + top_k = params.top_k if hasattr(params, top_k) else 1 + axis = parmas.axis if hasattr(params, axis) else -1 + if axis < 0: + axis += len(inshape) + assert (axis + 1 == len(inshape) + ), 'only can be applied on the last dimension[axis:%d, %s] now,'\ + 'make sure you have set axis param in xxx.prototxt file' \ + % (axis, str(inshape)) + + outshape = inshape + outshape[-1] = top_k + if out_max_val is True: + outshape[-1] *= 2 + return [outshape] + + +def shape_axpy(layer, input_shape): + assert len(input_shapes) == 3, "not valid input shape for axpy layer" + assert len(input_shapes[0]) == len(input_shapes[1]), 'should have same dims' + + output_shape = input_shapes[1] + assert (input_shapes[2] == output_shape),\ + "shape not consistent for axpy[%s <--> %s]" \ + % (str(output_shape), str(input_shapes[2])) + return [output_shape] + + +def shape_crop(layer, input_shape): + assert len(input_shape) == 2, "the number of crop's inputs must be 2" + return [input_shape[1]] + + +def shape_detectionoutput(layer, input_shape): + return [[-1, 6]] + + +def shape_flatten(layer, input_shape): + assert len(input_shape) == 1, "the number of flatten's inputs must be 1" + params = layer.flatten_param + start_axis = params.axis + end_axis = params.end_axis + if start_axis < 0: + start_axis += len(input_shape[0]) + if end_axis < 0: + end_axis += len(input_shape[0]) + 1 + assert start_axis <= end_axis, 'invalid axis[%d] or end_axis[%d] params'\ + % (start_axis, end_axis) + output_shape = [0] * (start_axis - 0) + [ + -1 + ] + [0] * (len(input_shape[0]) - end_axis) + return [output_shape] + + +def shape_normalize(layer, input_shape): + return input_shape + + +def shape_permute(layer, input_shape): + params = layer.permute_param + order = list(params.order) + inshape = input_shape[0] + output_shape = [] + for ii in order: + assert ii < len(inshape), "invalid order for permute[%s]" % (name) + output_shape.append(inshape[ii]) + return [output_shape] + + +def shape_power(layer, input_shape): + return input_shape + + +def shape_priorbox(layer, input_shape): + params = layer.prior_box_param + min_size = list(params.min_size) + max_size = list(params.max_size) + aspect_ratio = list(params.aspect_ratio) + assert len(input_shapes[0]) == 2, "invalid inputs for Priorbox[%s]" % (name) + fc_shape = input_shapes[0][0] + N = 1 + if not max_size == None: + N += 1 + if not aspect_ratio == None: + N += 2 * len(aspect_ratio) + + N_bbx = fc_shape[2] * fc_shape[3] * N + output_shape = [[1, 2, 4 * N_bbx]] + return output_shape + + +def shape_reduction(layer, input_shape): + params = layer.reduction_param + axis = params.axis + if axis < 0: + axis += len(input_shape[0]) + 1 + assert axis <= len(input_shape[0]), 'invalid axis[%d] error' % (axis) + return [input_shape[0:axis]] + + +def shape_roipooling(layer, input_shape): + params = layer.roi_pooling_param + pooled_w = params.pooled_w + pooled_h = params.pooled_h + spatial_scale = params.spatial_scale + assert len( + input_shapes[0]) == 2, "not valid input shape for roipooling layer" + base_fea_shape = input_shapes[0][0] + rois_shape = input_shapes[0][1] + output_shape = base_fea_shape + output_shape[0] = rois_shape[0] + output_shape[2] = pooled_h + output_shape[3] = pooled_w + return [output_shape] + + +def shape_select(layer, input_shape): + input_shape = list(input_shape[0]) + params = layer.select_param + axis = params.axis + slice_point = list(params.slice_point) + start = slice_point[0] + if len(slice_point) == 2: + end = slice_point[1] + else: + end = input_shape[axis] + + assert end > start, "invalid slice_point with [start:%d, end:%d]"\ + % (start, end) + output_shape = input_shape + output_shape[axis] = end - start + return [output_shape] diff --git a/x2paddle/op_mapper/caffe_op_mapper.py b/x2paddle/op_mapper/caffe_op_mapper.py index f4e98ae3f0853510c1ac341f29edf49994e7a3df..abb8d241a6fdca02f59a9957eec40ad51dee28cb 100644 --- a/x2paddle/op_mapper/caffe_op_mapper.py +++ b/x2paddle/op_mapper/caffe_op_mapper.py @@ -267,6 +267,7 @@ class CaffeOpMapper(OpMapper): def Pooling(self, node): params = node.layer.pooling_param + ceil_mode = getattr(params, 'ceil_mode', True) global_pool = getattr(params, 'global_pooling', False) kernel_default = [1, 1] channel, kernel, stride, pad, dilation, group = self.get_kernel_parameters( @@ -286,7 +287,7 @@ class CaffeOpMapper(OpMapper): 'pool_size': kernel, 'pool_stride': stride, 'pool_padding': pad, - 'ceil_mode': True, + 'ceil_mode': ceil_mode, 'pool_type': string(pool_type), 'exclusive': True, 'global_pooling': global_pool, @@ -737,7 +738,7 @@ class CaffeOpMapper(OpMapper): else: self.weights[node.layer_name + '_scale'] = np.squeeze(nose.data[0]) self.weights[node.layer_name + '_offset'] = np.squeeze(node.data[1]) - params = node.layer.scale_params + params = node.layer.scale_param axis = params.axis num_axes = params.num_axes assert num_axes == 1, "layer scale not support this num_axes[%d] now" % ( @@ -811,3 +812,518 @@ class CaffeOpMapper(OpMapper): node.layer_name, node.layer_name), output=node, param_attr=attr) + + def Reshape(self, node): + assert len(node.inputs) == 1 and len( + node.outputs + ) == 1, 'The count of Reshape node\'s input and output is not 1.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + top_count = len(input.layer.top) + if self.is_Scale(input): + tmp = self.graph.get_bottom_node(input, idx=0, copy=True) + if self.is_BN(tmp): + input = tmp + is_inplace, = False if top_count == 1 else True + output_shape = node.output_shape[0] + attr = { + 'shape': output_shape, + 'inplace': is_inplace, + 'name': string(node.layer_name) + } + node.fluid_code.add_layer("reshape", + inputs=input, + output=node, + param_attr=attr) + + def ArgMax(self, node): + assert len(node.inputs) == 1 and len( + node.outputs + ) == 1, 'The count of ArgMax node\'s input and output is not 1.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + if self.is_Scale(input): + tmp = self.graph.get_bottom_node(input, idx=0, copy=True) + if self.is_BN(tmp): + input = tmp + input_shape = node.input_shape[0] + params = node.layer.argmax_param + out_max_val = params.out_max_val if hasattr(params, + out_max_val) else False + top_k = params.top_k if hasattr(params, top_k) else 1 + axis = parmas.axis if hasattr(params, axis) else -1 + if axis < 0: + axis += len(input_shape) + if out_max_val is True: + attr = {'k': top_k, 'name': string(node.layer_name + '_topk')} + node.fluid_code.add_layer("topk", + inputs=input, + output='{}_topk_var, {}_index_var'.format( + node.layer_name, node.layer_name), + param_attr=attr) + attr = {'dtype': '{}_topk_var.dtype'.format(node.layer_name)} + node.fluid_code.add_layer( + "cast", + inputs='{}_index_var'.format(node.layer_name), + output='{}_index_var'.format(node.layer_name), + param_attr=attr) + attr = {'axis': axis, 'name': string(node.layer_name)} + node.fluid_code.add_layer("concat", + inputs='{}_topk_var, {}_index_var'.format( + node.layer_name, node.layer_name), + output=node, + param_attr=attr) + else: + attr = {'k': top_k, 'name': string(node.layer_name)} + node.fluid_code.add_layer("topk", + inputs=input, + output='_, {}'.format(node.layer_name), + param_attr=attr) + + def Axpy(self, node): + assert len( + node.inputs) == 3, 'The count of Axpy node\'s input is not 3.' + alpha = self.graph.get_bottom_node(node, idx=0, copy=True) + if self.is_Scale(alpha): + tmp = self.graph.get_bottom_node(alpha, idx=0, copy=True) + if self.is_BN(tmp): + alpha = tmp + x = self.graph.get_bottom_node(node, idx=1, copy=True) + if self.is_Scale(x): + tmp = self.graph.get_bottom_node(x, idx=0, copy=True) + if self.is_BN(tmp): + x = tmp + y = self.graph.get_bottom_node(node, idx=2, copy=True) + if self.is_Scale(y): + tmp = self.graph.get_bottom_node(y, idx=0, copy=True) + if self.is_BN(tmp): + y = tmp + attr = {'axis': 0, 'name': string(node.layer_name + '_mul')} + node.fluid_code.add_layer("elementwise_mul", + inputs={ + 'x': alpha, + 'y': x + }, + output=node, + param_attr=attr) + attr = {'name': string(node.layer_name + '_add')} + node.fluid_code.add_layer("elementwise_add", + inputs={ + 'x': node, + 'y': y + }, + output=node, + param_attr=attr) + + def Crop(self, node): + assert len( + node.inputs) == 2, 'The count of Crop node\'s input is not 2.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + if self.is_Scale(input): + tmp = self.graph.get_bottom_node(input, idx=0, copy=True) + if self.is_BN(tmp): + input = tmp + example = self.graph.get_bottom_node(node, idx=1, copy=True) + if self.is_Scale(example): + tmp = self.graph.get_bottom_node(example, idx=0, copy=True) + if self.is_BN(tmp): + example = tmp + params = node.layer.crop_param + axis = parmas.axis + input_shape = node.input_shape[0] + if axis < 0: + axis += len(input_shape) + offset_real = [0] * len(input_shape) + if hasattr(params, offset): + offset = list(params.offset) + assert (len(input_shape) - axis) == len( + offset), "invalid offset[%s] in crop layer" % (str(offset)) + offset_real = [0] * axis + offset + attr = {'offsets': offset_real, 'name': string(node.layer_name)} + node.fluid_code.add_layer("crop", + inputs={ + 'x': input, + 'y': example + }, + output=node, + param_attr=attr) + + def DetectionOutput(self, node): + assert len( + node.inputs + ) == 3, 'The count of DetectionOutput node\'s input is not 3.' + mbox_loc = self.graph.get_bottom_node(node, idx=0, copy=True) + if self.is_Scale(mbox_loc): + tmp = self.graph.get_bottom_node(mbox_loc, idx=0, copy=True) + if self.is_BN(tmp): + mbox_loc = tmp + mbox_conf_flatten = self.graph.get_bottom_node(node, idx=1, copy=True) + if self.is_Scale(mbox_conf_flatten): + tmp = self.graph.get_bottom_node(mbox_conf_flatten, + idx=0, + copy=True) + if self.is_BN(tmp): + mbox_conf_flatten = tmp + mbox_priorbox = self.graph.get_bottom_node(node, idx=2, copy=True) + if self.is_Scale(mbox_priorbox): + tmp = self.graph.get_bottom_node(mbox_priorbox, idx=0, copy=True) + if self.is_BN(tmp): + mbox_priorbox = tmp + params = node.layer.detection_output_param + nms_threshold = 0.3 + top_k = 10 + eta = 1.0 + if hasattr(params, 'nms_param'): + nms_threshold = getattr(params.nms_param, 'nms_threshold', 0.3) + top_k = getattr(params.nms_param, 'top_k', 10) + eta = getattr(params.nms_param, 'eta', 1.0) + background_label = getattr(params, 'background_label_id', 0) + share_location = getattr(params, 'share_location', True) + keep_top_k = getattr(params, 'keep_top_k', 100) + confidence_threshold = getattr(params, 'confidence_threshold', 0.1) + attr = { + 'num_or_sections': 2, + 'dim': 1, + 'name': string(node.layer_name + '_split') + } + node.fluid_code.add_layer("split", + inputs=mbox_priorbox, + output='mbox_priorbox_list', + param_attr=attr) + node.fluid_code.add_note('pb = mbox_priorbox_list[0]') + node.fluid_code.add_note('pbv = mbox_priorbox_list[1]') + attr = {'shape': [-1, 4], 'name': string(node.layer_name + '_reshape1')} + node.fluid_code.add_layer("reshape", + inputs='pb', + output='pb', + param_attr=attr) + attr = {'shape': [-1, 4], 'name': string(node.layer_name + '_reshape2')} + node.fluid_code.add_layer("reshape", + inputs='pbv', + output='pbv', + param_attr=attr) + # TODO(syf): need chaeck + attr = { + 'shape': [-1, node.input_shape[1][1], 4], + 'name': string(node.layer_name + '_reshape3') + } + node.fluid_code.add_layer("reshape", + inputs=mbox_loc, + output='mbox_loc', + param_attr=attr) + attr = { + 'background_label': background_label, + 'nms_threshold': nms_threshold, + 'nms_top_k': top_k, + 'keep_top_k': keep_top_k, + 'score_threshold': confidence_threshold, + 'nms_eta': eta + } + inputs_str = get_input_name(mbox_conf_flatten) + ', mbox_loc, pb, pbv' + node.fluid_code.add_layer("detection_output", + inputs=inputs_str, + output=node, + param_attr=attr) + + def Flatten(self, noed): + assert len( + node.inputs + ) == 1, 'The count of DetectionOutput node\'s input is not 1.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + if self.is_Scale(input): + tmp = self.graph.get_bottom_node(input, idx=0, copy=True) + if self.is_BN(tmp): + input = tmp + shape = node.output_shape[0] + attr = {'shape': shape, 'name': string(node.layer_name)} + node.fluid_code.add_layer("reshape", + inputs=input, + output=node, + param_attr=attr) + + def Normalize(self, node): + assert len( + node.inputs) == 1, 'The count of Normalize node\'s input is not 1.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + if self.is_Scale(input): + tmp = self.graph.get_bottom_node(input, idx=0, copy=True) + if self.is_BN(tmp): + input = tmp + params = node.layer.norm_param + across_spatial = params.across_spatial + channel_shared = params.channel_shared + assert across_spatial == False, "Only support across_spatial == False for Normalize" + attr = {'axis': 1, 'name': string(node.layer_name + '_l2')} + node.fluid_code.add_layer("l2_normalize", + inputs=input, + output=node.layer_name + '_l2', + param_attr=attr) + input_name = self.get_input_name(input) + data = node.data + data = self.adjust_parameters(node, data) + self.weights[node.layer_name + '_scale'] = data[0] + node.fluid_code.add_note( + '{}_scale_attr = ParamAttr(name=\'{}\')'.format( + node.layer_name, node.layer_name + '_scale')) + attr = { + 'shape': [1] if channel_shared else [node.input_shape[0][1]], + 'dtype': '{}.dtype'.format(input_name), + 'attr': '{}_scale_attr'.format(node.layer_name), + 'name': string(node.layer_name + '_param') + } + node.fluid_code.add_layer("create_parameter", + inputs=None, + output=node.layer_name + '_scale_param', + param_attr=attr) + attr = { + 'axis': -1 if channel_shared else 1, + 'name': string(node.layer_name + '_mul') + } + node.fluid_code.add_layer("elementwise_mul", + inputs=node.layer_name + '_l2, ' + + node.layer_name + '_scale_param', + output=node, + param_attr=attr) + + def Permute(self, node): + assert len( + node.inputs) == 1, 'The count of Permute node\'s input is not 1.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + if self.is_Scale(input): + tmp = self.graph.get_bottom_node(input, idx=0, copy=True) + if self.is_BN(tmp): + input = tmp + params = node.layer.permute_param + order = list(params.order) + attr = {'order': order, 'name': string(node.layer_name)} + node.fluid_code.add_layer("transpose", + inputs=input, + output=node, + param_attr=attr) + + def Power(self, node): + assert len( + node.inputs) == 1, 'The count of Permute node\'s input is not 1.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + if self.is_Scale(input): + tmp = self.graph.get_bottom_node(input, idx=0, copy=True) + if self.is_BN(tmp): + input = tmp + params = node.layer.power_param + power = params.power + scale = params.scale + shift = params.shift + attr = { + 'scale': scale, + 'bias': shift, + 'bias_after_scale': True, + 'name': string(node.layer_name + '_scale') + } + node.fluid_code.add_layer("scale", + inputs=input, + output=node, + param_attr=attr) + attr = {'factor': power, 'name': string(node.layer_name)} + node.fluid_code.add_layer("pow", + inputs=node, + output=node, + param_attr=attr) + + def PriorBox(self, node): + assert len( + node.inputs) == 2, 'The count of PriorBox node\'s input is not 2.' + input1 = self.graph.get_bottom_node(node, idx=0, copy=True) + if self.is_Scale(input1): + tmp = self.graph.get_bottom_node(input1, idx=0, copy=True) + if self.is_BN(tmp): + input1 = tmp + input2 = self.graph.get_bottom_node(node, idx=1, copy=True) + if self.is_Scale(input2): + tmp = self.graph.get_bottom_node(input2, idx=0, copy=True) + if self.is_BN(tmp): + input2 = tmp + input_dict = {'input': input1, 'image': input2} + params = node.layer.prior_box_param + step = getattr(params, 'step', 0.0) + offset = getattr(params, 'offset', 0.5) + min_size = list(params.min_size) + max_size = list(params.max_size) + aspect_ratio = list(params.aspect_ratio) + flip = getattr(params, 'flip', False) + clip = getattr(params, 'clip', False) + variance = list(getattr(params, 'variance', [0.1, 0.1, 0.2, 0.2])) + steps = tuple(step) if type(step) is list or type(step) is tuple else ( + step, step) + attr = { + 'min_sizes': min_size, + 'max_sizes': max_size, + 'aspect_ratios': aspect_ratio, + 'variance': variance, + 'flip': flip, + 'clip': clip, + 'step': steps, + 'offset': offset, + 'min_max_aspect_ratios_order': True, + 'name': string(node.layer_name) + } + node.fluid_code.add_layer("prior_box", + inputs=input_dict, + output='{}_box, {}_var'.format( + node.layer_name, node.layer_name), + param_attr=attr) + attr = { + 'shape': [1, 1, -1], + } + node.fluid_code.add_layer("reshape", + inputs='{}_box'.format(node.layer_name), + output='{}_box'.format(node.layer_name), + param_attr=attr) + attr = { + 'shape': [1, 1, -1], + } + node.fluid_code.add_layer("reshape", + inputs='{}_var'.format(node.layer_name), + output='{}_var'.format(node.layer_name), + param_attr=attr) + attr = {'axis': 1, 'name': string(node.layer_name + '_concat')} + node.fluid_code.add_layer("concat", + inputs='[{}_box, {}_var]'.format( + node.layer_name, node.layer_name), + output=node, + param_attr=attr) + + def Reduction(self, node): + assert len( + node.inputs) == 1, 'The count of Reduction node\'s input is not 1.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + if self.is_Scale(input): + tmp = self.graph.get_bottom_node(input, idx=0, copy=True) + if self.is_BN(tmp): + input = tmp + params = node.layer.reduction_param + operation = params.operation + axis = params.axis + coeff = params.coeff + assert operation >= 1 and operation <= 4, "reduction reduction [%s] error" % ( + operation) + input_len = len(node.input_shape[0]) + if axis < 0: + axis += input_len + 1 + dim = list(range(input_len)) + if operation == 1: ## operation = SUM + attr = { + 'dim': dim[axis:], + 'keep_dim': False, + 'name': string(node.layer_name) + } + node.fluid_code.add_layer("reduce_sum", + inputs=input, + output=node, + param_attr=attr) + elif operation == 2: ## operation = ASUM + attr = {'name': string(node.layer_name + '_abs')} + node.fluid_code.add_layer("abs", + inputs=input, + output=node, + param_attr=attr) + attr = { + 'dim': dim[axis:], + 'keep_dim': False, + 'name': string(node.layer_name) + } + node.fluid_code.add_layer("reduce_sum", + inputs=node, + output=node, + param_attr=attr) + elif operation == 3: ## operation = SUMSQ + attr = {'factor': 2.0, 'name': string(node.layer_name + '_pow')} + node.fluid_code.add_layer("pow", + inputs=input, + output=node, + param_attr=attr) + attr = { + 'dim': dim[axis:], + 'keep_dim': False, + 'name': string(node.layer_name) + } + node.fluid_code.add_layer("reduce_sum", + inputs=node, + output=node, + param_attr=attr) + else: ## operation = MEAN + attr = { + 'dim': dim[axis:], + 'keep_dim': False, + 'name': string(node.layer_name) + } + node.fluid_code.add_layer("reduce_mean", + inputs=node, + output=node, + param_attr=attr) + attr = {'scale': coeff} + node.fluid_code.add_layer("scale", + inputs=node, + output=node, + param_attr=attr) + + def ROIPooling(self, node): + assert len( + node.inputs) == 2, 'The count of ROIPooling node\'s input is not 2.' + input1 = self.graph.get_bottom_node(node, idx=0, copy=True) + if self.is_Scale(input1): + tmp = self.graph.get_bottom_node(input1, idx=0, copy=True) + if self.is_BN(tmp): + input1 = tmp + input2 = self.graph.get_bottom_node(node, idx=1, copy=True) + if self.is_Scale(input2): + tmp = self.graph.get_bottom_node(input2, idx=0, copy=True) + if self.is_BN(tmp): + input2 = tmp + attr = {'axes': [1], 'starts': [1], 'ends': [5]} + node.fluid_code.add_layer("slice", + inputs=input2, + output=input2, + param_attr=attr) + input_dict = {'input': input1, 'rois': input2} + params = node.layer.roi_pooling_param + attr = { + 'pooled_w': params.pooled_w, + 'pooled_h': params.pooled_h, + 'spatial_scale': params.spatial_scale, + 'name': string(node.layer_name) + } + node.fluid_code.add_layer("roi_pool", + inputs=input_dict, + output=node, + param_attr=attr) + + def Select(self, node): + assert len( + node.inputs) == 1, 'The count of Select node\'s input is not 2.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + if self.is_Scale(input): + tmp = self.graph.get_bottom_node(input, idx=0, copy=True) + if self.is_BN(tmp): + input = tmp + params = node.layer.select_param + slice_point = list(params.slice_point) + axis = params.axis + maxint32 = 2147483647 + slice_point = [0] + slice_point + slice_point.append(maxint32) + i = 0 + node.fluid_code.add_note('{} = []'.format(node.layer_name)) + for i in range(len(slice_point)): + attr = { + 'axes': [axis], + 'starts': [slice_point[i]], + 'ends': [slice_point[i + 1]], + 'name': string(node.layer_name + '_' + str(i)) + } + node.fluid_code.add_layer("slice", + inputs=input, + output=string(node.layer_name + '_' + + str(i)), + param_attr=attr) + node.fluid_code.add_note('{}.append({})'.format( + node.layer_name, node.layer_name + '_' + str(i))) + if i == len(slice_point) - 2: + break