提交 8760d23c 编写于 作者: S sneaxiy

featue/py_func

上级 943ad478
......@@ -82,7 +82,7 @@ endif()
# op_library(unstack_op DEPS stack_op)
# op_library(tensor_array_to_tensor_op DEPS concat_op)
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS})
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS} python pybind)
set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies")
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/py_func_op.h"
#include <set>
#include <string>
#include <vector>
#include "Python.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
namespace py = pybind11;
static std::mutex g_py_callables_mtx;
static std::vector<py::object> g_py_callables;
size_t AppendPythonCallableObjectAndReturnId(py::object py_obj) {
std::lock_guard<std::mutex> guard(g_py_callables_mtx);
g_py_callables.emplace_back(py_obj);
return g_py_callables.size() - 1;
}
static py::object *GetPythonCallableObject(size_t i) {
std::lock_guard<std::mutex> guard(g_py_callables_mtx);
PADDLE_ENFORCE_LT(i, g_py_callables.size());
return &g_py_callables[i];
}
void DoCallPythonFunc(py::object *callable, const std::string &func_token,
const std::vector<framework::LoDTensor> &ins,
std::vector<framework::LoDTensor *> *out) {
py::gil_scoped_acquire guard{};
py::tuple in_args(ins.size());
for (size_t i = 0; i < ins.size(); ++i) {
in_args[i] = py::cast(ins[i]);
}
auto ret = (*callable)(func_token, *in_args);
auto ret_tuple = py::cast<py::tuple>(ret);
PADDLE_ENFORCE_EQ(py::len(ret_tuple), out->size(), "Output number not match");
for (size_t i = 0; i < out->size(); ++i) {
try {
auto *out_tensor = py::cast<framework::LoDTensor *>(ret_tuple[i]);
PADDLE_ENFORCE_NOT_NULL(out_tensor,
"Output tensor should not be nullptr");
(*out)[i]->set_lod(out_tensor->lod());
(*out)[i]->ShareDataWith(*out_tensor);
} catch (py::cast_error &) {
PADDLE_THROW("Output %d is not LoDTensor", i);
}
}
}
class PyFuncOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"), "Input(X) must exist");
PADDLE_ENFORCE(ctx->HasOutputs("Out"), "Output(Out) must exist");
}
};
class PyFuncOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Inputs of py_func op.").AsDuplicable();
AddOutput("Out", "Outputs of py_func op").AsDuplicable();
AddAttr<std::string>("token", "function token");
AddAttr<int>("handle_idx", "handle index").SetDefault(0);
AddComment(R"DOC("PyFunc Op")DOC");
}
};
class PyFuncOp : public framework::OperatorBase {
public:
using framework::OperatorBase::OperatorBase;
protected:
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
auto &in_arg_names = Inputs("X");
auto &out_arg_names = Outputs("Out");
std::vector<framework::LoDTensor> inputs(in_arg_names.size());
for (size_t i = 0; i < in_arg_names.size(); ++i) {
auto &in_tensor =
scope.FindVar(in_arg_names[i])->Get<framework::LoDTensor>();
if (platform::is_gpu_place(in_tensor.place())) {
framework::TensorCopySync(in_tensor, platform::CPUPlace(), &inputs[i]);
} else {
inputs[i].ShareDataWith(in_tensor);
}
inputs[i].set_lod(in_tensor.lod());
}
std::vector<framework::LoDTensor *> outputs(out_arg_names.size());
for (size_t i = 0; i < out_arg_names.size(); ++i) {
auto *out_tensor =
scope.FindVar(out_arg_names[i])->GetMutable<framework::LoDTensor>();
outputs[i] = out_tensor;
}
auto &token = Attr<std::string>("token");
auto handle_idx = static_cast<size_t>(Attr<int>("handle_idx"));
auto *py_callable = GetPythonCallableObject(handle_idx);
VLOG(10) << "Call py_func_op with token " << token << ", and handle_idx "
<< handle_idx;
DoCallPythonFunc(py_callable, token, inputs, &outputs);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(py_func, ops::PyFuncOp, ops::PyFuncOpMaker,
ops::PyFuncOpShapeInference,
paddle::framework::EmptyGradOpMaker);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "pybind11/pybind11.h"
namespace paddle {
namespace operators {
size_t AppendPythonCallableObjectAndReturnId(pybind11::object py_obj);
} // namespace operators
} // namespace paddle
......@@ -36,6 +36,7 @@ limitations under the License. */
#include "paddle/fluid/framework/version.h"
#include "paddle/fluid/memory/allocation/allocator_strategy.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/py_func_op.h"
#include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h"
#include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/enforce.h"
......@@ -100,6 +101,12 @@ PYBIND11_MODULE(core, m) {
BindException(&m);
m.def(
"append_python_callable_object_and_return_id",
[](py::object py_obj) -> size_t {
return paddle::operators::AppendPythonCallableObjectAndReturnId(py_obj);
});
py::class_<Tensor>(m, "Tensor", py::buffer_protocol())
.def_buffer(
[](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); })
......@@ -525,6 +532,20 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<platform::Place>(m, "Place")
.def(py::init<>())
.def("is_cpu_place",
[](platform::Place &self) { return platform::is_cpu_place(self); })
.def("is_gpu_place",
[](platform::Place &self) { return platform::is_gpu_place(self); })
.def("is_cuda_pinned_place",
[](platform::Place &self) {
return platform::is_cuda_pinned_place(self);
})
.def("gpu_device_id",
[](platform::Place &self) {
PADDLE_ENFORCE(platform::is_gpu_place(self),
"gpu_device_id() only supports in CUDAPlace");
return boost::get<platform::CUDAPlace>(self).device;
})
.def("set_place",
[](platform::Place &self, const platform::CPUPlace &cpu_place) {
self = cpu_place;
......
......@@ -18,10 +18,12 @@ All layers just related to the neural network.
from __future__ import print_function
import numpy as np
import six
import os
import inspect
from ..layer_helper import LayerHelper
from ..initializer import Normal, Constant
from ..framework import Variable, OpProtoHolder
from ..framework import Variable, OpProtoHolder, Program
from ..param_attr import ParamAttr
from .layer_function_generator import autodoc, templatedoc, _generate_doc_string_
from .tensor import concat
......@@ -172,6 +174,7 @@ __all__ = [
'merge_selected_rows',
'get_tensor_from_selected_rows',
'lstm',
'py_func',
]
kIgnoreIndex = -100
......@@ -9082,3 +9085,110 @@ def get_tensor_from_selected_rows(x, name=None):
outputs={'Out': out},
attrs={})
return out
@templatedoc()
def py_func(func, x, out, backward_func=None):
"""
"""
class PyFuncRegister(object):
_main_program_to_register = dict()
@classmethod
def get_instance(cls, prog=None):
if prog is None:
prog = fluid.default_main_program()
if not isinstance(prog, Program):
raise ValueError("prog must be None or type of Program")
ret = cls._main_program_to_register.get(prog, None)
if ret is None:
ret = PyFuncRegister()
ret._idx = core.append_python_callable_object_and_return_id(ret)
ret._token_func_dict = dict()
ret._func_token_dict = dict()
cls._main_program_to_register[prog] = ret
return ret
@property
def handle_idx(self):
return self._idx
def unique_token(self, func):
return self._register_func(func)
def _register_func(self, func):
if func is None:
raise ValueError("func cannot be None")
token = self._func_token_dict.get(func, None)
if token is not None:
return token
token = unique_name.generate('py_func_op_token')
self._token_func_dict[token] = func
self._func_token_dict[func] = token
return token
def __call__(self, token, *args):
func = self._token_func_dict.get(token, None)
if func is None:
raise ValueError("func has not been registered")
arg_list = inspect.getargspec(func)
kwargs = dict()
idx = 0
for arg in arg_list[0]:
kwargs[arg] = args[idx]
idx += 1
args = args[idx:]
ret0 = func(*args, **kwargs)
if ret0 is None:
return None
if not isinstance(ret0, (list, tuple)):
ret0 = (ret0, )
ret = []
for i in six.moves.range(len(ret0)):
if isinstance(ret0[i], core.LoDTensor):
ret.append(ret0[i])
continue
if isinstance(ret0[i], np.ndarray):
r = ret0[i]
else:
r = np.array(ret0[i])
t = core.LoDTensor()
t.set(r, core.CPUPlace())
ret.append(t)
return tuple(ret)
helper = LayerHelper('py_func', **locals())
if isinstance(x, Variable):
x = [x]
if isinstance(out, Variable):
out = [out]
for each_out in out:
if len(each_out.shape) == 0:
raise ValueError(
'users should infer shapes of outputs of py_func op manually')
py_func_reg = PyFuncRegister.get_instance(helper.main_program)
token = py_func_reg.unique_token(func)
helper.append_op(
type='py_func',
inputs={'X': x},
outputs={'Out': out},
attrs={'handle_idx': py_func_reg.handle_idx,
'token': token})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册