提交 2349acea 编写于 作者: X Xin Pan

checkpoint

test=develop
上级 11d4d39c
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
std::map<int, py::object> py_funcs_;
using framework::Variable; using framework::Variable;
void AddTo(Variable* src, Variable* dst) { void AddTo(Variable* src, Variable* dst) {
...@@ -183,5 +185,22 @@ void VarBase::RunBackward() { ...@@ -183,5 +185,22 @@ void VarBase::RunBackward() {
Autograd().RunBackward(this); Autograd().RunBackward(this);
} }
void PyLayer::RegisterFunc(int func_id, const py::object& py_func) {
py_funcs_[func_id] = py_func;
}
std::vector<VarBase*> PyLayer::Apply(int func_id,
const std::vector<VarBase>& inputs) {
std::vector<framework::LoDTensor> tensor_inputs;
std::vector<VarBase*> ret;
for (const VarBase& in : inputs) {
tensor_inputs.push_back(in.var_->Get<framework::LoDTensor>());
}
PADDLE_ENFORCE(py_funcs_.find(func_id) != py_funcs_.end());
CallPythonFunc(py_funcs_[func_id], tensor_inputs, &ret);
return ret;
}
} // namespace imperative } // namespace imperative
} // namespace paddle } // namespace paddle
...@@ -82,6 +82,7 @@ class PreparedOp { ...@@ -82,6 +82,7 @@ class PreparedOp {
framework::OperatorWithKernel::OpKernelFunc func; framework::OperatorWithKernel::OpKernelFunc func;
platform::DeviceContext* dev_ctx; platform::DeviceContext* dev_ctx;
}; };
class OpBase; class OpBase;
class VarBase { class VarBase {
...@@ -128,7 +129,11 @@ class VarBase { ...@@ -128,7 +129,11 @@ class VarBase {
class OpBase { class OpBase {
public: public:
OpBase() : op_desc_(nullptr), grad_op_desc_(nullptr) {} OpBase()
: op_desc_(nullptr),
grad_op_desc_(nullptr),
forward_id_(-1),
backward_id_(-1) {}
virtual ~OpBase() { virtual ~OpBase() {
if (grad_op_desc_) delete grad_op_desc_; if (grad_op_desc_) delete grad_op_desc_;
...@@ -139,6 +144,9 @@ class OpBase { ...@@ -139,6 +144,9 @@ class OpBase {
framework::OpDesc* op_desc_; framework::OpDesc* op_desc_;
framework::OpDesc* grad_op_desc_; framework::OpDesc* grad_op_desc_;
int forward_id_;
int backward_id_;
std::map<std::string, std::vector<VarBase*>> input_vars_; std::map<std::string, std::vector<VarBase*>> input_vars_;
std::map<std::string, std::vector<VarBase*>> output_vars_; std::map<std::string, std::vector<VarBase*>> output_vars_;
std::map<std::string, std::vector<OpBase*>> pre_ops_; std::map<std::string, std::vector<OpBase*>> pre_ops_;
...@@ -159,7 +167,7 @@ class Layer { ...@@ -159,7 +167,7 @@ class Layer {
} }
}; };
static void CallPythonFunc(py::object* callable, static void CallPythonFunc(const py::object& callable,
const std::vector<framework::LoDTensor>& ins, const std::vector<framework::LoDTensor>& ins,
std::vector<VarBase*>* outs) { std::vector<VarBase*>* outs) {
py::gil_scoped_acquire guard; py::gil_scoped_acquire guard;
...@@ -169,7 +177,7 @@ static void CallPythonFunc(py::object* callable, ...@@ -169,7 +177,7 @@ static void CallPythonFunc(py::object* callable,
} }
// TODO(panyx0718): Who owns the returned LoDTensor. // TODO(panyx0718): Who owns the returned LoDTensor.
auto ret = (*callable)(in_args); auto ret = callable(in_args);
auto ret_tuple = py::cast<py::tuple>(ret); auto ret_tuple = py::cast<py::tuple>(ret);
size_t ret_num = py::len(ret_tuple); size_t ret_num = py::len(ret_tuple);
for (size_t i = 0; i < ret_num; ++i) { for (size_t i = 0; i < ret_num; ++i) {
...@@ -192,17 +200,10 @@ class PyLayer { ...@@ -192,17 +200,10 @@ class PyLayer {
public: public:
virtual ~PyLayer() {} virtual ~PyLayer() {}
static std::vector<VarBase*> Apply(py::object* callable, static void RegisterFunc(int func_id, const py::object& py_func);
const std::vector<VarBase>& inputs) {
std::vector<framework::LoDTensor> tensor_inputs;
std::vector<VarBase*> ret;
for (const VarBase& in : inputs) { static std::vector<VarBase*> Apply(int func_id,
tensor_inputs.push_back(in.var_->Get<framework::LoDTensor>()); const std::vector<VarBase>& inputs);
}
CallPythonFunc(callable, tensor_inputs, &ret);
return ret;
}
}; };
} // namespace imperative } // namespace imperative
......
...@@ -172,6 +172,21 @@ class Tracer { ...@@ -172,6 +172,21 @@ class Tracer {
op->block_ = block; op->block_ = block;
} }
std::vector<VarBase*> PyTrace(OpBase* op,
const std::vector<VarBase>& inputs) {
std::vector<VarBase*> outputs = PyLayer::Apply(op->forward_id_, inputs);
/*
for (const VarBase& inp : inputs) {
if (inp.pre_op_) {
op->pre_ops_[it.first].push_back(inp->pre_op_);
op->pre_ops_out_idx_[it.first].push_back(inp->pre_op_out_idx_);
} else {
op->pre_ops_[it.first].push_back(nullptr);
}
}*/
return outputs;
}
private: private:
framework::BlockDesc* root_block_; framework::BlockDesc* root_block_;
}; };
......
...@@ -26,7 +26,9 @@ void BindTracer(pybind11::module *m) { ...@@ -26,7 +26,9 @@ void BindTracer(pybind11::module *m) {
[](imperative::Tracer &self, framework::BlockDesc *root_block) { [](imperative::Tracer &self, framework::BlockDesc *root_block) {
new (&self) imperative::Tracer(root_block); new (&self) imperative::Tracer(root_block);
}) })
.def("trace", &imperative::Tracer::Trace); .def("trace", &imperative::Tracer::Trace)
.def("py_trace", &imperative::Tracer::PyTrace,
pybind11::return_value_policy::take_ownership);
} }
} // namespace pybind } // namespace pybind
......
...@@ -168,6 +168,13 @@ PYBIND11_MODULE(core, m) { ...@@ -168,6 +168,13 @@ PYBIND11_MODULE(core, m) {
self.op_desc_ = op_desc; self.op_desc_ = op_desc;
} }
}, },
py::return_value_policy::reference)
.def_property(
"forward_id",
[](const imperative::OpBase &self) { return self.forward_id_; },
[](imperative::OpBase &self, int forward_id) {
self.forward_id_ = forward_id;
},
py::return_value_policy::reference); py::return_value_policy::reference);
py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer"); py::class_<imperative::Layer, Layer /* <--- trampoline*/> layer(m, "Layer");
...@@ -179,13 +186,16 @@ PYBIND11_MODULE(core, m) { ...@@ -179,13 +186,16 @@ PYBIND11_MODULE(core, m) {
py::class_<paddle::imperative::PyLayer>(m, "PyLayer") py::class_<paddle::imperative::PyLayer>(m, "PyLayer")
.def(py::init<>()) .def(py::init<>())
.def_static("apply", .def_static(
[](py::object *callable, "apply",
const std::vector<imperative::VarBase> &inputs) [](int func_id, const std::vector<imperative::VarBase> &inputs)
-> std::vector<imperative::VarBase *> { -> std::vector<imperative::VarBase *> {
return imperative::PyLayer::Apply(callable, inputs); return imperative::PyLayer::Apply(func_id, inputs);
}, },
py::return_value_policy::take_ownership); py::return_value_policy::take_ownership)
.def_static("register_func", [](int func_id, const py::object &callable) {
imperative::PyLayer::RegisterFunc(func_id, callable);
});
BindTracer(&m); BindTracer(&m);
......
...@@ -48,7 +48,6 @@ class Layer(core.Layer): ...@@ -48,7 +48,6 @@ class Layer(core.Layer):
raise ValueError("Layer shouldn't implement backward") raise ValueError("Layer shouldn't implement backward")
# TODO(panyx0718): Inherit from C++ base class.
class PyLayer(core.PyLayer): class PyLayer(core.PyLayer):
"""Layers composed of user-defined python codes.""" """Layers composed of user-defined python codes."""
...@@ -65,13 +64,21 @@ class PyLayer(core.PyLayer): ...@@ -65,13 +64,21 @@ class PyLayer(core.PyLayer):
@classmethod @classmethod
def __call__(cls, inputs): def __call__(cls, inputs):
tracer = framework._imperative_tracer()
block = framework.default_main_program().current_block()
inputs = map(base.to_variable, inputs) inputs = map(base.to_variable, inputs)
inputs = [x._ivar for x in inputs] inputs = [x._ivar for x in inputs]
ivars = core.PyLayer.apply(cls.forward, inputs)
PyLayer.register_func(1, cls.forward)
iop = core.OpBase()
iop.forward_id = 1
block.ops.append(iop)
ivars = tracer.py_trace(iop, inputs)
# ivars = core.PyLayer.apply(cls.forward, inputs)
ret = [] ret = []
for ivar in ivars: for ivar in ivars:
tensor = ivar.value.get_tensor() tensor = ivar.value.get_tensor()
block = framework.default_main_program().current_block()
py_var = framework.Variable( py_var = framework.Variable(
block, block,
type=core.VarDesc.VarType.LOD_TENSOR, type=core.VarDesc.VarType.LOD_TENSOR,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册