提交 deb0d41c 编写于 作者: S sneaxiy

fix cmake

fix cmake again
test=develop
上级 e7c5c9d2
......@@ -42,8 +42,7 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
# warpctc_op needs cudnn 7 above
if (WITH_GPU AND NOT WIN32)
......@@ -82,7 +81,7 @@ endif()
# op_library(unstack_op DEPS stack_op)
# op_library(tensor_array_to_tensor_op DEPS concat_op)
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS} python pybind)
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS})
set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies")
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
......@@ -94,4 +93,8 @@ cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
if (WITH_PYTHON)
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
endif()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
......@@ -22,7 +22,7 @@
namespace paddle {
namespace operators {
namespace py = pybind11;
namespace py = ::pybind11;
static std::vector<py::object> g_py_callables;
......@@ -30,7 +30,7 @@ const char kForwardPythonCallableId[] = "forward_callable_id";
const char kBackwardPythonCallableId[] = "backward_callable_id";
const char kPyFuncBackwardSkipVars[] = "backward_skip_vars";
size_t AppendPythonCallableObjectAndReturnId(py::object py_obj) {
size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) {
g_py_callables.emplace_back(py_obj);
return g_py_callables.size() - 1;
}
......
......@@ -19,7 +19,7 @@
namespace paddle {
namespace operators {
size_t AppendPythonCallableObjectAndReturnId(pybind11::object py_obj);
size_t AppendPythonCallableObjectAndReturnId(const ::pybind11::object &py_obj);
} // namespace operators
} // namespace paddle
set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler layer)
if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op)
endif()
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc)
if(WITH_PYTHON)
......
......@@ -9173,31 +9173,22 @@ class PyFuncWrapper(object):
kwargs[arg] = args[idx]
idx += 1
ret0 = self._func(*args[idx:], **kwargs)
if ret0 is None:
return None
if not isinstance(ret0, (list, tuple)):
ret0 = (ret0, )
func_ret = self._func(*args[idx:], **kwargs)
if not isinstance(func_ret, (list, tuple)):
func_ret = (func_ret, )
ret = []
for i in six.moves.range(len(ret0)):
if ret0[i] is None:
ret.append(None)
continue
if isinstance(ret0[i], core.LoDTensor):
ret.append(ret0[i])
for each_ret in func_ret:
if each_ret is None or isinstance(each_ret, core.LoDTensor):
ret.append(each_ret)
continue
if isinstance(ret0[i], np.ndarray):
r = ret0[i]
else:
r = np.array(ret0[i])
if not isinstance(each_ret, np.ndarray):
each_ret = np.array(each_ret)
t = core.LoDTensor()
t.set(r, core.CPUPlace())
ret.append(t)
tensor = core.LoDTensor()
tensor.set(each_ret, core.CPUPlace())
ret.append(tensor)
return tuple(ret)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册