提交 c86b3dd6 编写于 作者: M minqiyang

Polish code

test=develop
上级 ddfb9f11
...@@ -133,11 +133,11 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -133,11 +133,11 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
grad_in_vars.push_back(fwd_var_it->second->var_); grad_in_vars.push_back(fwd_var_it->second->var_);
} else { } else {
VarBase* var = vars[var_it->second]; VarBase* var = vars[var_it->second];
if (!var->grads_->IsInitialized()) { if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_); InitVar(var->var_, var->grads_->var_);
} }
// Douts. // Douts.
grad_in_vars.push_back(var->grads_); grad_in_vars.push_back(var->grads_->var_);
} }
} }
} }
...@@ -149,10 +149,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs, ...@@ -149,10 +149,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
auto var_it = grad_to_var->find(grad_outvar); auto var_it = grad_to_var->find(grad_outvar);
PADDLE_ENFORCE(var_it != grad_to_var->end()); PADDLE_ENFORCE(var_it != grad_to_var->end());
VarBase* var = vars[var_it->second]; VarBase* var = vars[var_it->second];
if (!var->grads_->IsInitialized()) { if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_); InitVar(var->var_, var->grads_->var_);
} }
grad_out_vars.push_back(var->grads_); grad_out_vars.push_back(var->grads_->var_);
} }
} }
} }
...@@ -194,13 +194,13 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op, ...@@ -194,13 +194,13 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
grad_input_vars.push_back(out->var_); grad_input_vars.push_back(out->var_);
} }
for (VarBase* out : outputs) { for (VarBase* out : outputs) {
grad_input_vars.push_back(out->grads_); grad_input_vars.push_back(out->grads_->var_);
if (!grad_input_vars.back()->IsInitialized()) { if (!grad_input_vars.back()->IsInitialized()) {
InitVar(out->var_, grad_input_vars.back()); InitVar(out->var_, grad_input_vars.back());
} }
} }
for (const VarBase* inp : inputs) { for (const VarBase* inp : inputs) {
grad_output_vars.push_back(inp->grads_); grad_output_vars.push_back(inp->grads_->var_);
if (!grad_output_vars.back()->IsInitialized()) { if (!grad_output_vars.back()->IsInitialized()) {
InitVar(inp->var_, grad_output_vars.back()); InitVar(inp->var_, grad_output_vars.back());
} }
......
...@@ -55,18 +55,18 @@ class PyLayer(core.PyLayer): ...@@ -55,18 +55,18 @@ class PyLayer(core.PyLayer):
super(PyLayer, self).__init__() super(PyLayer, self).__init__()
@staticmethod @staticmethod
def forward(inputs): def forward(*inputs):
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def backward(douts): def backward(*douts):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def __call__(cls, inputs): def __call__(cls, *inputs):
tracer = framework._imperative_tracer() tracer = framework._imperative_tracer()
block = framework.default_main_program().current_block() block = framework.default_main_program().current_block()
inputs = [x._ivar for x in inputs] ivar_inputs = [x._ivar for x in inputs]
if not hasattr(cls, 'forward_id'): if not hasattr(cls, 'forward_id'):
cls.forward_id = core.PyLayer.num_funcs() + 1 cls.forward_id = core.PyLayer.num_funcs() + 1
...@@ -78,11 +78,11 @@ class PyLayer(core.PyLayer): ...@@ -78,11 +78,11 @@ class PyLayer(core.PyLayer):
iop.forward_id = cls.forward_id iop.forward_id = cls.forward_id
iop.backward_id = cls.backward_id iop.backward_id = cls.backward_id
block.ops.append(iop) block.ops.append(iop)
ivars = tracer.py_trace(iop, inputs, False) ivars = tracer.py_trace(iop, ivar_inputs, False)
# ivars = core.PyLayer.apply(cls.forward, inputs) # ivars = core.PyLayer.apply(cls.forward, inputs)
ret = [] ret = []
for ivar in ivars: for ivar in ivars:
tensor = ivar.value.get_tensor() tensor = ivar.value().get_tensor()
py_var = framework.Variable( py_var = framework.Variable(
block, block,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
......
...@@ -97,35 +97,35 @@ class TestImperative(unittest.TestCase): ...@@ -97,35 +97,35 @@ class TestImperative(unittest.TestCase):
super(PyLayer1, self).__init__() super(PyLayer1, self).__init__()
@staticmethod @staticmethod
def forward(inputs): def forward(input):
return inputs return input
@staticmethod @staticmethod
def backward(inputs): def backward(input):
return inputs return input
class PyLayer2(fluid.imperative.PyLayer): class PyLayer2(fluid.imperative.PyLayer):
def __init__(self): def __init__(self):
super(PyLayer2, self).__init__() super(PyLayer2, self).__init__()
@staticmethod @staticmethod
def forward(inputs): def forward(input):
return inputs return input
@staticmethod @staticmethod
def backward(inputs): def backward(input):
return inputs return input
py_layer_1 = PyLayer1() py_layer_1 = PyLayer1()
py_layer_2 = PyLayer2() py_layer_2 = PyLayer2()
py_layer_1([fluid.imperative.base.to_variable(np.ones([2, 2]))]) py_layer_1(fluid.imperative.base.to_variable(np.ones([2, 2])))
py_layer_2([fluid.imperative.base.to_variable(np.ones([2, 2]))]) py_layer_2(fluid.imperative.base.to_variable(np.ones([2, 2])))
id = py_layer_1.forward_id id = py_layer_1.forward_id
self.assertGreater(id, 0) self.assertGreater(id, 0)
self.assertEqual(py_layer_1.backward_id, id + 1) self.assertEqual(py_layer_1.backward_id, id + 1)
self.assertEqual(py_layer_2.forward_id, id + 2) self.assertEqual(py_layer_2.forward_id, id + 2)
self.assertEqual(py_layer_2.backward_id, id + 3) self.assertEqual(py_layer_2.backward_id, id + 3)
py_layer_1([fluid.imperative.base.to_variable(np.ones([2, 2]))]) py_layer_1(fluid.imperative.base.to_variable(np.ones([2, 2])))
self.assertEqual(py_layer_1.forward_id, id) self.assertEqual(py_layer_1.forward_id, id)
def test_pylayer(self): def test_pylayer(self):
...@@ -133,7 +133,7 @@ class TestImperative(unittest.TestCase): ...@@ -133,7 +133,7 @@ class TestImperative(unittest.TestCase):
with fluid.imperative.guard(): with fluid.imperative.guard():
my_py_layer = MyPyLayer() my_py_layer = MyPyLayer()
var_inp = fluid.imperative.base.to_variable(np_inp) var_inp = fluid.imperative.base.to_variable(np_inp)
outs = my_py_layer([var_inp]) outs = my_py_layer(var_inp)
dy_out = np.sum(outs[0]._numpy()) dy_out = np.sum(outs[0]._numpy())
outs[0]._backward() outs[0]._backward()
dy_grad = var_inp._gradient() dy_grad = var_inp._gradient()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册