未验证 提交 405a2f18 编写于 作者: J Jason 提交者: GitHub

Merge pull request #92 from SunAhong1993/develop

fix the slice
......@@ -97,6 +97,11 @@ class OpMapper(object):
import model
try:
inputs, outputs = model.x2paddle_net()
for i, out in enumerate(outputs):
if isinstance(out, list):
for out_part in out:
outputs.append(out_part)
del outputs[i]
input_names = [input.name for input in inputs]
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
......
......@@ -26,6 +26,11 @@ def string(param):
def run_net(param_dir="./"):
import os
inputs, outputs = x2paddle_net()
for i, out in enumerate(outputs):
if isinstance(out, list):
for out_part in out:
outputs.append(out_part)
del outputs[i]
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
......
......@@ -399,9 +399,22 @@ class CaffeOpMapper(OpMapper):
assert len(
node.inputs) == 1, 'The count of Slice node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
top_len = len(node.layer.top)
params = node.layer.slice_param
axis = params.axis
slice_dim = params.slice_dim
if slice_dim != 1 and axis == 1:
axis = slice_dim
points = list(params.slice_point)
if len(points) == 0:
dims = node.input_shape[0][axis]
assert dims % top_len == 0, "the parameter of Slice is wrong"
part = dims / top_len
t = part
while t < dims:
points.append(int(t))
t += part
maxint32 = 2147483647
points = [0] + points
points.append(maxint32)
......@@ -570,9 +583,7 @@ class CaffeOpMapper(OpMapper):
param_attr=attr)
def BatchNorm(self, node):
assert len(node.inputs) == 1 and len(
node.outputs
) == 1, 'The count of BatchNorm node\'s input and output is not 1.'
assert len(node.inputs) == 1, 'The count of BatchNorm node\'s input is not 1.'
input = self.graph.get_bottom_node(node, idx=0, copy=True)
params = node.layer.batch_norm_param
if hasattr(params, 'eps'):
......
......@@ -151,10 +151,22 @@ def shape_concat(layer, input_shape):
def shape_slice(layer, input_shape):
inshape = input_shape[0]
top_len = len(layer.top)
params = layer.slice_param
axis = params.axis
count = inshape[axis]
slice_dim = params.slice_dim
if slice_dim != 1 and axis == 1:
axis = slice_dim
points = list(params.slice_point)
count = inshape[axis]
if len(points) == 0:
assert count % top_len == 0, "the parameter of Slice is wrong"
part = count / top_len
t = part
while t < count:
points.append(int(t))
t += part
points = [0] + points + [count]
output_shape = []
for i in range(len(points)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册