提交 c86b3dd6 编写于 作者: M minqiyang

Polish code

test=develop
上级 ddfb9f11
......@@ -133,11 +133,11 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
grad_in_vars.push_back(fwd_var_it->second->var_);
} else {
VarBase* var = vars[var_it->second];
if (!var->grads_->IsInitialized()) {
InitVar(var->var_, var->grads_);
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_);
}
// Douts.
grad_in_vars.push_back(var->grads_);
grad_in_vars.push_back(var->grads_->var_);
}
}
}
......@@ -149,10 +149,10 @@ void Tracer::Trace(OpBase* op, const VarBasePtrMap& inputs,
auto var_it = grad_to_var->find(grad_outvar);
PADDLE_ENFORCE(var_it != grad_to_var->end());
VarBase* var = vars[var_it->second];
if (!var->grads_->IsInitialized()) {
InitVar(var->var_, var->grads_);
if (!var->grads_->var_->IsInitialized()) {
InitVar(var->var_, var->grads_->var_);
}
grad_out_vars.push_back(var->grads_);
grad_out_vars.push_back(var->grads_->var_);
}
}
}
......@@ -194,13 +194,13 @@ std::vector<VarBase*> Tracer::PyTrace(OpBase* op,
grad_input_vars.push_back(out->var_);
}
for (VarBase* out : outputs) {
grad_input_vars.push_back(out->grads_);
grad_input_vars.push_back(out->grads_->var_);
if (!grad_input_vars.back()->IsInitialized()) {
InitVar(out->var_, grad_input_vars.back());
}
}
for (const VarBase* inp : inputs) {
grad_output_vars.push_back(inp->grads_);
grad_output_vars.push_back(inp->grads_->var_);
if (!grad_output_vars.back()->IsInitialized()) {
InitVar(inp->var_, grad_output_vars.back());
}
......
......@@ -55,18 +55,18 @@ class PyLayer(core.PyLayer):
super(PyLayer, self).__init__()
@staticmethod
def forward(inputs):
def forward(*inputs):
raise NotImplementedError
@staticmethod
def backward(douts):
def backward(*douts):
raise NotImplementedError
@classmethod
def __call__(cls, inputs):
def __call__(cls, *inputs):
tracer = framework._imperative_tracer()
block = framework.default_main_program().current_block()
inputs = [x._ivar for x in inputs]
ivar_inputs = [x._ivar for x in inputs]
if not hasattr(cls, 'forward_id'):
cls.forward_id = core.PyLayer.num_funcs() + 1
......@@ -78,11 +78,11 @@ class PyLayer(core.PyLayer):
iop.forward_id = cls.forward_id
iop.backward_id = cls.backward_id
block.ops.append(iop)
ivars = tracer.py_trace(iop, inputs, False)
ivars = tracer.py_trace(iop, ivar_inputs, False)
# ivars = core.PyLayer.apply(cls.forward, inputs)
ret = []
for ivar in ivars:
tensor = ivar.value.get_tensor()
tensor = ivar.value().get_tensor()
py_var = framework.Variable(
block,
type=core.VarDesc.VarType.LOD_TENSOR,
......
......@@ -97,35 +97,35 @@ class TestImperative(unittest.TestCase):
super(PyLayer1, self).__init__()
@staticmethod
def forward(inputs):
return inputs
def forward(input):
return input
@staticmethod
def backward(inputs):
return inputs
def backward(input):
return input
class PyLayer2(fluid.imperative.PyLayer):
def __init__(self):
super(PyLayer2, self).__init__()
@staticmethod
def forward(inputs):
return inputs
def forward(input):
return input
@staticmethod
def backward(inputs):
return inputs
def backward(input):
return input
py_layer_1 = PyLayer1()
py_layer_2 = PyLayer2()
py_layer_1([fluid.imperative.base.to_variable(np.ones([2, 2]))])
py_layer_2([fluid.imperative.base.to_variable(np.ones([2, 2]))])
py_layer_1(fluid.imperative.base.to_variable(np.ones([2, 2])))
py_layer_2(fluid.imperative.base.to_variable(np.ones([2, 2])))
id = py_layer_1.forward_id
self.assertGreater(id, 0)
self.assertEqual(py_layer_1.backward_id, id + 1)
self.assertEqual(py_layer_2.forward_id, id + 2)
self.assertEqual(py_layer_2.backward_id, id + 3)
py_layer_1([fluid.imperative.base.to_variable(np.ones([2, 2]))])
py_layer_1(fluid.imperative.base.to_variable(np.ones([2, 2])))
self.assertEqual(py_layer_1.forward_id, id)
def test_pylayer(self):
......@@ -133,7 +133,7 @@ class TestImperative(unittest.TestCase):
with fluid.imperative.guard():
my_py_layer = MyPyLayer()
var_inp = fluid.imperative.base.to_variable(np_inp)
outs = my_py_layer([var_inp])
outs = my_py_layer(var_inp)
dy_out = np.sum(outs[0]._numpy())
outs[0]._backward()
dy_grad = var_inp._gradient()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册