提交 18ccff05 编写于 作者: S SunAhong1993

modify the slice and shufflechannel

上级 4c4fb484
...@@ -8,13 +8,7 @@ def shufflechannel_shape(input_shape): ...@@ -8,13 +8,7 @@ def shufflechannel_shape(input_shape):
def shufflechannel_layer(inputs, group=None, input_shape=None, name=None): def shufflechannel_layer(inputs, group=None, input_shape=None, name=None):
input = inputs[0] input = inputs[0]
c_fm = fluid.layers.split(input, num_or_sections=input_shape[0][1], dim=1) out = fluid.layers.shuffle_channel(x=input, group=group)
size = int(input_shape[0][1] / group)
new_c_fm = []
for i in range(size):
for j in range(group):
new_c_fm.append(c_fm[j * size + i])
out = fluid.layers.concat(new_c_fm, axis=1)
return out return out
......
...@@ -450,35 +450,19 @@ class CaffeOpMapper(OpMapper): ...@@ -450,35 +450,19 @@ class CaffeOpMapper(OpMapper):
slice_dim = params.slice_dim slice_dim = params.slice_dim
if slice_dim != 1 and axis == 1: if slice_dim != 1 and axis == 1:
axis = slice_dim axis = slice_dim
points = list(params.slice_point) output_shape = node.output_shape
sections_list = []
if len(points) == 0: for s in output_shape:
dims = node.input_shape[0][axis] sections_list.append(s[axis])
assert dims % top_len == 0, "the parameter of Slice is wrong"
part = dims / top_len
t = part
while t < dims:
points.append(int(t))
t += part
maxint32 = 2147483647
points = [0] + points
points.append(maxint32)
i = 0
node.fluid_code.add_note('{} = []'.format(node.layer_name))
for i in range(len(points)):
attr = { attr = {
'axes': [axis], 'num_or_sections': sections_list,
'starts': [points[i]], 'dim': axis,
'ends': [points[i + 1]] 'name': string(node.layer_name)
} }
node.fluid_code.add_layer("slice", node.fluid_code.add_layer("split",
inputs=input, inputs=input,
output=node.layer_name + '_' + str(i), output=node.layer_name,
param_attr=attr) param_attr=attr)
node.fluid_code.add_note('{}.append({})'.format(
node.layer_name, node.layer_name + '_' + str(i)))
if i == len(points) - 2:
break
def Concat(self, node): def Concat(self, node):
assert len( assert len(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册