From 3021f600d9bc73e84e6fccab4618859949497dfa Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Tue, 23 Jul 2019 20:34:56 +0800 Subject: [PATCH] add caffe op part1 --- x2paddle/decoder/caffe_decoder.py | 7 +- x2paddle/decoder/caffe_shape.py | 55 +++- x2paddle/op_mapper/caffe_op_mapper.py | 351 +++++++++++++++++++++++--- 3 files changed, 372 insertions(+), 41 deletions(-) diff --git a/x2paddle/decoder/caffe_decoder.py b/x2paddle/decoder/caffe_decoder.py index 101d445..7b9f8a0 100644 --- a/x2paddle/decoder/caffe_decoder.py +++ b/x2paddle/decoder/caffe_decoder.py @@ -60,11 +60,14 @@ class CaffeResolver(object): class CaffeGraphNode(GraphNode): def __init__(self, layer, layer_name=None): if layer_name is None: - super(CaffeGraphNode, self).__init__(layer, layer.name.replace('/', '_')) + super(CaffeGraphNode, self).__init__(layer, + layer.name.replace('/', '_')) else: - super(CaffeGraphNode, self).__init__(layer, layer_name.replace('/', '_')) + super(CaffeGraphNode, self).__init__(layer, + layer_name.replace('/', '_')) self.layer_type = layer.type self.fluid_code = FluidCode() + self.data = None def set_params(self, params): self.data = params diff --git a/x2paddle/decoder/caffe_shape.py b/x2paddle/decoder/caffe_shape.py index ca6d514..e87d1f3 100644 --- a/x2paddle/decoder/caffe_shape.py +++ b/x2paddle/decoder/caffe_shape.py @@ -110,7 +110,7 @@ def get_strided_kernel_output_shape(params, input_shape, round_func): round_func) has_c_o = hasattr(params, 'num_output') c = params.num_output if has_c_o else input_shape[1] - + return [[input_shape[0], c, o_h, o_w]] @@ -169,6 +169,7 @@ def shape_softmax(layer, input_shape): def shape_input(layer, input_shape): return [list(layer.input_param.shape[0].dim)] + def shape_concat(layer, input_shape): params = layer.concat_param axis = params.axis @@ -178,4 +179,54 @@ def shape_concat(layer, input_shape): output_shape = shape else: output_shape[axis] += shape[axis] - return [output_shape] \ No newline at end of file + return [output_shape] + + +def shape_slice(layer, input_shape): + inshape = input_shape[0] + params = layer.slice_param + axis = params.axis + count = inshape[axis] + points = list(params.slice_point) + points = [0] + points + [count] + output_shape = [] + for i in range(len(points)): + shape = inshape + size = points[i + 1] - points[i] + shape[axis] = size + output_shape.append(shape) + if i == len(points) - 2: + break + return output_shape + + +def shape_prelu(layer, input_shape): + return input_shape + + +def shape_sigmoid(layer, input_shape): + return input_shape + + +def shape_absval(layer, input_shape): + return input_shape + + +def shape_accuracy(layer, input_shape): + return [[1]] + + +def shape_tanh(layer, input_shape): + return input_shape + + +def shape_eltwise(layer, input_shape): + return [input_shape[0]] + + +def shape_batchnorm(layer, input_shape): + return input_shape + + +def shape_scale(layer, input_shape): + return input_shape diff --git a/x2paddle/op_mapper/caffe_op_mapper.py b/x2paddle/op_mapper/caffe_op_mapper.py index 9161588..cf8d3b7 100644 --- a/x2paddle/op_mapper/caffe_op_mapper.py +++ b/x2paddle/op_mapper/caffe_op_mapper.py @@ -106,11 +106,17 @@ class CaffeOpMapper(OpMapper): raise ValueError('Unable to determine kernel parameter!') return default - def get_kernel_parameters(self, kind, params, kernel_default=[1, 1]): + def get_kernel_parameters(self, kind, params): assert kind in ['Convolution', 'Pooling', 'Deconvolution'] - k_h = self.get_kernel_value(params.kernel_h, params.kernel_size, 0, kernel_default[0]) - k_w = self.get_kernel_value(params.kernel_w, params.kernel_size, 1, kernel_default[1]) + k_h = self.get_kernel_value(params.kernel_h, + params.kernel_size, + 0, + default=1) + k_w = self.get_kernel_value(params.kernel_w, + params.kernel_size, + 1, + default=1) s_h = self.get_kernel_value(params.stride_h, params.stride, 0, @@ -144,6 +150,12 @@ class CaffeOpMapper(OpMapper): return c_o, kernel, stride, pad, dilation, group + def get_input_name(self, node): + if hasattr(node, "index"): + return node.layer_name + "[{}]".format(node.index) + else: + return node.layer_name + def Input(self, node): shape = list(node.layer.input_param.shape[0].dim)[1:] dtype = 'float32' @@ -159,6 +171,8 @@ class CaffeOpMapper(OpMapper): def Convolution(self, node): data = node.data + assert data is not None, 'The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.'.format( + node.layer_name, node.layer_type) data = self.adjust_parameters(node, data) self.weights[node.layer_name + '_weights'] = data[0] if len(data) == 2: @@ -187,6 +201,8 @@ class CaffeOpMapper(OpMapper): def Deconvolution(self, node): data = node.data + assert data is not None, 'The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.'.format( + node.layer_name, node.layer_type) data = self.adjust_parameters(node, data) self.weights[node.layer_name + '_weights'] = data[0] if len(data) == 2: @@ -216,13 +232,10 @@ class CaffeOpMapper(OpMapper): def Pooling(self, node): params = node.layer.pooling_param - shape = node.input_shape[0] global_pool = getattr(params, 'global_pooling', False) kernel_default = [1, 1] - if global_pool: - kernel_default = [shape[2],shape[3]] channel, kernel, stride, pad, dilation, group = self.get_kernel_parameters( - node.layer_type, params, kernel_default=kernel_default) + node.layer_type, params) if params.pool == 0: pool_type = 'max' else: @@ -237,6 +250,7 @@ class CaffeOpMapper(OpMapper): 'ceil_mode': True, 'pool_type': string(pool_type), 'exclusive': True, + 'global_pooling': global_pool, 'name': string(node.layer_name) } node.fluid_code.add_layer("pool2d", @@ -279,6 +293,8 @@ class CaffeOpMapper(OpMapper): def InnerProduct(self, node): data = node.data + assert data is not None, 'The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.'.format( + node.layer_name, node.layer_type) data = self.adjust_parameters(node, data) # Reshape the parameters to Paddle's ordering transpose_order = (1, 0) @@ -361,47 +377,308 @@ class CaffeOpMapper(OpMapper): params = node.layer.slice_param axis = params.axis points = list(params.slice_point) - shape = node.input_shape[0] - count = shape[axis] - sections = [] - idx = 0 - for p in points: - if idx == 0: - sections.append(p - 0) - elif idx == len(points) - 1: - sections.append(count - p) - else: - sections.append(points[idx + 1] - p) - idx += 1 - attr = { - 'dim': axis, - 'num_or_sections': sections, - 'name': string(node.layer_name + '_slice') - } - node.fluid_code.add_layer("split", - inputs=input, - output=node, - param_attr=attr) + maxint32 = 2147483647 + points = [0] + points + points.append(maxint32) + i = 0 + node.fluid_code.add_note('{} = []'.format(node.layer_name)) + for i in range(len(points)): + attr = { + 'axes': [axis], + 'starts': [points[i]], + 'ends': [points[i + 1]], + 'name': string(node.layer_name + '_' + str(i)) + } + node.fluid_code.add_layer("slice", + inputs=input, + output=string(node.layer_name + '_' + + str(i)), + param_attr=attr) + node.fluid_code.add_note('{}.append({})'.format( + node.layer_name, node.layer_name + '_' + str(i))) + if i == len(points) - 2: + break def Concat(self, node): assert len( - node.inputs) > 1, 'The count of Concat node\'s input is not more than 1.' + node.inputs + ) > 1, 'The count of Concat node\'s input is not more than 1.' inputs = [] for i in range(len(node.inputs)): input = self.graph.get_bottom_node(node, idx=i, copy=True) inputs.append(input) params = node.layer.concat_param axis = params.axis + attr = {'axis': axis, 'name': string(node.layer_name)} + node.fluid_code.add_layer("concat", + inputs=inputs, + output=node, + param_attr=attr) + + def PReLU(self, node): + assert len( + node.inputs) == 1, 'The count of PReLU node\'s input is not 1.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + params = node.layer.prelu_param + mode_bool = params.channel_shared + if mode_bool: + mode = 'all' + else: + mode = 'channel' + data = node.data + assert data is not None, 'The parameter of {} (type is {}) is not set. You need to use python package of caffe to set the default value.'.format( + node.layer_name, node.layer_type) + self.weights[node.layer_name + '_weights'] = data[0] attr = { - 'axis': axis, - 'name': string(node.layer_name + '_concat') + 'mode': mode, + 'param_attr': string(node.layer_name + '_weights'), + 'name': string(node.layer_name) } - node.fluid_code.add_layer("concat", + node.fluid_code.add_layer("prelu", + inputs=input, + output=node, + param_attr=attr) + + def Sigmoid(self, node): + assert len( + node.inputs) == 1, 'The count of PReLU node\'s input is not 1.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + attr = {'name': string(node.layer_name)} + node.fluid_code.add_layer("sigmoid", + inputs=input, + output=node, + param_attr=attr) + + def AbsVal(self, node): + assert len( + node.inputs) == 1, 'The count of PReLU node\'s input is not 1.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + attr = {'name': string(node.layer_name)} + node.fluid_code.add_layer("absval", + inputs=input, + output=node, + param_attr=attr) + + def Accuracy(self, node): + assert len( + node.inputs) == 2, 'The count of Accuracy node\'s input is not 2.' + inputs = [] + inputs[0] = None + inputs[1] = None + i = 0 + for shape in node.input_shape: + if shape[1] == 1: + inputs[1] = self.graph.get_bottom_node(node, idx=i, copy=True) + else: + inputs[0] = self.graph.get_bottom_node(node, idx=i, copy=True) + i += 1 + params = node.layer.accuracy_param + top_k = params.top_k + axis = params.axis + ignore_label = params.ignore_label + # TODO(syf) + assert axis == 1, 'PaddlePaddle can not support the situation when the axis is not 1.' + assert not ignore_label >= 0, 'PaddlePaddle can not support the situation when the model has ignore label.' + attr = {'k': top_k} + node.fluid_code.add_layer("accuracy", inputs=inputs, output=node, param_attr=attr) - - - - - \ No newline at end of file + + def TanH(self, node): + assert len( + node.inputs) == 1, 'The count of TanH node\'s input is not 1.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + attr = {'name': string(node.layer_name)} + node.fluid_code.add_layer("tanh", + inputs=input, + output=node, + param_attr=attr) + + def Eltwise(self, node): + assert len( + node.inputs) == 2, 'The count of TanH node\'s input is not 2.' + params = node.layer.eltwise_param + mode = params.operation + inputs = [] + inputs.append(self.graph.get_bottom_node(node, idx=0, copy=True)) + inputs.append(self.graph.get_bottom_node(node, idx=1, copy=True)) + if mode == 0: + attr = {'act': None, 'name': string(node.layer_name)} + node.fluid_code.add_layer("elementwise_mul", + inputs=inputs, + output=node, + param_attr=attr) + elif mode == 1: + if hasattr(params, 'coeff') and len(params.coeff) == 2: + coeff = params.coeff + input1_name = self.get_input_name(inputs[0]) + attr = { + 'shape': [1], + 'value': coeff[0], + 'dtype': '{}.dtype'.format(input1_name) + } + node.fluid_code.add_layer("fill_constant", + inputs=None, + output=node.layer_name + '_const1', + param_attr=attr) + attr = {'act': None, 'name': string(node.layer_name + '_mul1')} + node.fluid_code.add_layer("elementwise_mul", + inputs=input1_name + ', ' + + node.layer_name + '_const1', + output=node.layer_name + '_mul1', + param_attr=attr) + input2_name = self.get_input_name(inputs[1]) + attr = { + 'shape': [1], + 'value': coeff[1], + 'dtype': '{}.dtype'.format(input2_name) + } + node.fluid_code.add_layer("fill_constant", + inputs=None, + output=node.layer_name + '_const2', + param_attr=attr) + attr = {'act': None, 'name': string(node.layer_name + '_mul2')} + node.fluid_code.add_layer("elementwise_mul", + inputs=input2_name + ', ' + + node.layer_name + '_const2', + output=node.layer_name + '_mul2', + param_attr=attr) + + attr = {'act': None, 'name': string(node.layer_name)} + node.fluid_code.add_layer("elementwise_add", + inputs='{}_mul1, {}_mul2'.format( + node.layer_name, node.layer_name), + output=node, + param_attr=attr) + else: + attr = {'act': None, 'name': string(node.layer_name)} + node.fluid_code.add_layer("elementwise_add", + inputs=inputs, + output=node, + param_attr=attr) + else: + attr = {'act': None, 'name': string(node.layer_name)} + node.fluid_code.add_layer("elementwise_max", + inputs=inputs, + output=node, + param_attr=attr) + + def BatchNorm(self, node): + assert len(node.inputs) == 1 and len( + node.outputs + ) == 1, 'The count of BatchNorm node\'s input and output is not 1.' + input = self.graph.get_bottom_node(node, idx=0, copy=True) + params = node.layer.batch_norm_param + if hasattr(params, eps): + eps = params.eps + else: + eps = 1e-5 + assert len(node.data) == 3 + node.data = [np.squeeze(i) for i in node.data] + mean, variance, scale = node.data + # Prescale the stats + scaling_factor = 1.0 / scale if scale != 0 else 0 + mean *= scaling_factor + variance *= scaling_factor + self.weights[node.layer_name + '_mean'] = mean + self.weights[node.layer_name + '_variance'] = variance + if self.graph.get_node(node.outputs[0]).layer_type == 'Scale': + data = self.graph.get_node(node.outputs[0]).data + self.weights[node.layer_name + '_scale'] = np.squeeze(data[0]) + self.weights[node.layer_name + '_offset'] = np.squeeze(data[1]) + attr = { + 'is_test': True, + 'param_attr': string(node.layer_name + '_scale'), + 'bias_attr': string(node.layer_name + '_offset'), + 'moving_mean_name': string(node.layer_name + '_mean'), + 'moving_variance_name': string(node.layer_name + '_variance'), + 'epsilon': eps, + 'name': string(node.layer_name) + } + else: + attr = { + 'is_test': True, + 'param_attr': None, + 'bias_attr': None, + 'moving_mean_name': string(node.layer_name + '_mean'), + 'moving_variance_name': string(node.layer_name + '_variance'), + 'epsilon': eps, + 'name': string(node.layer_name) + } + node.fluid_code.add_layer("batch_norm", + inputs=input, + output=node, + param_attr=attr) + + def Scale(self, node): + assert len( + node.outputs) == 1, 'The count of Scale node\'s output is not 1.' + if len(node.inputs) == 1 and self.graph.get_node( + node.inputs[0]).layer_type == 'BatchNorm': + return + else: + self.weights[node.layer_name + '_scale'] = np.squeeze(nose.data[0]) + self.weights[node.layer_name + '_offset'] = np.squeeze(node.data[1]) + params = node.layer.scale_params + axis = params.axis + num_axes = params.num_axes + assert num_axes == 1, "layer scale not support this num_axes[%d] now" % ( + num_axes) + inputs = [] + if len(node.inputs) == 2: + # for two tensor, here resets axis to 1. Maybe there is a bug for unkown case. + axis = 1 + bias_shape = node.input_shape[0][axis:axis + num_axes] + input0 = self.graph.get_bottom_node(node, idx=0, copy=True) + input1 = self.graph.get_bottom_node(node, idx=1, copy=True) + inputs.append(input0) + inputs.append(input1) + attr = {'axis': axis, 'name': string(node.layer_name + '_mul')} + node.fluid_code.add_layer("elementwise_mul", + inputs=inputs, + output=node.layer_name + '_mul', + param_attr=attr) + else: + bias_shape = node.input_shape[0][axis:axis + num_axes] + input0 = self.graph.get_bottom_node(node, idx=0, copy=True) + input0_name = self.get_input_name(input0) + attr = { + 'dtype': '{}.dtype'.formatr(input0_name), + 'shape': bias_shape, + 'name': string(node.layer_name + '_cparam1'), + 'attr': string(node.layer_name + '_scale'), + 'is_bias': True, + 'default_initializer': 'Constant(value=1.0)' + } + node.fluid_code.add_layer("create_parameter", + inputs=None, + output=node, + param_attr=attr) + inputs.append(input0) + inputs.append(node) + attr = {'axis': axis, 'name': string(node.layer_name + '_mul')} + node.fluid_code.add_layer("elementwise_mul", + inputs=inputs, + output=node.layer_name + '_mul', + param_attr=attr) + scale_shape = bias_shape + input0_name = self.get_input_name(input0) + attr = { + 'dtype': '{}.dtype'.formatr(input0_name), + 'shape': scale_shape, + 'name': string(node.layer_name + '_cparam2'), + 'attr': string(node.layer_name + '_offset'), + 'is_bias': True, + 'default_initializer': 'Constant(value=1.0)' + } + node.fluid_code.add_layer("create_parameter", + inputs=None, + output=node.layer_name + '_offset_param', + param_attr=attr) + attr = {'axis': axis, 'name': string(node.layer_name + '_add')} + node.fluid_code.add_layer("elementwise_add", + inputs='{}_mul, {}_offset_param'.format( + node.layer_name, node.layer_name), + output=node, + param_attr=attr) -- GitLab