提交 a61e7d0f 编写于 作者: X Xin Pan

dy gan mostly working

test=develop
上级 03fe3109
......@@ -27,18 +27,21 @@ class Layer(core.Layer):
"""Layers composed of operators."""
def __init__(self, dtype=core.VarDesc.VarType.FP32, name=None):
self._once_built = False
self._built = False
self._dtype = dtype
def parameters(self):
return []
def _build_once(self, inputs):
pass
def __call__(self, *inputs):
if not self._once_built:
if not self._built:
self._build_once(*inputs)
self._once_built = True
outputs = self.forward(*inputs)
self._built = True
return outputs
def forward(self, *inputs):
......
......@@ -220,11 +220,14 @@ class FC(layers.Layer):
self._dtype = dtype
from ..layer_helper import LayerHelper
self._helper = LayerHelper(
'FC',
param_attr=param_attr,
bias_attr=bias_attr,
act=act,
name=name)
'FC', param_attr=param_attr, act=act, name=name)
self._bias_attr = bias_attr
def parameters(self):
if self._bias_attr:
return [self._w, self._b]
else:
return [self._w]
def _build_once(self, input):
input_shape = input.shape
......@@ -255,8 +258,20 @@ class FC(layers.Layer):
inputs={"X": [tmp]},
outputs={"Out": out},
attrs={"use_mkldnn": False})
if not self._bias_attr:
return out
# add bias
pre_activation = self._helper.append_bias_op(
out, dim_start=self._num_flatten_dims)
size = list(out.shape[1:])
if not self._built:
self._b = self._layer.create_parameter(
attr=self._bias_attr, shape=size, dtype=out.dtype, is_bias=True)
bias_out = self.create_variable_for_type_inference(dtype=out.dtype)
self.append_op(
type='elementwise_add',
inputs={'X': [out],
'Y': [self._b]},
outputs={'Out': [bias_out]},
attrs={'axis': 1})
# add activation
return self._helper.append_activation(pre_activation)
return self._helper.append_activation(bias_out)
......@@ -23,6 +23,7 @@ import paddle.fluid as fluid
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.imperative.nn import Conv2D, Pool2D, FC
from test_imperative_base import new_program_scope
from paddle.fluid.imperative.base import to_variable
class Discriminator(fluid.imperative.Layer):
......@@ -31,6 +32,9 @@ class Discriminator(fluid.imperative.Layer):
self._fc1 = FC(size=32, act='elu', name="d_fc1")
self._fc2 = FC(size=1, name="d_fc2")
def parameters(self):
return self._fc1.parameters() + self._fc2.parameters()
def forward(self, inputs):
x = self._fc1(inputs)
return self._fc2(x)
......@@ -43,6 +47,10 @@ class Generator(fluid.imperative.Layer):
self._fc2 = FC(size=64, act='elu', name="g_fc2")
self._fc3 = FC(size=1, name="g_fc3")
def parameters(self):
return self._fc1.parameters() + self._fc2.parameters(
) + self._fc3.parameters()
def forward(self, inputs):
x = self._fc1(inputs)
x = self._fc2(x)
......@@ -56,12 +64,15 @@ class TestImperativeMnist(unittest.TestCase):
startup = fluid.Program()
startup.random_seed = seed
discriminate_p = fluid.Program()
generate_p = fluid.Program()
discriminate_p.random_seed = seed
generate_p.random_seed = seed
scope = fluid.core.Scope()
exe = fluid.Executor(fluid.CPUPlace())
sys.stderr.write('1111\n')
with new_program_scope(
main=discriminate_p, startup=startup, scope=scope):
fluid.default_main_program().random_seed = seed
discriminator = Discriminator()
generator = Generator()
......@@ -70,64 +81,92 @@ class TestImperativeMnist(unittest.TestCase):
noise = fluid.layers.data(
name="noise", shape=[2, 2], append_batch_size=False)
label = fluid.layers.data(
name='label',
shape=[2, 1],
dtype='float32',
append_batch_size=False)
d_real = discriminator(img)
d_loss_real = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_real, label=label))
x=d_real,
label=fluid.layers.fill_constant(
shape=[2, 1], dtype='float32', value=1.0)))
d_fake = discriminator(generator(noise))
d_loss_fake = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_fake, label=label))
x=d_fake,
label=fluid.layers.fill_constant(
shape=[2, 1], dtype='float32', value=0.0)))
d_loss = d_loss_real + d_loss_fake
sgd = SGDOptimizer(learning_rate=1e-3)
sgd.minimize(d_loss)
generate_p = fluid.Program()
with new_program_scope(main=generate_p, startup=startup, scope=scope):
fluid.default_main_program().random_seed = seed
discriminator = Discriminator()
generator = Generator()
noise = fluid.layers.data(
name="noise", shape=[2, 2], append_batch_size=False)
label = fluid.layers.data(
name='label',
shape=[2, 1],
dtype='float32',
append_batch_size=False)
d_fake = discriminator(generator(noise))
g_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_fake, label=label))
x=d_fake,
label=fluid.layers.fill_constant(
shape=[2, 1], dtype='float32', value=1.0)))
sgd = SGDOptimizer(learning_rate=1e-3)
sgd.minimize(g_loss)
img = np.ones([2, 1], np.float32)
label = np.ones([2, 1], np.float32)
noise = np.ones([2, 2], np.float32)
exe.run(startup)
d_loss_val = exe.run(discriminate_p,
feed={'img': img,
'noise': noise,
'label': label},
fetch_list=[d_loss])[0]
g_loss_val = exe.run(generate_p,
feed={'noise': noise,
'label': label},
fetch_list=[g_loss])[0]
sys.stderr.write('d_loss %s, g_loss: %s\n' % (d_loss_val, g_loss_val))
with fluid.scope_guard(scope):
img = np.ones([2, 1], np.float32)
noise = np.ones([2, 2], np.float32)
exe.run(startup)
d_loss_val = exe.run(discriminate_p,
feed={'img': img,
'noise': noise},
fetch_list=[d_loss])[0]
g_loss_val = exe.run(generate_p,
feed={'noise': noise},
fetch_list=[g_loss])[0]
sys.stderr.write('d_loss %s, g_loss: %s\n' %
(d_loss_val, g_loss_val))
static_params = dict()
for param in discriminate_p.global_block().all_parameters():
sys.stderr.write('%s\n' % param.name)
static_params[param.name] = np.array(
scope.find_var(param.name).get_tensor())
dy_params = dict()
with fluid.imperative.guard():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
discriminator = Discriminator()
generator = Generator()
sgd = SGDOptimizer(learning_rate=1e-3)
d_real = discriminator(to_variable(np.ones([2, 1], np.float32)))
d_loss_real = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_real, label=to_variable(np.ones([2, 1], np.float32))))
d_fake = discriminator(
generator(to_variable(np.ones([2, 2], np.float32))))
d_loss_fake = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_fake, label=to_variable(np.zeros([2, 1], np.float32))))
d_loss = d_loss_real + d_loss_fake
sys.stderr.write('dy_d_loss: %s\n' % d_loss._numpy())
d_loss._backward()
sgd.minimize(d_loss)
for p in discriminator.parameters():
dy_params[p.name] = p._numpy()
for k, v in six.iteritems(dy_params):
sys.stderr.write('dy_param_loss: %s: %s\n' % (k, np.sum(v)))
sys.stderr.write('static_param_loss: %s: %s\n' % (k, np.sum(v)))
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册