From 8760d23c7dbcb4ad5a5b941aca5917514467c86d Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Mon, 10 Dec 2018 13:09:28 +0000 Subject: [PATCH] featue/py_func --- paddle/fluid/operators/CMakeLists.txt | 2 +- paddle/fluid/operators/py_func_op.cc | 131 ++++++++++++++++++++++++++ paddle/fluid/operators/py_func_op.h | 25 +++++ paddle/fluid/pybind/pybind.cc | 21 +++++ python/paddle/fluid/layers/nn.py | 112 +++++++++++++++++++++- 5 files changed, 289 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/operators/py_func_op.cc create mode 100644 paddle/fluid/operators/py_func_op.h diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 257bfc0a3f9..9379122faf3 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -82,7 +82,7 @@ endif() # op_library(unstack_op DEPS stack_op) # op_library(tensor_array_to_tensor_op DEPS concat_op) -set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS}) +set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS} python pybind) set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies") cc_test(gather_test SRCS gather_test.cc DEPS tensor) diff --git a/paddle/fluid/operators/py_func_op.cc b/paddle/fluid/operators/py_func_op.cc new file mode 100644 index 00000000000..86914f30604 --- /dev/null +++ b/paddle/fluid/operators/py_func_op.cc @@ -0,0 +1,131 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/py_func_op.h" +#include +#include +#include +#include "Python.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +namespace py = pybind11; + +static std::mutex g_py_callables_mtx; +static std::vector g_py_callables; + +size_t AppendPythonCallableObjectAndReturnId(py::object py_obj) { + std::lock_guard guard(g_py_callables_mtx); + g_py_callables.emplace_back(py_obj); + return g_py_callables.size() - 1; +} + +static py::object *GetPythonCallableObject(size_t i) { + std::lock_guard guard(g_py_callables_mtx); + PADDLE_ENFORCE_LT(i, g_py_callables.size()); + return &g_py_callables[i]; +} + +void DoCallPythonFunc(py::object *callable, const std::string &func_token, + const std::vector &ins, + std::vector *out) { + py::gil_scoped_acquire guard{}; + py::tuple in_args(ins.size()); + for (size_t i = 0; i < ins.size(); ++i) { + in_args[i] = py::cast(ins[i]); + } + + auto ret = (*callable)(func_token, *in_args); + auto ret_tuple = py::cast(ret); + PADDLE_ENFORCE_EQ(py::len(ret_tuple), out->size(), "Output number not match"); + for (size_t i = 0; i < out->size(); ++i) { + try { + auto *out_tensor = py::cast(ret_tuple[i]); + PADDLE_ENFORCE_NOT_NULL(out_tensor, + "Output tensor should not be nullptr"); + (*out)[i]->set_lod(out_tensor->lod()); + (*out)[i]->ShareDataWith(*out_tensor); + } catch (py::cast_error &) { + PADDLE_THROW("Output %d is not LoDTensor", i); + } + } +} + +class PyFuncOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInputs("X"), "Input(X) must exist"); + PADDLE_ENFORCE(ctx->HasOutputs("Out"), "Output(Out) must exist"); + } +}; + +class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "Inputs of py_func op.").AsDuplicable(); + AddOutput("Out", "Outputs of py_func op").AsDuplicable(); + AddAttr("token", "function token"); + AddAttr("handle_idx", "handle index").SetDefault(0); + AddComment(R"DOC("PyFunc Op")DOC"); + } +}; + +class PyFuncOp : public framework::OperatorBase { + public: + using framework::OperatorBase::OperatorBase; + + protected: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto &in_arg_names = Inputs("X"); + auto &out_arg_names = Outputs("Out"); + + std::vector inputs(in_arg_names.size()); + for (size_t i = 0; i < in_arg_names.size(); ++i) { + auto &in_tensor = + scope.FindVar(in_arg_names[i])->Get(); + if (platform::is_gpu_place(in_tensor.place())) { + framework::TensorCopySync(in_tensor, platform::CPUPlace(), &inputs[i]); + } else { + inputs[i].ShareDataWith(in_tensor); + } + inputs[i].set_lod(in_tensor.lod()); + } + + std::vector outputs(out_arg_names.size()); + for (size_t i = 0; i < out_arg_names.size(); ++i) { + auto *out_tensor = + scope.FindVar(out_arg_names[i])->GetMutable(); + outputs[i] = out_tensor; + } + + auto &token = Attr("token"); + auto handle_idx = static_cast(Attr("handle_idx")); + auto *py_callable = GetPythonCallableObject(handle_idx); + VLOG(10) << "Call py_func_op with token " << token << ", and handle_idx " + << handle_idx; + DoCallPythonFunc(py_callable, token, inputs, &outputs); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker, + ops::PyFuncOpShapeInference, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/py_func_op.h b/paddle/fluid/operators/py_func_op.h new file mode 100644 index 00000000000..e85fa6b5bc3 --- /dev/null +++ b/paddle/fluid/operators/py_func_op.h @@ -0,0 +1,25 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" + +namespace paddle { +namespace operators { + +size_t AppendPythonCallableObjectAndReturnId(pybind11::object py_obj); + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 58ef3da0b23..58da2cea347 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -36,6 +36,7 @@ limitations under the License. */ #include "paddle/fluid/framework/version.h" #include "paddle/fluid/memory/allocation/allocator_strategy.h" #include "paddle/fluid/operators/activation_op.h" +#include "paddle/fluid/operators/py_func_op.h" #include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" #include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/enforce.h" @@ -100,6 +101,12 @@ PYBIND11_MODULE(core, m) { BindException(&m); + m.def( + "append_python_callable_object_and_return_id", + [](py::object py_obj) -> size_t { + return paddle::operators::AppendPythonCallableObjectAndReturnId(py_obj); + }); + py::class_(m, "Tensor", py::buffer_protocol()) .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) @@ -525,6 +532,20 @@ All parameter, weight, gradient are variables in Paddle. py::class_(m, "Place") .def(py::init<>()) + .def("is_cpu_place", + [](platform::Place &self) { return platform::is_cpu_place(self); }) + .def("is_gpu_place", + [](platform::Place &self) { return platform::is_gpu_place(self); }) + .def("is_cuda_pinned_place", + [](platform::Place &self) { + return platform::is_cuda_pinned_place(self); + }) + .def("gpu_device_id", + [](platform::Place &self) { + PADDLE_ENFORCE(platform::is_gpu_place(self), + "gpu_device_id() only supports in CUDAPlace"); + return boost::get(self).device; + }) .def("set_place", [](platform::Place &self, const platform::CPUPlace &cpu_place) { self = cpu_place; diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 4833212d311..92cd53a6c36 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -18,10 +18,12 @@ All layers just related to the neural network. from __future__ import print_function import numpy as np +import six import os +import inspect from ..layer_helper import LayerHelper from ..initializer import Normal, Constant -from ..framework import Variable, OpProtoHolder +from ..framework import Variable, OpProtoHolder, Program from ..param_attr import ParamAttr from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_ from .tensor import concat @@ -172,6 +174,7 @@ __all__ = [ 'merge_selected_rows', 'get_tensor_from_selected_rows', 'lstm', + 'py_func', ] kIgnoreIndex = -100 @@ -9082,3 +9085,110 @@ def get_tensor_from_selected_rows(x, name=None): outputs={'Out': out}, attrs={}) return out + + +@templatedoc() +def py_func(func, x, out, backward_func=None): + """ + """ + + class PyFuncRegister(object): + _main_program_to_register = dict() + + @classmethod + def get_instance(cls, prog=None): + if prog is None: + prog = fluid.default_main_program() + + if not isinstance(prog, Program): + raise ValueError("prog must be None or type of Program") + + ret = cls._main_program_to_register.get(prog, None) + if ret is None: + ret = PyFuncRegister() + ret._idx = core.append_python_callable_object_and_return_id(ret) + ret._token_func_dict = dict() + ret._func_token_dict = dict() + cls._main_program_to_register[prog] = ret + + return ret + + @property + def handle_idx(self): + return self._idx + + def unique_token(self, func): + return self._register_func(func) + + def _register_func(self, func): + if func is None: + raise ValueError("func cannot be None") + + token = self._func_token_dict.get(func, None) + if token is not None: + return token + + token = unique_name.generate('py_func_op_token') + self._token_func_dict[token] = func + self._func_token_dict[func] = token + return token + + def __call__(self, token, *args): + func = self._token_func_dict.get(token, None) + if func is None: + raise ValueError("func has not been registered") + + arg_list = inspect.getargspec(func) + kwargs = dict() + idx = 0 + for arg in arg_list[0]: + kwargs[arg] = args[idx] + idx += 1 + + args = args[idx:] + ret0 = func(*args, **kwargs) + if ret0 is None: + return None + + if not isinstance(ret0, (list, tuple)): + ret0 = (ret0, ) + + ret = [] + for i in six.moves.range(len(ret0)): + if isinstance(ret0[i], core.LoDTensor): + ret.append(ret0[i]) + continue + + if isinstance(ret0[i], np.ndarray): + r = ret0[i] + else: + r = np.array(ret0[i]) + + t = core.LoDTensor() + t.set(r, core.CPUPlace()) + ret.append(t) + + return tuple(ret) + + helper = LayerHelper('py_func', **locals()) + if isinstance(x, Variable): + x = [x] + + if isinstance(out, Variable): + out = [out] + + for each_out in out: + if len(each_out.shape) == 0: + raise ValueError( + 'users should infer shapes of outputs of py_func op manually') + + py_func_reg = PyFuncRegister.get_instance(helper.main_program) + token = py_func_reg.unique_token(func) + + helper.append_op( + type='py_func', + inputs={'X': x}, + outputs={'Out': out}, + attrs={'handle_idx': py_func_reg.handle_idx, + 'token': token}) + return out -- GitLab