提交 358261f0 编写于 作者: Q qijun

fix gpu build error

上级 2ddef137
...@@ -64,23 +64,25 @@ PYBIND11_PLUGIN(core) { ...@@ -64,23 +64,25 @@ PYBIND11_PLUGIN(core) {
self.Resize(pd::make_ddim(dim)); self.Resize(pd::make_ddim(dim));
}) })
.def("alloc_float", .def("alloc_float",
[](pd::Tensor& self, paddle::platform::Place& place) { [](pd::Tensor& self, paddle::platform::GPUPlace& place) {
self.mutable_data<float>(place); self.mutable_data<float>(place);
}) })
.def("alloc_float", .def("alloc_float",
[](pd::Tensor& self) { [](pd::Tensor& self, paddle::platform::CPUPlace& place) {
self.mutable_data<float>(paddle::platform::CPUPlace()); self.mutable_data<float>(place);
}) })
.def("alloc_int", .def("alloc_int",
[](pd::Tensor& self, paddle::platform::Place& place) { [](pd::Tensor& self, paddle::platform::CPUPlace& place) {
self.mutable_data<int>(place); self.mutable_data<int>(place);
}) })
.def("alloc_int", .def("alloc_int",
[](pd::Tensor& self) { [](pd::Tensor& self, paddle::platform::GPUPlace& place) {
self.mutable_data<int>(paddle::platform::CPUPlace()); self.mutable_data<int>(place);
}) })
.def("set", paddle::pybind::PyTensorSetFromArray<float>) .def("set", paddle::pybind::PyCPUTensorSetFromArray<float>)
.def("set", paddle::pybind::PyTensorSetFromArray<int>) .def("set", paddle::pybind::PyCUDATensorSetFromArray<float>)
.def("set", paddle::pybind::PyCPUTensorSetFromArray<int>)
.def("set", paddle::pybind::PyCUDATensorSetFromArray<int>)
.def("shape", .def("shape",
[](pd::Tensor& self) { return pd::vectorize(self.dims()); }); [](pd::Tensor& self) { return pd::vectorize(self.dims()); });
...@@ -144,9 +146,9 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -144,9 +146,9 @@ All parameter, weight, gradient are variables in Paddle.
}) })
#endif #endif
; // NOLINT ; // NOLINT
py::class_<paddle::platform::Place>(m, "GPUPlace").def(py::init<int>()); py::class_<paddle::platform::GPUPlace>(m, "GPUPlace").def(py::init<int>());
py::class_<paddle::platform::Place>(m, "CPUPlace").def(py::init<>()); py::class_<paddle::platform::CPUPlace>(m, "CPUPlace").def(py::init<>());
py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base( py::class_<pd::OperatorBase, std::shared_ptr<pd::OperatorBase>> operator_base(
m, "Operator"); m, "Operator");
......
...@@ -61,7 +61,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> { ...@@ -61,7 +61,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
framework::Tensor dst_tensor; framework::Tensor dst_tensor;
if (paddle::platform::is_gpu_place(tensor.holder_->place())) { if (paddle::platform::is_gpu_place(tensor.holder_->place())) {
dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace()); dst_tensor.CopyFrom<CUR_TYPE>(tensor, platform::CPUPlace());
} else if (paddle::platform::is_gpu_place(tensor.holder_->place())) { } else if (paddle::platform::is_cpu_place(tensor.holder_->place())) {
dst_tensor = tensor; dst_tensor = tensor;
} }
return py::buffer_info( return py::buffer_info(
...@@ -84,9 +84,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) { ...@@ -84,9 +84,10 @@ inline py::buffer_info CastToPyBuffer(framework::Tensor &tensor) {
} }
template <typename T> template <typename T>
void PyTensorSetFromArray( void PyCPUTensorSetFromArray(
framework::Tensor &self, framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array) { py::array_t<T, py::array::c_style | py::array::forcecast> array,
paddle::platform::CPUPlace &place) {
std::vector<int> dims; std::vector<int> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) { for (size_t i = 0; i < array.ndim(); ++i) {
...@@ -94,18 +95,26 @@ void PyTensorSetFromArray( ...@@ -94,18 +95,26 @@ void PyTensorSetFromArray(
} }
self.Resize(framework::make_ddim(dims)); self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(self.place()); auto *dst = self.mutable_data<T>(place);
std::memcpy(dst, array.data(), sizeof(T) * array.size());
if (paddle::platform::is_cpu_place(self.place())) { }
std::memcpy(dst, array.data(), sizeof(T) * array.size());
} else if (paddle::platform::is_gpu_place(self.place())) { template <typename T>
#ifdef PADDLE_ONLY_CPU void PyCUDATensorSetFromArray(
PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); framework::Tensor &self,
#else py::array_t<T, py::array::c_style | py::array::forcecast> array,
platform::GpuMemcpySync( paddle::platform::GPUPlace &place) {
dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice); std::vector<int> dims;
#endif dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
dims.push_back((int)array.shape()[i]);
} }
self.Resize(framework::make_ddim(dims));
auto *dst = self.mutable_data<T>(place);
std::memcpy(dst, array.data(), sizeof(T) * array.size());
paddle::platform::GpuMemcpySync(
dst, array.data(), sizeof(T) * array.size(), cudaMemcpyHostToDevice);
} }
} // namespace pybind } // namespace pybind
......
...@@ -25,6 +25,7 @@ class OpTestMeta(type): ...@@ -25,6 +25,7 @@ class OpTestMeta(type):
self.assertIsNotNone(func) self.assertIsNotNone(func)
scope = core.Scope(None) scope = core.Scope(None)
place = core.CPUPlace()
kwargs = dict() kwargs = dict()
for in_name in func.all_input_args: for in_name in func.all_input_args:
...@@ -33,7 +34,7 @@ class OpTestMeta(type): ...@@ -33,7 +34,7 @@ class OpTestMeta(type):
var = scope.create_var(in_name).get_tensor() var = scope.create_var(in_name).get_tensor()
arr = getattr(self, in_name) arr = getattr(self, in_name)
var.set_dims(arr.shape) var.set_dims(arr.shape)
var.set(arr) var.set(arr, place)
else: else:
kwargs[in_name] = "@EMPTY@" kwargs[in_name] = "@EMPTY@"
......
...@@ -7,17 +7,18 @@ import paddle.v2.framework.create_op_creation_methods as creation ...@@ -7,17 +7,18 @@ import paddle.v2.framework.create_op_creation_methods as creation
class TestFc(unittest.TestCase): class TestFc(unittest.TestCase):
def test_fc(self): def test_fc(self):
scope = core.Scope(None) scope = core.Scope(None)
place = core.CPUPlace()
x = scope.create_var("X") x = scope.create_var("X")
x_tensor = x.get_tensor() x_tensor = x.get_tensor()
x_tensor.set_dims([1000, 784]) x_tensor.set_dims([1000, 784])
x_tensor.alloc_float() x_tensor.alloc_float(place)
w = scope.create_var("W") w = scope.create_var("W")
w_tensor = w.get_tensor() w_tensor = w.get_tensor()
w_tensor.set_dims([784, 100]) w_tensor.set_dims([784, 100])
w_tensor.alloc_float() w_tensor.alloc_float(place)
w_tensor.set(numpy.random.random((784, 100)).astype("float32")) w_tensor.set(numpy.random.random((784, 100)).astype("float32"), place)
# Set a real numpy array here. # Set a real numpy array here.
# x_tensor.set(numpy.array([])) # x_tensor.set(numpy.array([]))
......
...@@ -7,16 +7,16 @@ class TestScope(unittest.TestCase): ...@@ -7,16 +7,16 @@ class TestScope(unittest.TestCase):
def test_int_tensor(self): def test_int_tensor(self):
scope = core.Scope(None) scope = core.Scope(None)
var = scope.create_var("test_tensor") var = scope.create_var("test_tensor")
place = core.CPUPlace()
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set_dims([1000, 784]) tensor.set_dims([1000, 784])
tensor.alloc_int() tensor.alloc_int(place)
tensor_array = numpy.array(tensor) tensor_array = numpy.array(tensor)
self.assertEqual((1000, 784), tensor_array.shape) self.assertEqual((1000, 784), tensor_array.shape)
tensor_array[3, 9] = 1 tensor_array[3, 9] = 1
tensor_array[19, 11] = 2 tensor_array[19, 11] = 2
tensor.set(tensor_array) tensor.set(tensor_array, place)
tensor_array_2 = numpy.array(tensor) tensor_array_2 = numpy.array(tensor)
self.assertEqual(1.0, tensor_array_2[3, 9]) self.assertEqual(1.0, tensor_array_2[3, 9])
...@@ -25,16 +25,17 @@ class TestScope(unittest.TestCase): ...@@ -25,16 +25,17 @@ class TestScope(unittest.TestCase):
def test_float_tensor(self): def test_float_tensor(self):
scope = core.Scope(None) scope = core.Scope(None)
var = scope.create_var("test_tensor") var = scope.create_var("test_tensor")
place = core.CPUPlace()
tensor = var.get_tensor() tensor = var.get_tensor()
tensor.set_dims([1000, 784]) tensor.set_dims([1000, 784])
tensor.alloc_float() tensor.alloc_float(place)
tensor_array = numpy.array(tensor) tensor_array = numpy.array(tensor)
self.assertEqual((1000, 784), tensor_array.shape) self.assertEqual((1000, 784), tensor_array.shape)
tensor_array[3, 9] = 1.0 tensor_array[3, 9] = 1.0
tensor_array[19, 11] = 2.0 tensor_array[19, 11] = 2.0
tensor.set(tensor_array) tensor.set(tensor_array, place)
tensor_array_2 = numpy.array(tensor) tensor_array_2 = numpy.array(tensor)
self.assertAlmostEqual(1.0, tensor_array_2[3, 9]) self.assertAlmostEqual(1.0, tensor_array_2[3, 9])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册