From e240ba291853856d29790ecd3b6c5493c5ab2cd3 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Wed, 12 Dec 2018 03:16:34 +0000 Subject: [PATCH] implement backward test=develop --- paddle/fluid/framework/op_desc.cc | 2 + paddle/fluid/framework/op_desc.h | 2 + paddle/fluid/framework/operator.cc | 5 + paddle/fluid/framework/shape_inference.h | 5 + paddle/fluid/operators/py_func_op.cc | 127 ++++++++++++++++++++--- paddle/fluid/pybind/protobuf.cc | 2 +- python/paddle/fluid/layers/nn.py | 39 ++++--- 7 files changed, 154 insertions(+), 28 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index e8ecd90502..f8a9340df5 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -34,6 +34,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { public: CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block); + InferShapeOpPtr GetOp() const override { return &op_; } + bool HasInput(const std::string &name) const override; bool HasOutput(const std::string &name) const override; diff --git a/paddle/fluid/framework/op_desc.h b/paddle/fluid/framework/op_desc.h index 30c8a26c3d..3b3f50bfa7 100644 --- a/paddle/fluid/framework/op_desc.h +++ b/paddle/fluid/framework/op_desc.h @@ -121,6 +121,8 @@ class OpDesc { BlockDesc *Block() { return this->block_; } + const BlockDesc *Block() const { return this->block_; } + private: template static std::vector MapKeys(const MapType &map) { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index c6f3254e9f..188ab120be 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -481,6 +481,8 @@ class RuntimeInferShapeContext : public InferShapeContext { RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) : op_(op), scope_(scope) {} + InferShapeOpPtr GetOp() const override { return &op_; } + bool HasInput(const std::string& name) const override { // has only one input const auto& ins = op_.Inputs(); @@ -879,6 +881,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( t = &(var->Get().value()); } if (t != nullptr) { + PADDLE_ENFORCE(t->IsInitialized(), + "Input %s(%s) does not exist in Operator %s", + input.first, ipt_name, DebugString()); int tmp = static_cast(ToDataType(t->type())); PADDLE_ENFORCE( tmp == data_type || data_type == -1, diff --git a/paddle/fluid/framework/shape_inference.h b/paddle/fluid/framework/shape_inference.h index d73cca121e..2f95ab353e 100644 --- a/paddle/fluid/framework/shape_inference.h +++ b/paddle/fluid/framework/shape_inference.h @@ -25,7 +25,10 @@ limitations under the License. */ namespace paddle { namespace framework { +class OperatorBase; + using InferShapeVarPtr = boost::variant; +using InferShapeOpPtr = boost::variant; class InferShapeContext { public: @@ -38,6 +41,8 @@ class InferShapeContext { std::vector GetOutputsVarType( const std::string &name) const; + virtual InferShapeOpPtr GetOp() const = 0; + virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0; diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc index 86914f3060..46a6125f97 100644 --- a/paddle/fluid/operators/py_func_op.cc +++ b/paddle/fluid/operators/py_func_op.cc @@ -24,34 +24,34 @@ namespace operators { namespace py = pybind11; -static std::mutex g_py_callables_mtx; static std::vector g_py_callables; size_t AppendPythonCallableObjectAndReturnId(py::object py_obj) { - std::lock_guard guard(g_py_callables_mtx); g_py_callables.emplace_back(py_obj); return g_py_callables.size() - 1; } static py::object *GetPythonCallableObject(size_t i) { - std::lock_guard guard(g_py_callables_mtx); PADDLE_ENFORCE_LT(i, g_py_callables.size()); return &g_py_callables[i]; } -void DoCallPythonFunc(py::object *callable, const std::string &func_token, - const std::vector &ins, - std::vector *out) { +void CallPythonFunc(py::object *callable, const std::string &func_token, + const std::vector &ins, + std::vector *out) { py::gil_scoped_acquire guard{}; py::tuple in_args(ins.size()); for (size_t i = 0; i < ins.size(); ++i) { - in_args[i] = py::cast(ins[i]); + in_args[i] = ins[i].IsInitialized() ? py::cast(ins[i]) : py::cast(nullptr); } auto ret = (*callable)(func_token, *in_args); auto ret_tuple = py::cast(ret); PADDLE_ENFORCE_EQ(py::len(ret_tuple), out->size(), "Output number not match"); for (size_t i = 0; i < out->size(); ++i) { + if ((*out)[i] == nullptr) { + continue; + } try { auto *out_tensor = py::cast(ret_tuple[i]); PADDLE_ENFORCE_NOT_NULL(out_tensor, @@ -67,8 +67,43 @@ void DoCallPythonFunc(py::object *callable, const std::string &func_token, class PyFuncOpShapeInference : public framework::InferShapeBase { public: void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(!ctx->IsRuntime(), + "Infer shape cannot be called in runtime."); PADDLE_ENFORCE(ctx->HasInputs("X"), "Input(X) must exist"); PADDLE_ENFORCE(ctx->HasOutputs("Out"), "Output(Out) must exist"); + + auto *op = boost::get(ctx->GetOp()); + auto *block = op->Block(); + // No need to infer shape in forward part + if (block->ForwardBlockID() < 0) { + return; + } + + PADDLE_ENFORCE(!ctx->Attrs().Get("token").empty(), + "Function token cannot be empty"); + + const std::string kGradVarSuffix = framework::kGradVarSuffix; + auto out_vars = ctx->GetOutputVarPtrs("Out"); + for (auto &out_var : out_vars) { + auto *out_var_desc = boost::get(out_var); + auto out_name = out_var_desc->Name(); + if (out_name == framework::kEmptyVarName || + out_name.size() < kGradVarSuffix.size()) { + continue; + } + + size_t len = out_name.size() - kGradVarSuffix.size(); + if (out_name.substr(len) == kGradVarSuffix) { + auto fwd_var_name = out_name.substr(0, len); + auto *in_var_desc = block->FindVarRecursive(fwd_var_name); + PADDLE_ENFORCE_NOT_NULL(in_var_desc, "Forward variable %s not found", + fwd_var_name); + out_var_desc->SetShape(in_var_desc->GetShape()); + out_var_desc->SetDataType(in_var_desc->GetDataType()); + out_var_desc->SetLoDLevel(in_var_desc->GetLoDLevel()); + out_var_desc->SetType(in_var_desc->GetType()); + } + } } }; @@ -77,12 +112,68 @@ class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker { void Make() override { AddInput("X", "Inputs of py_func op.").AsDuplicable(); AddOutput("Out", "Outputs of py_func op").AsDuplicable(); - AddAttr("token", "function token"); - AddAttr("handle_idx", "handle index").SetDefault(0); + AddAttr("handle_idx", "Index of the registered py_func handle") + .SetDefault(0); + AddAttr("token", "Token of function token to be called") + .SetDefault(""); + AddAttr("backward_token", + "Token of backward function to be called") + .SetDefault(""); AddComment(R"DOC("PyFunc Op")DOC"); } }; +class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase { + public: + using framework::GradOpDescMakerBase::GradOpDescMakerBase; + + std::vector> operator()() const override { + auto &fwd_attrs = Attrs(); + if (fwd_attrs.at("backward_token").empty()) { + return {}; + } + + std::unique_ptr grad_op(new framework::OpDesc()); + grad_op->SetType("py_func"); + + framework::AttributeMap bwd_attrs; + bwd_attrs["token"] = fwd_attrs.at("backward_token"); + bwd_attrs["backward_token"] = std::string(""); + grad_op->SetAttrMap(bwd_attrs); + + auto bwd_in = Input("X"); + auto fwd_out = Output("Out"); + auto fwd_out_grad = OutputGrad("Out"); + bwd_in.insert(bwd_in.end(), fwd_out.begin(), fwd_out.end()); + bwd_in.insert(bwd_in.end(), fwd_out_grad.begin(), fwd_out_grad.end()); + + auto bwd_out = InputGrad("X", false); + + if (VLOG_IS_ON(10)) { + std::string in_str = "PyFunc Grad Input: "; + for (auto &in : bwd_in) { + in_str += in; + in_str += " "; + } + VLOG(10) << in_str; + + std::string out_str = "PyFunc Grad Output: "; + for (auto &out : bwd_out) { + out_str += out; + out += " "; + } + VLOG(10) << out_str; + } + + grad_op->SetInput("X", bwd_in); + grad_op->SetOutput("Out", InputGrad("X", false)); + + std::vector> ret(1); + ret[0] = std::move(grad_op); + return ret; + } +}; + class PyFuncOp : public framework::OperatorBase { public: using framework::OperatorBase::OperatorBase; @@ -95,8 +186,14 @@ class PyFuncOp : public framework::OperatorBase { std::vector inputs(in_arg_names.size()); for (size_t i = 0; i < in_arg_names.size(); ++i) { - auto &in_tensor = - scope.FindVar(in_arg_names[i])->Get(); + auto in_var = scope.FindVar(in_arg_names[i]); + if (in_var == nullptr) { + continue; + } + auto &in_tensor = in_var->Get(); + if (!in_tensor.IsInitialized()) { + continue; + } if (platform::is_gpu_place(in_tensor.place())) { framework::TensorCopySync(in_tensor, platform::CPUPlace(), &inputs[i]); } else { @@ -107,8 +204,9 @@ class PyFuncOp : public framework::OperatorBase { std::vector outputs(out_arg_names.size()); for (size_t i = 0; i < out_arg_names.size(); ++i) { + auto *out_var = scope.FindVar(out_arg_names[i]); auto *out_tensor = - scope.FindVar(out_arg_names[i])->GetMutable(); + out_var ? out_var->GetMutable() : nullptr; outputs[i] = out_tensor; } @@ -117,7 +215,7 @@ class PyFuncOp : public framework::OperatorBase { auto *py_callable = GetPythonCallableObject(handle_idx); VLOG(10) << "Call py_func_op with token " << token << ", and handle_idx " << handle_idx; - DoCallPythonFunc(py_callable, token, inputs, &outputs); + CallPythonFunc(py_callable, token, inputs, &outputs); } }; @@ -127,5 +225,4 @@ class PyFuncOp : public framework::OperatorBase { namespace ops = paddle::operators; REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker, - ops::PyFuncOpShapeInference, - paddle::framework::EmptyGradOpMaker); + ops::PyFuncOpShapeInference, ops::PyFuncOpGradDescMaker); diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index ac406b27b5..4b218fb3a2 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -328,7 +328,7 @@ void BindOpDesc(pybind11::module *m) { .def("infer_var_type", &pd::OpDesc::InferVarType) .def("set_is_target", &pd::OpDesc::SetIsTarget) .def("serialize_to_string", SerializeMessage) - .def("block", &pd::OpDesc::Block, + .def("block", [](pd::OpDesc &self) { return self.Block(); }, pybind11::return_value_policy::reference); } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 92cd53a6c3..66c98c935d 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -9096,12 +9096,9 @@ def py_func(func, x, out, backward_func=None): _main_program_to_register = dict() @classmethod - def get_instance(cls, prog=None): - if prog is None: - prog = fluid.default_main_program() - + def get_instance(cls, prog): if not isinstance(prog, Program): - raise ValueError("prog must be None or type of Program") + raise TypeError("prog must be type of Program") ret = cls._main_program_to_register.get(prog, None) if ret is None: @@ -9155,6 +9152,10 @@ def py_func(func, x, out, backward_func=None): ret = [] for i in six.moves.range(len(ret0)): + if ret0[i] is None: + ret.append(None) + continue + if isinstance(ret0[i], core.LoDTensor): ret.append(ret0[i]) continue @@ -9175,20 +9176,34 @@ def py_func(func, x, out, backward_func=None): x = [x] if isinstance(out, Variable): - out = [out] + out_list = [out] + else: + out_list = out + + if func is None or not hasattr(func, '__call__'): + raise TypeError('Input func must be a function') - for each_out in out: + if backward_func is not None and not hasattr(backward_func, '__call__'): + raise TypeError('Input backward_func must be a function') + + for each_out in out_list: if len(each_out.shape) == 0: raise ValueError( - 'users should infer shapes of outputs of py_func op manually') + 'Output shapes of py_func op should be provided by users manually' + ) py_func_reg = PyFuncRegister.get_instance(helper.main_program) - token = py_func_reg.unique_token(func) + forward_token = py_func_reg.unique_token(func) + backward_token = py_func_reg.unique_token( + backward_func) if backward_func is not None else '' helper.append_op( type='py_func', inputs={'X': x}, - outputs={'Out': out}, - attrs={'handle_idx': py_func_reg.handle_idx, - 'token': token}) + outputs={'Out': out_list}, + attrs={ + 'handle_idx': py_func_reg.handle_idx, + 'token': forward_token, + 'backward_token': backward_token + }) return out -- GitLab