提交 9a4f5786 编写于 作者: R Renwb1991 提交者: qingqing01

caffe2fluid:fix bug in scale (#1668)

上级 0e1ef228
...@@ -440,7 +440,8 @@ class Network(object): ...@@ -440,7 +440,8 @@ class Network(object):
if need_transpose: if need_transpose:
order = range(dims) order = range(dims)
order.remove(axis).append(axis) order.remove(axis)
order.append(axis)
input = fluid.layers.transpose( input = fluid.layers.transpose(
input, input,
perm=order, perm=order,
...@@ -525,11 +526,21 @@ class Network(object): ...@@ -525,11 +526,21 @@ class Network(object):
scale_shape = input.shape[axis:axis + num_axes] scale_shape = input.shape[axis:axis + num_axes]
param_attr = fluid.ParamAttr(name=prefix + 'scale') param_attr = fluid.ParamAttr(name=prefix + 'scale')
scale_param = fluid.layers.create_parameter( scale_param = fluid.layers.create_parameter(
shape=scale_shape, dtype=input.dtype, name=name, attr=param_attr) shape=scale_shape,
dtype=input.dtype,
name=name,
attr=param_attr,
is_bias=True,
default_initializer=fluid.initializer.Constant(value=1.0))
offset_attr = fluid.ParamAttr(name=prefix + 'offset') offset_attr = fluid.ParamAttr(name=prefix + 'offset')
offset_param = fluid.layers.create_parameter( offset_param = fluid.layers.create_parameter(
shape=scale_shape, dtype=input.dtype, name=name, attr=offset_attr) shape=scale_shape,
dtype=input.dtype,
name=name,
attr=offset_attr,
is_bias=True,
default_initializer=fluid.initializer.Constant(value=0.0))
output = fluid.layers.elementwise_mul( output = fluid.layers.elementwise_mul(
input, input,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册