提交 e2ba1337 编写于 作者: Q qijun

enable operator gpu unittest

上级 72b5bd93
...@@ -137,6 +137,8 @@ class Tensor { ...@@ -137,6 +137,8 @@ class Tensor {
const DDim& dims() const { return dims_; } const DDim& dims() const { return dims_; }
paddle::platform::Place place() const { return holder_->place(); }
private: private:
// Placeholder hides type T, so it doesn't appear as a template // Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable. // parameter of Variable.
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/platform/place.h"
#include "paddle/pybind/tensor_bind.h" #include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h" #include "pybind11/numpy.h"
#include "pybind11/pybind11.h" #include "pybind11/pybind11.h"
...@@ -62,12 +63,12 @@ PYBIND11_PLUGIN(core) { ...@@ -62,12 +63,12 @@ PYBIND11_PLUGIN(core) {
self.Resize(pd::make_ddim(dim)); self.Resize(pd::make_ddim(dim));
}) })
.def("alloc_float", .def("alloc_float",
[](pd::Tensor& self) { [](pd::Tensor& self, paddle::platform::Place& place) {
self.mutable_data<float>(paddle::platform::CPUPlace()); self.mutable_data<float>(place);
}) })
.def("alloc_int", .def("alloc_int",
[](pd::Tensor& self) { [](pd::Tensor& self, paddle::platform::Place& place) {
self.mutable_data<int>(paddle::platform::CPUPlace()); self.mutable_data<int>(place);
}) })
.def("set", paddle::pybind::PyTensorSetFromArray<float>) .def("set", paddle::pybind::PyTensorSetFromArray<float>)
.def("set", paddle::pybind::PyTensorSetFromArray<int>) .def("set", paddle::pybind::PyTensorSetFromArray<int>)
...@@ -122,9 +123,20 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -122,9 +123,20 @@ All parameter, weight, gradient are variables in Paddle.
.def("temp", pd::OperatorBase::TMP_VAR_NAME); .def("temp", pd::OperatorBase::TMP_VAR_NAME);
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext") py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("cpu_context", []() -> paddle::platform::DeviceContext* { .def_static(
return new paddle::platform::CPUDeviceContext(); "create",
}); [](paddle::platform::Place) -> paddle::platform::DeviceContext* {
if (paddle::platform::is_gpu_place(place)) {
return new paddle::platform::GPUDeviceContext(place);
} else if (paddle::platform::is_cpu_place(place)) {
return new paddle::platform::CPUDeviceContext();
}
});
py::class_<paddle::platform::Place>(m, "GPUPlace").def(py::init<int>());
.def(py::init<>());
py::class_<paddle::platform::Place>(m, "CPUPlace").def(py::init<>());
py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base( py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base(
m, "Operator"); m, "Operator");
......
...@@ -13,9 +13,10 @@ ...@@ -13,9 +13,10 @@
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <paddle/framework/tensor.h> #include "paddle/framework/tensor.h"
#include <pybind11/numpy.h> #include "paddle/memory/memcpy.h"
#include <pybind11/pybind11.h> #include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
namespace py = pybind11; namespace py = pybind11;
...@@ -56,7 +57,6 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -56,7 +57,6 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
strides[i - 1] = sizeof(CUR_TYPE) * prod; strides[i - 1] = sizeof(CUR_TYPE) * prod;
prod *= dims_outside[i - 1]; prod *= dims_outside[i - 1];
} }
return py::buffer_info( return py::buffer_info(
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()), tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()),
sizeof(CUR_TYPE), sizeof(CUR_TYPE),
...@@ -87,8 +87,25 @@ void PyTensorSetFromArray( ...@@ -87,8 +87,25 @@ void PyTensorSetFromArray(
} }
self.Resize(framework::make_ddim(dims)); self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(paddle::platform::CPUPlace()); auto *dst = self.mutable_data<T>(self.place());
std::memcpy(dst, array.data(), sizeof(T) * array.size());
if (paddle::platform::is_cpu_place(self.place())) {
paddle::memory::Copy<paddle::platform::CPUPlace,
paddle::platform::CPUPlace>(
place, dst, place, array.data(), sizeof(T) * array.size());
} else if (paddle::platform::is_gpu_place(place)) {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#else
paddle::memory::Copy<paddle::platform::GPUPlace,
paddle::platform::CPUPlace>(
place,
dst,
paddle::platform::CPUPlace(),
array.data(),
sizeof(T) * array.size());
#endif
}
} }
} // namespace pybind } // namespace pybind
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册