提交 9a4314f0 编写于 作者: X Xin Pan

imperative gan

test=develop
上级 a61e7d0f
...@@ -101,7 +101,6 @@ class VarBase { ...@@ -101,7 +101,6 @@ class VarBase {
// Owns `var` and `grad` // Owns `var` and `grad`
VarBase(framework::Variable* var, VarBase* grad) VarBase(framework::Variable* var, VarBase* grad)
: pre_op_(nullptr), : pre_op_(nullptr),
pre_op_out_name_(),
pre_op_out_idx_(-1), pre_op_out_idx_(-1),
var_desc_(nullptr), var_desc_(nullptr),
var_(var), var_(var),
...@@ -110,7 +109,6 @@ class VarBase { ...@@ -110,7 +109,6 @@ class VarBase {
explicit VarBase(bool stop_gradient) explicit VarBase(bool stop_gradient)
: pre_op_(nullptr), : pre_op_(nullptr),
pre_op_out_name_(),
pre_op_out_idx_(-1), pre_op_out_idx_(-1),
var_desc_(nullptr), var_desc_(nullptr),
var_(new framework::Variable()), var_(new framework::Variable()),
...@@ -127,6 +125,13 @@ class VarBase { ...@@ -127,6 +125,13 @@ class VarBase {
} }
} }
void Clear() {
delete grads_;
grads_ = new VarBase(true);
pre_op_ = nullptr;
pre_op_out_name_ = "";
}
void RunBackward(); void RunBackward();
framework::LoDTensor& GradValue(); framework::LoDTensor& GradValue();
......
...@@ -133,6 +133,7 @@ PYBIND11_MODULE(core, m) { ...@@ -133,6 +133,7 @@ PYBIND11_MODULE(core, m) {
[](imperative::VarBase &self) { self.RunBackward(); }) [](imperative::VarBase &self) { self.RunBackward(); })
.def("_grad_name", &imperative::VarBase::GradName) .def("_grad_name", &imperative::VarBase::GradName)
.def("_grad_value", &imperative::VarBase::GradValue) .def("_grad_value", &imperative::VarBase::GradValue)
.def("_clear", &imperative::VarBase::Clear)
.def("_grad_ivar", .def("_grad_ivar",
[](const imperative::VarBase &self) { return self.grads_; }, [](const imperative::VarBase &self) { return self.grads_; },
py::return_value_policy::reference) py::return_value_policy::reference)
......
...@@ -388,6 +388,9 @@ class Variable(object): ...@@ -388,6 +388,9 @@ class Variable(object):
def _gradient(self): def _gradient(self):
return np.array(self._ivar._grad_value()) return np.array(self._ivar._grad_value())
def _clear(self):
self._ivar._clear()
def __str__(self): def __str__(self):
return self.to_string(True) return self.to_string(True)
......
...@@ -69,8 +69,6 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -69,8 +69,6 @@ class TestImperativeMnist(unittest.TestCase):
generate_p.random_seed = seed generate_p.random_seed = seed
scope = fluid.core.Scope() scope = fluid.core.Scope()
exe = fluid.Executor(fluid.CPUPlace())
sys.stderr.write('1111\n')
with new_program_scope( with new_program_scope(
main=discriminate_p, startup=startup, scope=scope): main=discriminate_p, startup=startup, scope=scope):
discriminator = Discriminator() discriminator = Discriminator()
...@@ -117,6 +115,8 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -117,6 +115,8 @@ class TestImperativeMnist(unittest.TestCase):
sgd = SGDOptimizer(learning_rate=1e-3) sgd = SGDOptimizer(learning_rate=1e-3)
sgd.minimize(g_loss) sgd.minimize(g_loss)
exe = fluid.Executor(fluid.CPUPlace())
static_params = dict()
with fluid.scope_guard(scope): with fluid.scope_guard(scope):
img = np.ones([2, 1], np.float32) img = np.ones([2, 1], np.float32)
noise = np.ones([2, 2], np.float32) noise = np.ones([2, 2], np.float32)
...@@ -128,14 +128,14 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -128,14 +128,14 @@ class TestImperativeMnist(unittest.TestCase):
g_loss_val = exe.run(generate_p, g_loss_val = exe.run(generate_p,
feed={'noise': noise}, feed={'noise': noise},
fetch_list=[g_loss])[0] fetch_list=[g_loss])[0]
sys.stderr.write('d_loss %s, g_loss: %s\n' % for param in generate_p.global_block().all_parameters():
(d_loss_val, g_loss_val))
static_params = dict()
for param in discriminate_p.global_block().all_parameters():
sys.stderr.write('%s\n' % param.name)
static_params[param.name] = np.array( static_params[param.name] = np.array(
scope.find_var(param.name).get_tensor()) scope.find_var(param.name).get_tensor())
sys.stderr.write(
'static_param_loss: %s: %s\n' %
(param.name, np.sum(static_params[param.name])))
sys.stderr.write('d_loss %s, g_loss: %s\n' %
(d_loss_val, g_loss_val))
dy_params = dict() dy_params = dict()
with fluid.imperative.guard(): with fluid.imperative.guard():
...@@ -158,15 +158,31 @@ class TestImperativeMnist(unittest.TestCase): ...@@ -158,15 +158,31 @@ class TestImperativeMnist(unittest.TestCase):
x=d_fake, label=to_variable(np.zeros([2, 1], np.float32)))) x=d_fake, label=to_variable(np.zeros([2, 1], np.float32))))
d_loss = d_loss_real + d_loss_fake d_loss = d_loss_real + d_loss_fake
sys.stderr.write('dy_d_loss: %s\n' % d_loss._numpy())
d_loss._backward() d_loss._backward()
sgd.minimize(d_loss) sgd.minimize(d_loss)
for p in discriminator.parameters(): for p in discriminator.parameters():
dy_params[p.name] = p._numpy() p._clear()
for p in generator.parameters():
p._clear()
for k, v in six.iteritems(dy_params): d_fake = discriminator(
sys.stderr.write('dy_param_loss: %s: %s\n' % (k, np.sum(v))) generator(to_variable(np.ones([2, 2], np.float32))))
sys.stderr.write('static_param_loss: %s: %s\n' % (k, np.sum(v))) g_loss = fluid.layers.reduce_mean(
fluid.layers.sigmoid_cross_entropy_with_logits(
x=d_fake, label=to_variable(np.ones([2, 1], np.float32))))
g_loss._backward()
sgd = SGDOptimizer(learning_rate=1e-3)
sgd.minimize(g_loss)
for p in discriminator.parameters():
dy_params[p.name] = p._numpy()
sys.stderr.write('dy_param_loss: %s: %s\n' %
(p.name, np.sum(dy_params[p.name])))
for p in generator.parameters():
dy_params[p.name] = p._numpy()
sys.stderr.write('dy_param_loss: %s: %s\n' %
(p.name, np.sum(dy_params[p.name])))
sys.stderr.write('dy_d_loss: %s, dy_g_loss: %s\n' %
(d_loss._numpy(), g_loss._numpy()))
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册