提交 bfa2621f 编写于 作者: X Xin Pan

fix bias

test=develop
上级 9a4314f0
...@@ -221,13 +221,10 @@ class FC(layers.Layer): ...@@ -221,13 +221,10 @@ class FC(layers.Layer):
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
self._helper = LayerHelper( self._helper = LayerHelper(
'FC', param_attr=param_attr, act=act, name=name) 'FC', param_attr=param_attr, act=act, name=name)
self._bias_attr = bias_attr self._bias_attr = bias_attr if bias_attr else ParamAttr()
def parameters(self): def parameters(self):
if self._bias_attr:
return [self._w, self._b] return [self._w, self._b]
else:
return [self._w]
def _build_once(self, input): def _build_once(self, input):
input_shape = input.shape input_shape = input.shape
...@@ -264,10 +261,11 @@ class FC(layers.Layer): ...@@ -264,10 +261,11 @@ class FC(layers.Layer):
# add bias # add bias
size = list(out.shape[1:]) size = list(out.shape[1:])
if not self._built: if not self._built:
self._b = self._layer.create_parameter( self._b = self._helper.create_parameter(
attr=self._bias_attr, shape=size, dtype=out.dtype, is_bias=True) attr=self._bias_attr, shape=size, dtype=out.dtype, is_bias=True)
bias_out = self.create_variable_for_type_inference(dtype=out.dtype) bias_out = self._helper.create_variable_for_type_inference(
self.append_op( dtype=out.dtype)
self._helper.append_op(
type='elementwise_add', type='elementwise_add',
inputs={'X': [out], inputs={'X': [out],
'Y': [self._b]}, 'Y': [self._b]},
......
...@@ -405,8 +405,7 @@ class LayerHelper(object): ...@@ -405,8 +405,7 @@ class LayerHelper(object):
""" """
size = list(input_var.shape[dim_start:dim_end]) size = list(input_var.shape[dim_start:dim_end])
bias_attr = self.bias_attr bias_attr = self.bias_attr
if not bias_attr: assert bias_attr is not None
return input_var
b = self.create_parameter( b = self.create_parameter(
attr=bias_attr, shape=size, dtype=input_var.dtype, is_bias=True) attr=bias_attr, shape=size, dtype=input_var.dtype, is_bias=True)
......
...@@ -121,21 +121,21 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -121,21 +121,21 @@ class TestImperativeMnist(unittest.TestCase):
img = np.ones([2, 1], np.float32) img = np.ones([2, 1], np.float32)
noise = np.ones([2, 2], np.float32) noise = np.ones([2, 2], np.float32)
exe.run(startup) exe.run(startup)
d_loss_val = exe.run(discriminate_p, static_d_loss = exe.run(discriminate_p,
feed={'img': img, feed={'img': img,
'noise': noise}, 'noise': noise},
fetch_list=[d_loss])[0] fetch_list=[d_loss])[0]
g_loss_val = exe.run(generate_p, static_g_loss = exe.run(generate_p,
feed={'noise': noise}, feed={'noise': noise},
fetch_list=[g_loss])[0] fetch_list=[g_loss])[0]
# generate_p contains all parameters needed.
for param in generate_p.global_block().all_parameters(): for param in generate_p.global_block().all_parameters():
static_params[param.name] = np.array( static_params[param.name] = np.array(
scope.find_var(param.name).get_tensor()) scope.find_var(param.name).get_tensor())
sys.stderr.write( sys.stderr.write(
'static_param_loss: %s: %s\n' % 'static_param_loss: %s: %s\n' %
(param.name, np.sum(static_params[param.name]))) (param.name, np.sum(static_params[param.name])))
sys.stderr.write('d_loss %s, g_loss: %s\n' %
(d_loss_val, g_loss_val))
dy_params = dict() dy_params = dict()
with fluid.imperative.guard(): with fluid.imperative.guard():
...@@ -181,8 +181,14 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -181,8 +181,14 @@ class TestImperativeMnist(unittest.TestCase):
dy_params[p.name] = p._numpy() dy_params[p.name] = p._numpy()
sys.stderr.write('dy_param_loss: %s: %s\n' % sys.stderr.write('dy_param_loss: %s: %s\n' %
(p.name, np.sum(dy_params[p.name]))) (p.name, np.sum(dy_params[p.name])))
sys.stderr.write('dy_d_loss: %s, dy_g_loss: %s\n' %
(d_loss._numpy(), g_loss._numpy())) dy_g_loss = g_loss._numpy()
dy_d_loss = d_loss._numpy()
self.assertEqual(dy_g_loss, static_g_loss)
self.assertEqual(dy_d_loss, static_d_loss)
for k, v in six.iteritems(dy_params):
self.assertTrue(np.allclose(v, static_params[k]))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册