From ffd438e4ec657f3969d2c86c17fc07b3f4bae068 Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Thu, 22 Aug 2019 14:56:07 +0800 Subject: [PATCH] fix the slice --- x2paddle/core/op_mapper.py | 5 +++++ x2paddle/core/util.py | 5 +++++ x2paddle/op_mapper/caffe_op_mapper.py | 19 +++++++++++++++---- x2paddle/op_mapper/caffe_shape.py | 14 +++++++++++++- 4 files changed, 38 insertions(+), 5 deletions(-) diff --git a/x2paddle/core/op_mapper.py b/x2paddle/core/op_mapper.py index dacf4dd..042b279 100644 --- a/x2paddle/core/op_mapper.py +++ b/x2paddle/core/op_mapper.py @@ -97,6 +97,11 @@ class OpMapper(object): import model try: inputs, outputs = model.x2paddle_net() + for i, out in enumerate(outputs): + if isinstance(out, list): + for out_part in out: + outputs.append(out_part) + del outputs[i] input_names = [input.name for input in inputs] exe = fluid.Executor(fluid.CPUPlace()) exe.run(fluid.default_startup_program()) diff --git a/x2paddle/core/util.py b/x2paddle/core/util.py index 54524b8..faafe83 100644 --- a/x2paddle/core/util.py +++ b/x2paddle/core/util.py @@ -26,6 +26,11 @@ def string(param): def run_net(param_dir="./"): import os inputs, outputs = x2paddle_net() + for i, out in enumerate(outputs): + if isinstance(out, list): + for out_part in out: + outputs.append(out_part) + del outputs[i] exe = fluid.Executor(fluid.CPUPlace()) exe.run(fluid.default_startup_program()) diff --git a/x2paddle/op_mapper/caffe_op_mapper.py b/x2paddle/op_mapper/caffe_op_mapper.py index 29ef329..18a6d9b 100644 --- a/x2paddle/op_mapper/caffe_op_mapper.py +++ b/x2paddle/op_mapper/caffe_op_mapper.py @@ -399,9 +399,22 @@ class CaffeOpMapper(OpMapper): assert len( node.inputs) == 1, 'The count of Slice node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) + top_len = len(node.layer.top) params = node.layer.slice_param axis = params.axis + slice_dim = params.slice_dim + if slice_dim != 1 and axis == 1: + axis = slice_dim points = list(params.slice_point) + + if len(points) == 0: + dims = node.input_shape[0][axis] + assert dims % top_len == 0, "the parameter of Slice is wrong" + part = dims / top_len + t = part + while t < dims: + points.append(int(t)) + t += part maxint32 = 2147483647 points = [0] + points points.append(maxint32) @@ -421,7 +434,7 @@ class CaffeOpMapper(OpMapper): node.layer_name, node.layer_name + '_' + str(i))) if i == len(points) - 2: break - + def Concat(self, node): assert len( node.inputs @@ -570,9 +583,7 @@ class CaffeOpMapper(OpMapper): param_attr=attr) def BatchNorm(self, node): - assert len(node.inputs) == 1 and len( - node.outputs - ) == 1, 'The count of BatchNorm node\'s input and output is not 1.' + assert len(node.inputs) == 1, 'The count of BatchNorm node\'s input is not 1.' input = self.graph.get_bottom_node(node, idx=0, copy=True) params = node.layer.batch_norm_param if hasattr(params, 'eps'): diff --git a/x2paddle/op_mapper/caffe_shape.py b/x2paddle/op_mapper/caffe_shape.py index 6a26dfe..f8f3023 100644 --- a/x2paddle/op_mapper/caffe_shape.py +++ b/x2paddle/op_mapper/caffe_shape.py @@ -151,10 +151,22 @@ def shape_concat(layer, input_shape): def shape_slice(layer, input_shape): inshape = input_shape[0] + + top_len = len(layer.top) params = layer.slice_param axis = params.axis - count = inshape[axis] + slice_dim = params.slice_dim + if slice_dim != 1 and axis == 1: + axis = slice_dim points = list(params.slice_point) + count = inshape[axis] + if len(points) == 0: + assert count % top_len == 0, "the parameter of Slice is wrong" + part = count / top_len + t = part + while t < count: + points.append(int(t)) + t += part points = [0] + points + [count] output_shape = [] for i in range(len(points)): -- GitLab