提交 11d4d39c 编写于 作者: X Xin Pan

forward working

test=develop
上级 b6291333
......@@ -161,13 +161,14 @@ class Layer {
static void CallPythonFunc(py::object* callable,
const std::vector<framework::LoDTensor>& ins,
std::vector<framework::LoDTensor*>* outs) {
std::vector<VarBase*>* outs) {
py::gil_scoped_acquire guard;
py::tuple in_args(ins.size());
for (size_t i = 0; i < ins.size(); ++i) {
in_args[i] = ins[i].IsInitialized() ? py::cast(ins[i]) : py::cast(nullptr);
}
// TODO(panyx0718): Who owns the returned LoDTensor.
auto ret = (*callable)(in_args);
auto ret_tuple = py::cast<py::tuple>(ret);
size_t ret_num = py::len(ret_tuple);
......@@ -176,7 +177,11 @@ static void CallPythonFunc(py::object* callable,
auto* py_out_tensor = py::cast<framework::LoDTensor*>(ret_tuple[i]);
PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
"Output tensor %d should not be nullptr", i);
outs->push_back(py_out_tensor);
VarBase* var = new VarBase();
auto* tensor = var->var_->GetMutable<framework::LoDTensor>();
tensor->ShareDataWith(*py_out_tensor);
tensor->set_lod(py_out_tensor->lod());
outs->push_back(var);
} catch (py::cast_error&) {
PADDLE_THROW("The %d-th output must be LoDTensor", i);
}
......@@ -187,18 +192,16 @@ class PyLayer {
public:
virtual ~PyLayer() {}
static std::vector<VarBase> Apply(py::object* callable,
static std::vector<VarBase*> Apply(py::object* callable,
const std::vector<VarBase>& inputs) {
std::vector<VarBase> outputs;
std::vector<framework::LoDTensor> tensor_inputs;
std::vector<framework::LoDTensor*> tensor_outputs;
std::vector<VarBase*> ret;
for (const VarBase& in : inputs) {
tensor_inputs.push_back(in.var_->Get<framework::LoDTensor>());
}
CallPythonFunc(callable, tensor_inputs, &tensor_outputs);
return outputs;
CallPythonFunc(callable, tensor_inputs, &ret);
return ret;
}
};
......
......@@ -182,9 +182,10 @@ PYBIND11_MODULE(core, m) {
.def_static("apply",
[](py::object *callable,
const std::vector<imperative::VarBase> &inputs)
-> std::vector<imperative::VarBase> {
-> std::vector<imperative::VarBase *> {
return imperative::PyLayer::Apply(callable, inputs);
});
},
py::return_value_policy::take_ownership);
BindTracer(&m);
......
......@@ -372,6 +372,9 @@ class Variable(object):
self.stop_gradient = stop_gradient
self.is_data = is_data
if _in_imperative_mode():
if 'ivar' in kwargs:
self._ivar = kwargs['ivar']
else:
self._ivar = core.VarBase()
self._ivar.desc = self.desc
self._ivar.stop_gradient = stop_gradient
......
......@@ -67,5 +67,17 @@ class PyLayer(core.PyLayer):
def __call__(cls, inputs):
inputs = map(base.to_variable, inputs)
inputs = [x._ivar for x in inputs]
sys.stderr.write('%s\n' % inputs)
return core.PyLayer.apply(cls.forward, inputs)
ivars = core.PyLayer.apply(cls.forward, inputs)
ret = []
for ivar in ivars:
tensor = ivar.value.get_tensor()
block = framework.default_main_program().current_block()
py_var = framework.Variable(
block,
type=core.VarDesc.VarType.LOD_TENSOR,
name=None,
shape=tensor.shape(),
dtype=tensor._dtype(),
ivar=ivar)
ret.append(py_var)
return ret
......@@ -81,8 +81,8 @@ class TestImperative(unittest.TestCase):
def test_pylayer(self):
with fluid.imperative.guard():
my_py_layer = MyPyLayer()
out = my_py_layer([np.ones([2, 2], np.float32)])
sys.stderr.write('%s\n' % np.array(out))
outs = my_py_layer([np.ones([2, 2], np.float32)])
sys.stderr.write('%s\n' % outs[0]._numpy())
# out.backward()
def test_layer_in_out(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册