You need to sign in or sign up before continuing.
提交 5509b849 编写于 作者: Y Yu Yang

Merge Develop

上级 f7688bd6
...@@ -44,4 +44,4 @@ cc_library(paddle_pybind SHARED ...@@ -44,4 +44,4 @@ cc_library(paddle_pybind SHARED
add_op add_op
mean_op mean_op
cross_entropy_op cross_entropy_op
recurrent_network_op) recurrent_op)
...@@ -83,29 +83,28 @@ PYBIND11_PLUGIN(core) { ...@@ -83,29 +83,28 @@ PYBIND11_PLUGIN(core) {
self.Resize(make_ddim(dim)); self.Resize(make_ddim(dim));
}) })
.def("alloc_float", .def("alloc_float",
[](pd::Tensor &self, paddle::platform::GPUPlace &place) { [](Tensor &self, paddle::platform::GPUPlace &place) {
self.mutable_data<float>(place); self.mutable_data<float>(place);
}) })
.def("alloc_float", .def("alloc_float",
[](pd::Tensor &self, paddle::platform::CPUPlace &place) { [](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<float>(place); self.mutable_data<float>(place);
}) })
.def("alloc_int", .def("alloc_int",
[](pd::Tensor &self, paddle::platform::CPUPlace &place) { [](Tensor &self, paddle::platform::CPUPlace &place) {
self.mutable_data<int>(place); self.mutable_data<int>(place);
}) })
.def("alloc_int", .def("alloc_int",
[](pd::Tensor &self, paddle::platform::GPUPlace &place) { [](Tensor &self, paddle::platform::GPUPlace &place) {
self.mutable_data<int>(place); self.mutable_data<int>(place);
}) })
.def("set", paddle::pybind::PyCPUTensorSetFromArray<float>) .def("set", PyCPUTensorSetFromArray<float>)
.def("set", paddle::pybind::PyCPUTensorSetFromArray<int>) .def("set", PyCPUTensorSetFromArray<int>)
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
.def("set", paddle::pybind::PyCUDATensorSetFromArray<float>) .def("set", PyCUDATensorSetFromArray<float>)
.def("set", paddle::pybind::PyCUDATensorSetFromArray<int>) .def("set", PyCUDATensorSetFromArray<int>)
#endif #endif
.def("shape", .def("shape", [](Tensor &self) { return vectorize(self.dims()); });
[](pd::Tensor &self) { return pd::vectorize(self.dims()); });
py::class_<Variable>(m, "Variable", R"DOC(Variable Class. py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
...@@ -152,8 +151,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -152,8 +151,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def_submodule( m.def_submodule(
"var_names", "var_names",
"The module will return special predefined variable name in Paddle") "The module will return special predefined variable name in Paddle")
.def("empty", pd::OperatorBase::EMPTY_VAR_NAME) .def("empty", OperatorBase::EMPTY_VAR_NAME)
.def("temp", pd::OperatorBase::TMP_VAR_NAME); .def("temp", OperatorBase::TMP_VAR_NAME);
// clang-format off // clang-format off
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext") py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("create", .def_static("create",
...@@ -190,9 +189,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -190,9 +189,9 @@ All parameter, weight, gradient are variables in Paddle.
}); });
operator_base.def("backward", operator_base.def("backward",
[](const pd::OperatorBase &forwardOp, [](const OperatorBase &forwardOp,
const std::unordered_set<std::string> &no_grad_vars) { const std::unordered_set<std::string> &no_grad_vars) {
return pd::Backward(forwardOp, no_grad_vars); return Backward(forwardOp, no_grad_vars);
}); });
ExposeOperator(operator_base); ExposeOperator(operator_base);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册