未验证 提交 21e8d135 编写于 作者: J Jason 提交者: GitHub

Merge pull request #8 from Renwb1991/repair_scale

caffe2fluid: repair scale
...@@ -525,16 +525,30 @@ class Network(object): ...@@ -525,16 +525,30 @@ class Network(object):
num_axes) num_axes)
prefix = name + '_' prefix = name + '_'
scale_shape = input.shape[axis:axis + num_axes] if isinstance(input, list) and len(input) == 2:
# for two tensor, here resets axis to 1. Maybe there is a bug for unkown case.
axis = 1
output_shape = input[0].shape[axis:axis + num_axes]
scale_param = input[1]
input = input[0]
else:
output_shape = input.shape[axis:axis + num_axes]
param_attr = fluid.ParamAttr(name=prefix + 'scale') param_attr = fluid.ParamAttr(name=prefix + 'scale')
scale_param = fluid.layers.create_parameter( scale_param = fluid.layers.create_parameter(
shape=scale_shape, shape=output_shape,
dtype=input.dtype, dtype=input.dtype,
name=name, name=name,
attr=param_attr, attr=param_attr,
is_bias=True, is_bias=True,
default_initializer=fluid.initializer.Constant(value=1.0)) default_initializer=fluid.initializer.Constant(value=1.0))
output = fluid.layers.elementwise_mul(
input,
scale_param,
axis=axis,
name=self.get_unique_output_name(name, 'scale_mul'))
scale_shape = output_shape
offset_attr = fluid.ParamAttr(name=prefix + 'offset') offset_attr = fluid.ParamAttr(name=prefix + 'offset')
offset_param = fluid.layers.create_parameter( offset_param = fluid.layers.create_parameter(
shape=scale_shape, shape=scale_shape,
...@@ -544,11 +558,6 @@ class Network(object): ...@@ -544,11 +558,6 @@ class Network(object):
is_bias=True, is_bias=True,
default_initializer=fluid.initializer.Constant(value=0.0)) default_initializer=fluid.initializer.Constant(value=0.0))
output = fluid.layers.elementwise_mul(
input,
scale_param,
axis=axis,
name=self.get_unique_output_name(name, 'scale_mul'))
output = fluid.layers.elementwise_add( output = fluid.layers.elementwise_add(
output, output,
offset_param, offset_param,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册