提交 e240ba29 编写于 作者: S sneaxiy

implement backward

test=develop
上级 8760d23c
...@@ -34,6 +34,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -34,6 +34,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
public: public:
CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block); CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
InferShapeOpPtr GetOp() const override { return &op_; }
bool HasInput(const std::string &name) const override; bool HasInput(const std::string &name) const override;
bool HasOutput(const std::string &name) const override; bool HasOutput(const std::string &name) const override;
......
...@@ -121,6 +121,8 @@ class OpDesc { ...@@ -121,6 +121,8 @@ class OpDesc {
BlockDesc *Block() { return this->block_; } BlockDesc *Block() { return this->block_; }
const BlockDesc *Block() const { return this->block_; }
private: private:
template <typename MapType> template <typename MapType>
static std::vector<typename MapType::key_type> MapKeys(const MapType &map) { static std::vector<typename MapType::key_type> MapKeys(const MapType &map) {
......
...@@ -481,6 +481,8 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -481,6 +481,8 @@ class RuntimeInferShapeContext : public InferShapeContext {
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {} : op_(op), scope_(scope) {}
InferShapeOpPtr GetOp() const override { return &op_; }
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
// has only one input // has only one input
const auto& ins = op_.Inputs(); const auto& ins = op_.Inputs();
...@@ -879,6 +881,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType( ...@@ -879,6 +881,9 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value()); t = &(var->Get<SelectedRows>().value());
} }
if (t != nullptr) { if (t != nullptr) {
PADDLE_ENFORCE(t->IsInitialized(),
"Input %s(%s) does not exist in Operator %s",
input.first, ipt_name, DebugString());
int tmp = static_cast<int>(ToDataType(t->type())); int tmp = static_cast<int>(ToDataType(t->type()));
PADDLE_ENFORCE( PADDLE_ENFORCE(
tmp == data_type || data_type == -1, tmp == data_type || data_type == -1,
......
...@@ -25,7 +25,10 @@ limitations under the License. */ ...@@ -25,7 +25,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OperatorBase;
using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>; using InferShapeVarPtr = boost::variant<VarDesc *, Variable *>;
using InferShapeOpPtr = boost::variant<const OpDesc *, const OperatorBase *>;
class InferShapeContext { class InferShapeContext {
public: public:
...@@ -38,6 +41,8 @@ class InferShapeContext { ...@@ -38,6 +41,8 @@ class InferShapeContext {
std::vector<proto::VarType::Type> GetOutputsVarType( std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string &name) const; const std::string &name) const;
virtual InferShapeOpPtr GetOp() const = 0;
virtual bool HasInputs(const std::string &name) const = 0; virtual bool HasInputs(const std::string &name) const = 0;
virtual bool HasOutputs(const std::string &name) const = 0; virtual bool HasOutputs(const std::string &name) const = 0;
......
...@@ -24,34 +24,34 @@ namespace operators { ...@@ -24,34 +24,34 @@ namespace operators {
namespace py = pybind11; namespace py = pybind11;
static std::mutex g_py_callables_mtx;
static std::vector<py::object> g_py_callables; static std::vector<py::object> g_py_callables;
size_t AppendPythonCallableObjectAndReturnId(py::object py_obj) { size_t AppendPythonCallableObjectAndReturnId(py::object py_obj) {
std::lock_guard<std::mutex> guard(g_py_callables_mtx);
g_py_callables.emplace_back(py_obj); g_py_callables.emplace_back(py_obj);
return g_py_callables.size() - 1; return g_py_callables.size() - 1;
} }
static py::object *GetPythonCallableObject(size_t i) { static py::object *GetPythonCallableObject(size_t i) {
std::lock_guard<std::mutex> guard(g_py_callables_mtx);
PADDLE_ENFORCE_LT(i, g_py_callables.size()); PADDLE_ENFORCE_LT(i, g_py_callables.size());
return &g_py_callables[i]; return &g_py_callables[i];
} }
void DoCallPythonFunc(py::object *callable, const std::string &func_token, void CallPythonFunc(py::object *callable, const std::string &func_token,
const std::vector<framework::LoDTensor> &ins, const std::vector<framework::LoDTensor> &ins,
std::vector<framework::LoDTensor *> *out) { std::vector<framework::LoDTensor *> *out) {
py::gil_scoped_acquire guard{}; py::gil_scoped_acquire guard{};
py::tuple in_args(ins.size()); py::tuple in_args(ins.size());
for (size_t i = 0; i < ins.size(); ++i) { for (size_t i = 0; i < ins.size(); ++i) {
in_args[i] = py::cast(ins[i]); in_args[i] = ins[i].IsInitialized() ? py::cast(ins[i]) : py::cast(nullptr);
} }
auto ret = (*callable)(func_token, *in_args); auto ret = (*callable)(func_token, *in_args);
auto ret_tuple = py::cast<py::tuple>(ret); auto ret_tuple = py::cast<py::tuple>(ret);
PADDLE_ENFORCE_EQ(py::len(ret_tuple), out->size(), "Output number not match"); PADDLE_ENFORCE_EQ(py::len(ret_tuple), out->size(), "Output number not match");
for (size_t i = 0; i < out->size(); ++i) { for (size_t i = 0; i < out->size(); ++i) {
if ((*out)[i] == nullptr) {
continue;
}
try { try {
auto *out_tensor = py::cast<framework::LoDTensor *>(ret_tuple[i]); auto *out_tensor = py::cast<framework::LoDTensor *>(ret_tuple[i]);
PADDLE_ENFORCE_NOT_NULL(out_tensor, PADDLE_ENFORCE_NOT_NULL(out_tensor,
...@@ -67,8 +67,43 @@ void DoCallPythonFunc(py::object *callable, const std::string &func_token, ...@@ -67,8 +67,43 @@ void DoCallPythonFunc(py::object *callable, const std::string &func_token,
class PyFuncOpShapeInference : public framework::InferShapeBase { class PyFuncOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(!ctx->IsRuntime(),
"Infer shape cannot be called in runtime.");
PADDLE_ENFORCE(ctx->HasInputs("X"), "Input(X) must exist"); PADDLE_ENFORCE(ctx->HasInputs("X"), "Input(X) must exist");
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "Output(Out) must exist"); PADDLE_ENFORCE(ctx->HasOutputs("Out"), "Output(Out) must exist");
auto *op = boost::get<const framework::OpDesc *>(ctx->GetOp());
auto *block = op->Block();
// No need to infer shape in forward part
if (block->ForwardBlockID() < 0) {
return;
}
PADDLE_ENFORCE(!ctx->Attrs().Get<std::string>("token").empty(),
"Function token cannot be empty");
const std::string kGradVarSuffix = framework::kGradVarSuffix;
auto out_vars = ctx->GetOutputVarPtrs("Out");
for (auto &out_var : out_vars) {
auto *out_var_desc = boost::get<framework::VarDesc *>(out_var);
auto out_name = out_var_desc->Name();
if (out_name == framework::kEmptyVarName ||
out_name.size() < kGradVarSuffix.size()) {
continue;
}
size_t len = out_name.size() - kGradVarSuffix.size();
if (out_name.substr(len) == kGradVarSuffix) {
auto fwd_var_name = out_name.substr(0, len);
auto *in_var_desc = block->FindVarRecursive(fwd_var_name);
PADDLE_ENFORCE_NOT_NULL(in_var_desc, "Forward variable %s not found",
fwd_var_name);
out_var_desc->SetShape(in_var_desc->GetShape());
out_var_desc->SetDataType(in_var_desc->GetDataType());
out_var_desc->SetLoDLevel(in_var_desc->GetLoDLevel());
out_var_desc->SetType(in_var_desc->GetType());
}
}
} }
}; };
...@@ -77,12 +112,68 @@ class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -77,12 +112,68 @@ class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("X", "Inputs of py_func op.").AsDuplicable(); AddInput("X", "Inputs of py_func op.").AsDuplicable();
AddOutput("Out", "Outputs of py_func op").AsDuplicable(); AddOutput("Out", "Outputs of py_func op").AsDuplicable();
AddAttr<std::string>("token", "function token"); AddAttr<int>("handle_idx", "Index of the registered py_func handle")
AddAttr<int>("handle_idx", "handle index").SetDefault(0); .SetDefault(0);
AddAttr<std::string>("token", "Token of function token to be called")
.SetDefault("");
AddAttr<std::string>("backward_token",
"Token of backward function to be called")
.SetDefault("");
AddComment(R"DOC("PyFunc Op")DOC"); AddComment(R"DOC("PyFunc Op")DOC");
} }
}; };
class PyFuncOpGradDescMaker : public framework::GradOpDescMakerBase {
public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
auto &fwd_attrs = Attrs();
if (fwd_attrs.at("backward_token").empty()) {
return {};
}
std::unique_ptr<framework::OpDesc> grad_op(new framework::OpDesc());
grad_op->SetType("py_func");
framework::AttributeMap bwd_attrs;
bwd_attrs["token"] = fwd_attrs.at("backward_token");
bwd_attrs["backward_token"] = std::string("");
grad_op->SetAttrMap(bwd_attrs);
auto bwd_in = Input("X");
auto fwd_out = Output("Out");
auto fwd_out_grad = OutputGrad("Out");
bwd_in.insert(bwd_in.end(), fwd_out.begin(), fwd_out.end());
bwd_in.insert(bwd_in.end(), fwd_out_grad.begin(), fwd_out_grad.end());
auto bwd_out = InputGrad("X", false);
if (VLOG_IS_ON(10)) {
std::string in_str = "PyFunc Grad Input: ";
for (auto &in : bwd_in) {
in_str += in;
in_str += " ";
}
VLOG(10) << in_str;
std::string out_str = "PyFunc Grad Output: ";
for (auto &out : bwd_out) {
out_str += out;
out += " ";
}
VLOG(10) << out_str;
}
grad_op->SetInput("X", bwd_in);
grad_op->SetOutput("Out", InputGrad("X", false));
std::vector<std::unique_ptr<framework::OpDesc>> ret(1);
ret[0] = std::move(grad_op);
return ret;
}
};
class PyFuncOp : public framework::OperatorBase { class PyFuncOp : public framework::OperatorBase {
public: public:
using framework::OperatorBase::OperatorBase; using framework::OperatorBase::OperatorBase;
...@@ -95,8 +186,14 @@ class PyFuncOp : public framework::OperatorBase { ...@@ -95,8 +186,14 @@ class PyFuncOp : public framework::OperatorBase {
std::vector<framework::LoDTensor> inputs(in_arg_names.size()); std::vector<framework::LoDTensor> inputs(in_arg_names.size());
for (size_t i = 0; i < in_arg_names.size(); ++i) { for (size_t i = 0; i < in_arg_names.size(); ++i) {
auto &in_tensor = auto in_var = scope.FindVar(in_arg_names[i]);
scope.FindVar(in_arg_names[i])->Get<framework::LoDTensor>(); if (in_var == nullptr) {
continue;
}
auto &in_tensor = in_var->Get<framework::LoDTensor>();
if (!in_tensor.IsInitialized()) {
continue;
}
if (platform::is_gpu_place(in_tensor.place())) { if (platform::is_gpu_place(in_tensor.place())) {
framework::TensorCopySync(in_tensor, platform::CPUPlace(), &inputs[i]); framework::TensorCopySync(in_tensor, platform::CPUPlace(), &inputs[i]);
} else { } else {
...@@ -107,8 +204,9 @@ class PyFuncOp : public framework::OperatorBase { ...@@ -107,8 +204,9 @@ class PyFuncOp : public framework::OperatorBase {
std::vector<framework::LoDTensor *> outputs(out_arg_names.size()); std::vector<framework::LoDTensor *> outputs(out_arg_names.size());
for (size_t i = 0; i < out_arg_names.size(); ++i) { for (size_t i = 0; i < out_arg_names.size(); ++i) {
auto *out_var = scope.FindVar(out_arg_names[i]);
auto *out_tensor = auto *out_tensor =
scope.FindVar(out_arg_names[i])->GetMutable<framework::LoDTensor>(); out_var ? out_var->GetMutable<framework::LoDTensor>() : nullptr;
outputs[i] = out_tensor; outputs[i] = out_tensor;
} }
...@@ -117,7 +215,7 @@ class PyFuncOp : public framework::OperatorBase { ...@@ -117,7 +215,7 @@ class PyFuncOp : public framework::OperatorBase {
auto *py_callable = GetPythonCallableObject(handle_idx); auto *py_callable = GetPythonCallableObject(handle_idx);
VLOG(10) << "Call py_func_op with token " << token << ", and handle_idx " VLOG(10) << "Call py_func_op with token " << token << ", and handle_idx "
<< handle_idx; << handle_idx;
DoCallPythonFunc(py_callable, token, inputs, &outputs); CallPythonFunc(py_callable, token, inputs, &outputs);
} }
}; };
...@@ -127,5 +225,4 @@ class PyFuncOp : public framework::OperatorBase { ...@@ -127,5 +225,4 @@ class PyFuncOp : public framework::OperatorBase {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker, REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker,
ops::PyFuncOpShapeInference, ops::PyFuncOpShapeInference, ops::PyFuncOpGradDescMaker);
paddle::framework::EmptyGradOpMaker);
...@@ -328,7 +328,7 @@ void BindOpDesc(pybind11::module *m) { ...@@ -328,7 +328,7 @@ void BindOpDesc(pybind11::module *m) {
.def("infer_var_type", &pd::OpDesc::InferVarType) .def("infer_var_type", &pd::OpDesc::InferVarType)
.def("set_is_target", &pd::OpDesc::SetIsTarget) .def("set_is_target", &pd::OpDesc::SetIsTarget)
.def("serialize_to_string", SerializeMessage<pd::OpDesc>) .def("serialize_to_string", SerializeMessage<pd::OpDesc>)
.def("block", &pd::OpDesc::Block, .def("block", [](pd::OpDesc &self) { return self.Block(); },
pybind11::return_value_policy::reference); pybind11::return_value_policy::reference);
} }
......
...@@ -9096,12 +9096,9 @@ def py_func(func, x, out, backward_func=None): ...@@ -9096,12 +9096,9 @@ def py_func(func, x, out, backward_func=None):
_main_program_to_register = dict() _main_program_to_register = dict()
@classmethod @classmethod
def get_instance(cls, prog=None): def get_instance(cls, prog):
if prog is None:
prog = fluid.default_main_program()
if not isinstance(prog, Program): if not isinstance(prog, Program):
raise ValueError("prog must be None or type of Program") raise TypeError("prog must be type of Program")
ret = cls._main_program_to_register.get(prog, None) ret = cls._main_program_to_register.get(prog, None)
if ret is None: if ret is None:
...@@ -9155,6 +9152,10 @@ def py_func(func, x, out, backward_func=None): ...@@ -9155,6 +9152,10 @@ def py_func(func, x, out, backward_func=None):
ret = [] ret = []
for i in six.moves.range(len(ret0)): for i in six.moves.range(len(ret0)):
if ret0[i] is None:
ret.append(None)
continue
if isinstance(ret0[i], core.LoDTensor): if isinstance(ret0[i], core.LoDTensor):
ret.append(ret0[i]) ret.append(ret0[i])
continue continue
...@@ -9175,20 +9176,34 @@ def py_func(func, x, out, backward_func=None): ...@@ -9175,20 +9176,34 @@ def py_func(func, x, out, backward_func=None):
x = [x] x = [x]
if isinstance(out, Variable): if isinstance(out, Variable):
out = [out] out_list = [out]
else:
out_list = out
if func is None or not hasattr(func, '__call__'):
raise TypeError('Input func must be a function')
for each_out in out: if backward_func is not None and not hasattr(backward_func, '__call__'):
raise TypeError('Input backward_func must be a function')
for each_out in out_list:
if len(each_out.shape) == 0: if len(each_out.shape) == 0:
raise ValueError( raise ValueError(
'users should infer shapes of outputs of py_func op manually') 'Output shapes of py_func op should be provided by users manually'
)
py_func_reg = PyFuncRegister.get_instance(helper.main_program) py_func_reg = PyFuncRegister.get_instance(helper.main_program)
token = py_func_reg.unique_token(func) forward_token = py_func_reg.unique_token(func)
backward_token = py_func_reg.unique_token(
backward_func) if backward_func is not None else ''
helper.append_op( helper.append_op(
type='py_func', type='py_func',
inputs={'X': x}, inputs={'X': x},
outputs={'Out': out}, outputs={'Out': out_list},
attrs={'handle_idx': py_func_reg.handle_idx, attrs={
'token': token}) 'handle_idx': py_func_reg.handle_idx,
'token': forward_token,
'backward_token': backward_token
})
return out return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册