提交 deb0d41c 编写于 作者: S sneaxiy

fix cmake

fix cmake again
test=develop
上级 e7c5c9d2
...@@ -42,8 +42,7 @@ if (WITH_DISTRIBUTE) ...@@ -42,8 +42,7 @@ if (WITH_DISTRIBUTE)
SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch) SET(OP_PREFETCH_DEPS ${OP_PREFETCH_DEPS} parameter_prefetch)
endif() endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS}) register_operators(EXCLUDES py_func_op warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS} ${OP_PREFETCH_DEPS})
# warpctc_op needs cudnn 7 above # warpctc_op needs cudnn 7 above
if (WITH_GPU AND NOT WIN32) if (WITH_GPU AND NOT WIN32)
...@@ -82,7 +81,7 @@ endif() ...@@ -82,7 +81,7 @@ endif()
# op_library(unstack_op DEPS stack_op) # op_library(unstack_op DEPS stack_op)
# op_library(tensor_array_to_tensor_op DEPS concat_op) # op_library(tensor_array_to_tensor_op DEPS concat_op)
set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS} python pybind) set(OPERATOR_DEPS ${OPERATOR_DEPS} ${COMMON_OP_DEPS})
set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies") set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies")
cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_test(gather_test SRCS gather_test.cc DEPS tensor)
...@@ -94,4 +93,8 @@ cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op) ...@@ -94,4 +93,8 @@ cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op) cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor) nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
if (WITH_PYTHON)
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
endif()
set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library") set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace py = pybind11; namespace py = ::pybind11;
static std::vector<py::object> g_py_callables; static std::vector<py::object> g_py_callables;
...@@ -30,7 +30,7 @@ const char kForwardPythonCallableId[] = "forward_callable_id"; ...@@ -30,7 +30,7 @@ const char kForwardPythonCallableId[] = "forward_callable_id";
const char kBackwardPythonCallableId[] = "backward_callable_id"; const char kBackwardPythonCallableId[] = "backward_callable_id";
const char kPyFuncBackwardSkipVars[] = "backward_skip_vars"; const char kPyFuncBackwardSkipVars[] = "backward_skip_vars";
size_t AppendPythonCallableObjectAndReturnId(py::object py_obj) { size_t AppendPythonCallableObjectAndReturnId(const py::object &py_obj) {
g_py_callables.emplace_back(py_obj); g_py_callables.emplace_back(py_obj);
return g_py_callables.size() - 1; return g_py_callables.size() - 1;
} }
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
size_t AppendPythonCallableObjectAndReturnId(pybind11::object py_obj); size_t AppendPythonCallableObjectAndReturnId(const ::pybind11::object &py_obj);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler layer) set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder parallel_executor profiler layer)
if(WITH_PYTHON)
list(APPEND PYBIND_DEPS py_func_op)
endif()
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc) set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc recordio.cc async_executor_py.cc imperative.cc)
if(WITH_PYTHON) if(WITH_PYTHON)
......
...@@ -9173,31 +9173,22 @@ class PyFuncWrapper(object): ...@@ -9173,31 +9173,22 @@ class PyFuncWrapper(object):
kwargs[arg] = args[idx] kwargs[arg] = args[idx]
idx += 1 idx += 1
ret0 = self._func(*args[idx:], **kwargs) func_ret = self._func(*args[idx:], **kwargs)
if ret0 is None: if not isinstance(func_ret, (list, tuple)):
return None func_ret = (func_ret, )
if not isinstance(ret0, (list, tuple)):
ret0 = (ret0, )
ret = [] ret = []
for i in six.moves.range(len(ret0)): for each_ret in func_ret:
if ret0[i] is None: if each_ret is None or isinstance(each_ret, core.LoDTensor):
ret.append(None) ret.append(each_ret)
continue
if isinstance(ret0[i], core.LoDTensor):
ret.append(ret0[i])
continue continue
if isinstance(ret0[i], np.ndarray): if not isinstance(each_ret, np.ndarray):
r = ret0[i] each_ret = np.array(each_ret)
else:
r = np.array(ret0[i])
t = core.LoDTensor() tensor = core.LoDTensor()
t.set(r, core.CPUPlace()) tensor.set(each_ret, core.CPUPlace())
ret.append(t) ret.append(tensor)
return tuple(ret) return tuple(ret)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册