提交 e2ba1337 编写于 作者: Q qijun

enable operator gpu unittest

上级 72b5bd93
......@@ -137,6 +137,8 @@ class Tensor {
const DDim& dims() const { return dims_; }
paddle::platform::Place place() const { return holder_->place(); }
private:
// Placeholder hides type T, so it doesn't appear as a template
// parameter of Variable.
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
#include "paddle/framework/scope.h"
#include "paddle/platform/place.h"
#include "paddle/pybind/tensor_bind.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
......@@ -62,12 +63,12 @@ PYBIND11_PLUGIN(core) {
self.Resize(pd::make_ddim(dim));
})
.def("alloc_float",
[](pd::Tensor& self) {
self.mutable_data<float>(paddle::platform::CPUPlace());
[](pd::Tensor& self, paddle::platform::Place& place) {
self.mutable_data<float>(place);
})
.def("alloc_int",
[](pd::Tensor& self) {
self.mutable_data<int>(paddle::platform::CPUPlace());
[](pd::Tensor& self, paddle::platform::Place& place) {
self.mutable_data<int>(place);
})
.def("set", paddle::pybind::PyTensorSetFromArray<float>)
.def("set", paddle::pybind::PyTensorSetFromArray<int>)
......@@ -122,9 +123,20 @@ All parameter, weight, gradient are variables in Paddle.
.def("temp", pd::OperatorBase::TMP_VAR_NAME);
py::class_<paddle::platform::DeviceContext>(m, "DeviceContext")
.def_static("cpu_context", []() -> paddle::platform::DeviceContext* {
return new paddle::platform::CPUDeviceContext();
});
.def_static(
"create",
[](paddle::platform::Place) -> paddle::platform::DeviceContext* {
if (paddle::platform::is_gpu_place(place)) {
return new paddle::platform::GPUDeviceContext(place);
} else if (paddle::platform::is_cpu_place(place)) {
return new paddle::platform::CPUDeviceContext();
}
});
py::class_<paddle::platform::Place>(m, "GPUPlace").def(py::init<int>());
.def(py::init<>());
py::class_<paddle::platform::Place>(m, "CPUPlace").def(py::init<>());
py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base(
m, "Operator");
......
......@@ -13,9 +13,10 @@
limitations under the License. */
#pragma once
#include <paddle/framework/tensor.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include "paddle/framework/tensor.h"
#include "paddle/memory/memcpy.h"
#include "pybind11/numpy.h"
#include "pybind11/pybind11.h"
namespace py = pybind11;
......@@ -56,7 +57,6 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
strides[i - 1] = sizeof(CUR_TYPE) * prod;
prod *= dims_outside[i - 1];
}
return py::buffer_info(
tensor.mutable_data<CUR_TYPE>(tensor.holder_->place()),
sizeof(CUR_TYPE),
......@@ -87,8 +87,25 @@ void PyTensorSetFromArray(
}
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(paddle::platform::CPUPlace());
std::memcpy(dst, array.data(), sizeof(T) * array.size());
auto *dst = self.mutable_data<T>(self.place());
if (paddle::platform::is_cpu_place(self.place())) {
paddle::memory::Copy<paddle::platform::CPUPlace,
paddle::platform::CPUPlace>(
place, dst, place, array.data(), sizeof(T) * array.size());
} else if (paddle::platform::is_gpu_place(place)) {
#ifdef PADDLE_ONLY_CPU
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#else
paddle::memory::Copy<paddle::platform::GPUPlace,
paddle::platform::CPUPlace>(
place,
dst,
paddle::platform::CPUPlace(),
array.data(),
sizeof(T) * array.size());
#endif
}
}
} // namespace pybind
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册