提交 b6291333 编写于 作者: X Xin Pan

checkpoint runnable PyLayer

test=develop
上级 0d0bc612
......@@ -17,6 +17,9 @@
#include <map>
#include <string>
#include <vector>
#include "pybind11/pybind11.h"
#include "Python.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/var_desc.h"
......@@ -25,6 +28,8 @@
namespace paddle {
namespace imperative {
namespace py = ::pybind11;
class PreparedOp {
public:
PreparedOp(const framework::OperatorBase& op,
......@@ -152,10 +157,48 @@ class Layer {
std::vector<VarBase> vars;
return vars;
}
};
virtual std::vector<VarBase> Backward(const std::vector<VarBase>& inputs) {
std::vector<VarBase> vars;
return vars;
static void CallPythonFunc(py::object* callable,
const std::vector<framework::LoDTensor>& ins,
std::vector<framework::LoDTensor*>* outs) {
py::gil_scoped_acquire guard;
py::tuple in_args(ins.size());
for (size_t i = 0; i < ins.size(); ++i) {
in_args[i] = ins[i].IsInitialized() ? py::cast(ins[i]) : py::cast(nullptr);
}
auto ret = (*callable)(in_args);
auto ret_tuple = py::cast<py::tuple>(ret);
size_t ret_num = py::len(ret_tuple);
for (size_t i = 0; i < ret_num; ++i) {
try {
auto* py_out_tensor = py::cast<framework::LoDTensor*>(ret_tuple[i]);
PADDLE_ENFORCE_NOT_NULL(py_out_tensor,
"Output tensor %d should not be nullptr", i);
outs->push_back(py_out_tensor);
} catch (py::cast_error&) {
PADDLE_THROW("The %d-th output must be LoDTensor", i);
}
}
}
class PyLayer {
public:
virtual ~PyLayer() {}
static std::vector<VarBase> Apply(py::object* callable,
const std::vector<VarBase>& inputs) {
std::vector<VarBase> outputs;
std::vector<framework::LoDTensor> tensor_inputs;
std::vector<framework::LoDTensor*> tensor_outputs;
for (const VarBase& in : inputs) {
tensor_inputs.push_back(in.var_->Get<framework::LoDTensor>());
}
CallPythonFunc(callable, tensor_inputs, &tensor_outputs);
return outputs;
}
};
......
......@@ -31,12 +31,6 @@ class Layer : public imperative::Layer {
PYBIND11_OVERLOAD(std::vector<imperative::VarBase>, Layer, Forward,
inputs); // NOLINT
}
std::vector<imperative::VarBase> Backward(
const std::vector<imperative::VarBase>& inputs) override {
PYBIND11_OVERLOAD(std::vector<imperative::VarBase>, Layer, Backward,
inputs); // NOLINT
}
};
class PyOpBase : public imperative::OpBase {
......
......@@ -172,15 +172,20 @@ PYBIND11_MODULE(core, m) {
py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer");
layer.def(py::init<>())
.def("forward",
[](imperative::Layer &self,
const std::vector<imperative::VarBase> &inputs) {
return self.Forward(inputs);
})
.def("backward", [](imperative::Layer &self,
const std::vector<imperative::VarBase> &inputs) {
return self.Backward(inputs);
.def("forward", [](imperative::Layer &self,
const std::vector<imperative::VarBase> &inputs) {
return self.Forward(inputs);
});
py::class_<paddle::imperative::PyLayer>(m, "PyLayer")
.def(py::init<>())
.def_static("apply",
[](py::object *callable,
const std::vector<imperative::VarBase> &inputs)
-> std::vector<imperative::VarBase> {
return imperative::PyLayer::Apply(callable, inputs);
});
BindTracer(&m);
py::class_<Tensor>(m, "Tensor", py::buffer_protocol())
......
......@@ -20,7 +20,7 @@ from paddle.fluid import core
from paddle.fluid import framework
from paddle.fluid.imperative import base
__all__ = ['Layer']
__all__ = ['Layer', 'PyLayer']
class Layer(core.Layer):
......@@ -48,14 +48,24 @@ class Layer(core.Layer):
raise ValueError("Layer shouldn't implement backward")
class PyLayer(core.Layer):
# TODO(panyx0718): Inherit from C++ base class.
class PyLayer(core.PyLayer):
"""Layers composed of user-defined python codes."""
def __call__(self, *inputs):
pass
def __init__(self):
super(PyLayer, self).__init__()
def forward(self, *inputs):
@staticmethod
def forward(inputs):
raise NotImplementedError
def backward(self, *inputs):
@staticmethod
def backward(inputs):
raise NotImplementedError
@classmethod
def __call__(cls, inputs):
inputs = map(base.to_variable, inputs)
inputs = [x._ivar for x in inputs]
sys.stderr.write('%s\n' % inputs)
return core.PyLayer.apply(cls.forward, inputs)
......@@ -15,6 +15,7 @@
import contextlib
import unittest
import numpy as np
import sys
import paddle.fluid as fluid
from paddle.fluid import core
......@@ -34,6 +35,24 @@ class MyLayer(fluid.imperative.Layer):
return [x]
class MyPyLayer(fluid.imperative.PyLayer):
def __init__(self):
super(MyPyLayer, self).__init__()
@staticmethod
def forward(inputs):
sys.stderr.write('before forward\n')
ret = np.tanh(inputs[0])
sys.stderr.write('after forward: %s\n' % ret)
tensor = core.LoDTensor()
tensor.set(ret, core.CPUPlace())
return tuple([tensor])
@staticmethod
def backward(douts, outs):
return np.array(douts[0]) * (1 - np.square(np.array(outs[0])))
class MLP(fluid.imperative.Layer):
def __init__(self):
super(MLP, self).__init__()
......@@ -59,6 +78,13 @@ class TestImperative(unittest.TestCase):
l = fluid.imperative.Layer()
self.assertRaises(NotImplementedError, l.forward, [])
def test_pylayer(self):
with fluid.imperative.guard():
my_py_layer = MyPyLayer()
out = my_py_layer([np.ones([2, 2], np.float32)])
sys.stderr.write('%s\n' % np.array(out))
# out.backward()
def test_layer_in_out(self):
np_inp = np.array([1.0, 2.0, -1.0], dtype=np.float32)
with fluid.imperative.guard():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册