提交 bfa2621f 编写于 作者: X Xin Pan

fix bias

test=develop
上级 9a4314f0
......@@ -221,13 +221,10 @@ class FC(layers.Layer):
from ..layer_helper import LayerHelper
self._helper = LayerHelper(
'FC', param_attr=param_attr, act=act, name=name)
self._bias_attr = bias_attr
self._bias_attr = bias_attr if bias_attr else ParamAttr()
def parameters(self):
if self._bias_attr:
return [self._w, self._b]
else:
return [self._w]
def _build_once(self, input):
input_shape = input.shape
......@@ -264,10 +261,11 @@ class FC(layers.Layer):
# add bias
size = list(out.shape[1:])
if not self._built:
self._b = self._layer.create_parameter(
self._b = self._helper.create_parameter(
attr=self._bias_attr, shape=size, dtype=out.dtype, is_bias=True)
bias_out = self.create_variable_for_type_inference(dtype=out.dtype)
self.append_op(
bias_out = self._helper.create_variable_for_type_inference(
dtype=out.dtype)
self._helper.append_op(
type='elementwise_add',
inputs={'X': [out],
'Y': [self._b]},
......
......@@ -405,8 +405,7 @@ class LayerHelper(object):
"""
size = list(input_var.shape[dim_start:dim_end])
bias_attr = self.bias_attr
if not bias_attr:
return input_var
assert bias_attr is not None
b = self.create_parameter(
attr=bias_attr, shape=size, dtype=input_var.dtype, is_bias=True)
......
......@@ -121,21 +121,21 @@ class TestImperativeMnist(unittest.TestCase):
img = np.ones([2, 1], np.float32)
noise = np.ones([2, 2], np.float32)
exe.run(startup)
d_loss_val = exe.run(discriminate_p,
static_d_loss = exe.run(discriminate_p,
feed={'img': img,
'noise': noise},
fetch_list=[d_loss])[0]
g_loss_val = exe.run(generate_p,
static_g_loss = exe.run(generate_p,
feed={'noise': noise},
fetch_list=[g_loss])[0]
# generate_p contains all parameters needed.
for param in generate_p.global_block().all_parameters():
static_params[param.name] = np.array(
scope.find_var(param.name).get_tensor())
sys.stderr.write(
'static_param_loss: %s: %s\n' %
(param.name, np.sum(static_params[param.name])))
sys.stderr.write('d_loss %s, g_loss: %s\n' %
(d_loss_val, g_loss_val))
dy_params = dict()
with fluid.imperative.guard():
......@@ -181,8 +181,14 @@ class TestImperativeMnist(unittest.TestCase):
dy_params[p.name] = p._numpy()
sys.stderr.write('dy_param_loss: %s: %s\n' %
(p.name, np.sum(dy_params[p.name])))
sys.stderr.write('dy_d_loss: %s, dy_g_loss: %s\n' %
(d_loss._numpy(), g_loss._numpy()))
dy_g_loss = g_loss._numpy()
dy_d_loss = d_loss._numpy()
self.assertEqual(dy_g_loss, static_g_loss)
self.assertEqual(dy_d_loss, static_d_loss)
for k, v in six.iteritems(dy_params):
self.assertTrue(np.allclose(v, static_params[k]))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册