提交 ffd438e4 编写于 作者: S SunAhong1993

fix the slice

上级 d0ef318d
...@@ -97,6 +97,11 @@ class OpMapper(object): ...@@ -97,6 +97,11 @@ class OpMapper(object):
import model import model
try: try:
inputs, outputs = model.x2paddle_net() inputs, outputs = model.x2paddle_net()
for i, out in enumerate(outputs):
if isinstance(out, list):
for out_part in out:
outputs.append(out_part)
del outputs[i]
input_names = [input.name for input in inputs] input_names = [input.name for input in inputs]
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
......
...@@ -26,6 +26,11 @@ def string(param): ...@@ -26,6 +26,11 @@ def string(param):
def run_net(param_dir="./"): def run_net(param_dir="./"):
import os import os
inputs, outputs = x2paddle_net() inputs, outputs = x2paddle_net()
for i, out in enumerate(outputs):
if isinstance(out, list):
for out_part in out:
outputs.append(out_part)
del outputs[i]
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
......
...@@ -399,9 +399,22 @@ class CaffeOpMapper(OpMapper): ...@@ -399,9 +399,22 @@ class CaffeOpMapper(OpMapper):
assert len( assert len(
node.inputs) == 1, 'The count of Slice node\'s input is not 1.' node.inputs) == 1, 'The count of Slice node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
top_len = len(node.layer.top)
params = node.layer.slice_param params = node.layer.slice_param
axis = params.axis axis = params.axis
slice_dim = params.slice_dim
if slice_dim != 1 and axis == 1:
axis = slice_dim
points = list(params.slice_point) points = list(params.slice_point)
if len(points) == 0:
dims = node.input_shape[0][axis]
assert dims % top_len == 0, "the parameter of Slice is wrong"
part = dims / top_len
t = part
while t < dims:
points.append(int(t))
t += part
maxint32 = 2147483647 maxint32 = 2147483647
points = [0] + points points = [0] + points
points.append(maxint32) points.append(maxint32)
...@@ -421,7 +434,7 @@ class CaffeOpMapper(OpMapper): ...@@ -421,7 +434,7 @@ class CaffeOpMapper(OpMapper):
node.layer_name, node.layer_name + '_' + str(i))) node.layer_name, node.layer_name + '_' + str(i)))
if i == len(points) - 2: if i == len(points) - 2:
break break
def Concat(self, node): def Concat(self, node):
assert len( assert len(
node.inputs node.inputs
...@@ -570,9 +583,7 @@ class CaffeOpMapper(OpMapper): ...@@ -570,9 +583,7 @@ class CaffeOpMapper(OpMapper):
param_attr=attr) param_attr=attr)
def BatchNorm(self, node): def BatchNorm(self, node):
assert len(node.inputs) == 1 and len( assert len(node.inputs) == 1, 'The count of BatchNorm node\'s input is not 1.'
node.outputs
) == 1, 'The count of BatchNorm node\'s input and output is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True) input = self.graph.get_bottom_node(node, idx=0, copy=True)
params = node.layer.batch_norm_param params = node.layer.batch_norm_param
if hasattr(params, 'eps'): if hasattr(params, 'eps'):
......
...@@ -151,10 +151,22 @@ def shape_concat(layer, input_shape): ...@@ -151,10 +151,22 @@ def shape_concat(layer, input_shape):
def shape_slice(layer, input_shape): def shape_slice(layer, input_shape):
inshape = input_shape[0] inshape = input_shape[0]
top_len = len(layer.top)
params = layer.slice_param params = layer.slice_param
axis = params.axis axis = params.axis
count = inshape[axis] slice_dim = params.slice_dim
if slice_dim != 1 and axis == 1:
axis = slice_dim
points = list(params.slice_point) points = list(params.slice_point)
count = inshape[axis]
if len(points) == 0:
assert count % top_len == 0, "the parameter of Slice is wrong"
part = count / top_len
t = part
while t < count:
points.append(int(t))
t += part
points = [0] + points + [count] points = [0] + points + [count]
output_shape = [] output_shape = []
for i in range(len(points)): for i in range(len(points)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册