提交 d0dbc061 编写于 作者: D dangqingqing

Correctly use host_vector in LoDTensor and expose LoDTensor to Python.

上级 5f905248
...@@ -9,6 +9,7 @@ cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) ...@@ -9,6 +9,7 @@ cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor)
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor)
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
cc_test(variable_test SRCS variable_test.cc) cc_test(variable_test SRCS variable_test.cc)
......
...@@ -18,8 +18,10 @@ ...@@ -18,8 +18,10 @@
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include <thrust/system/cuda/experimental/pinned_allocator.h>
#endif #endif
#include <glog/logging.h>
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
...@@ -32,7 +34,8 @@ template <typename T> ...@@ -32,7 +34,8 @@ template <typename T>
using Vector = std::vector<T>; using Vector = std::vector<T>;
#else #else
template <typename T> template <typename T>
using Vector = thrust::host_vector<T>; using Vector = thrust::host_vector<
T, thrust::system::cuda::experimental::pinned_allocator<T>>;
#endif #endif
using LoD = std::vector<Vector<size_t>>; using LoD = std::vector<Vector<size_t>>;
...@@ -53,7 +56,17 @@ class LoDTensor { ...@@ -53,7 +56,17 @@ class LoDTensor {
LoDTensor() {} LoDTensor() {}
LoDTensor(const LoD& lod, Tensor* t) : lod_(lod), tensor_(t) {} LoDTensor(const LoD& lod, Tensor* t) : lod_(lod), tensor_(t) {}
void set_lod(const LoD& lod) { lod_ = lod; } void set_lod(const LoD& lod) {
lod_ = lod;
LOG(INFO) << lod_[0][0];
}
#ifdef PADDLE_ONLY_CPU
void set_lod(const std::vector<std::vector<size_t>>& lod) {
lod_ = lod;
LOG(INFO) << lod_[0][0];
}
#endif
void set_tensor(Tensor* tensor) { tensor_ = tensor; } void set_tensor(Tensor* tensor) { tensor_ = tensor; }
......
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include "paddle/framework/lod_tensor.h"
#include "paddle/platform/assert.h"
#include <gtest/gtest.h>
__global__ void test(size_t* a, int size) {
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size;
i += blockDim.x * gridDim.x) {
a[i] *= 2;
}
}
TEST(LoDTensor, LoDInGPU) {
paddle::framework::Tensor tensor;
paddle::framework::LoDTensor lod_tensor;
paddle::platform::GPUPlace place(0);
paddle::framework::LoD src_lod;
src_lod.push_back(std::vector<size_t>{0, 2, 4, 6, 8, 10, 12, 14});
tensor.Resize({14, 16});
tensor.mutable_data<float>(place);
lod_tensor.set_lod(src_lod);
lod_tensor.set_tensor(&tensor);
CHECK_EQ(lod_tensor.lod_element(0, 2), 4);
CHECK_EQ(lod_tensor.lod_element(0, 4), 8);
auto lod = lod_tensor.lod();
test<<<1, 8>>>(lod[0].data(), lod[0].size());
cudaDeviceSynchronize();
for (size_t i = 0; i < src_lod[0].size(); ++i) {
CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2);
}
}
...@@ -71,8 +71,10 @@ void testIm2col() { ...@@ -71,8 +71,10 @@ void testIm2col() {
context = context =
new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace());
} else { } else {
#ifndef PADDLE_ONLY_CPU
context = context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
#endif
} }
im2col(input, output_cfo, stride, stride, padding, padding, context); im2col(input, output_cfo, stride, stride, padding, padding, context);
im2col_ocf(input, output_ocf, stride, stride, padding, padding, context); im2col_ocf(input, output_ocf, stride, stride, padding, padding, context);
......
...@@ -56,7 +56,8 @@ namespace paddle { ...@@ -56,7 +56,8 @@ namespace paddle {
namespace framework { namespace framework {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using LODTensor = framework::LODTensor; using LoDTensor = framework::LoDTensor;
using LoD = framework::LoD;
static size_t UniqueIntegerGenerator() { static size_t UniqueIntegerGenerator() {
static std::atomic<size_t> generator; static std::atomic<size_t> generator;
...@@ -116,23 +117,45 @@ PYBIND11_PLUGIN(core) { ...@@ -116,23 +117,45 @@ PYBIND11_PLUGIN(core) {
return self.data<float>()[offset]; return self.data<float>()[offset];
}); });
py::class_<LODTensor>(m, "LODTensor", R"DOC(LOD(Leval of Ddetails) Tensor. py::class_<LoDTensor>(m, "LoDTensor", R"DOC(LoD(Leval of Ddetails) Tensor.
The tensor and LOD info should be created before creating the LODTensor, then The tensor and LoD info should be created before creating the LoDTensor, then
call the set_tensor and set_lod functions to set them. call the set_tensor and set_lod functions to set them.
)DOC") )DOC")
.def("set_tensor", .def("set_tensor",
[](LODTensor &self, Tensor *tensor) { self.set_tensor(tensor); }) [](LoDTensor &self, Tensor *tensor) { self.set_tensor(tensor); })
.def("set_lod", .def("set_lod",
[](LODTensor &self, std::vector<std::vector<size_t>> &lod) { [](LoDTensor &self, std::vector<std::vector<size_t>> &lod) {
#ifdef PADDLE_ONLY_CPU
self.set_lod(lod); self.set_lod(lod);
#else
paddle::framework::LoD new_lod;
new_lod.reserve(lod.size());
std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod));
self.set_lod(new_lod);
#endif
}) })
.def("get_tensor", .def("tensor",
[](LODTensor &self) -> Tensor & { return self.tensor(); }, [](LoDTensor &self) -> Tensor & { return self.tensor(); },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("get_lod", [](LODTensor &self) -> std::vector<std::vector<size_t>> { .def("lod", [](LoDTensor &self) -> std::vector<std::vector<size_t>> {
#ifdef PADDLE_ONLY_CPU
return self.lod(); return self.lod();
#else
auto lod = self.lod();
std::vector<std::vector<size_t>> new_lod;
new_lod.reserve(lod.size());
std::transform(lod.begin(), lod.end(), std::back_inserter(new_lod),
[](paddle::framework::Vector<size_t> item) ->
std::vector<size_t> {
std::vector<size_t> v;
v.reserve(item.size());
std::copy(item.begin(), item.end(), std::back_inserter(v));
return v;
});
return new_lod;
#endif
}); });
py::class_<Variable>(m, "Variable", R"DOC(Variable Class. py::class_<Variable>(m, "Variable", R"DOC(Variable Class.
...@@ -147,8 +170,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -147,8 +170,8 @@ All parameter, weight, gradient are variables in Paddle.
[](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); }, [](Variable &self) -> Tensor * { return self.GetMutable<Tensor>(); },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("get_lod_tensor", .def("get_lod_tensor",
[](Variable &self) -> LODTensor * { [](Variable &self) -> LoDTensor * {
return self.GetMutable<LODTensor>(); return self.GetMutable<LoDTensor>();
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("get_net", .def("get_net",
......
...@@ -3,7 +3,7 @@ import unittest ...@@ -3,7 +3,7 @@ import unittest
import numpy import numpy
class TestScope(unittest.TestCase): class TestTensor(unittest.TestCase):
def test_int_tensor(self): def test_int_tensor(self):
scope = core.Scope() scope = core.Scope()
var = scope.new_var("test_tensor") var = scope.new_var("test_tensor")
...@@ -20,8 +20,8 @@ class TestScope(unittest.TestCase): ...@@ -20,8 +20,8 @@ class TestScope(unittest.TestCase):
tensor.set(tensor_array, place) tensor.set(tensor_array, place)
tensor_array_2 = numpy.array(tensor) tensor_array_2 = numpy.array(tensor)
self.assertEqual(1.0, tensor_array_2[3, 9]) self.assertEqual(1, tensor_array_2[3, 9])
self.assertEqual(2.0, tensor_array_2[19, 11]) self.assertEqual(2, tensor_array_2[19, 11])
def test_float_tensor(self): def test_float_tensor(self):
scope = core.Scope() scope = core.Scope()
...@@ -43,6 +43,63 @@ class TestScope(unittest.TestCase): ...@@ -43,6 +43,63 @@ class TestScope(unittest.TestCase):
self.assertAlmostEqual(1.0, tensor_array_2[3, 9]) self.assertAlmostEqual(1.0, tensor_array_2[3, 9])
self.assertAlmostEqual(2.0, tensor_array_2[19, 11]) self.assertAlmostEqual(2.0, tensor_array_2[19, 11])
def test_int_lod_tensor(self):
scope = core.Scope()
var = scope.new_var("test_tensor")
var_lod = scope.new_var("test_lod_tensor")
place = core.CPUPlace()
tensor = var.get_tensor()
lod_tensor = var_lod.get_lod_tensor()
tensor.set_dims([4, 4, 6])
tensor.alloc_int(place)
array = numpy.array(tensor)
array[0, 0, 0] = 3
array[3, 3, 5] = 10
tensor.set(array, place)
lod_tensor.set_tensor(tensor)
lod_tensor.set_lod([[0, 2, 4]])
lod_v = numpy.array(lod_tensor.tensor())
self.assertTrue(numpy.alltrue(array == lod_v))
lod = lod_tensor.lod()
self.assertEqual(0, lod[0][0])
self.assertEqual(2, lod[0][1])
self.assertEqual(4, lod[0][2])
def test_float_lod_tensor(self):
scope = core.Scope()
var = scope.new_var("test_tensor")
var_lod = scope.new_var("test_lod_tensor")
place = core.CPUPlace()
tensor = var.get_tensor()
lod_tensor = var_lod.get_lod_tensor()
tensor.set_dims([5, 2, 3, 4])
tensor.alloc_float(place)
tensor_array = numpy.array(tensor)
self.assertEqual((5, 2, 3, 4), tensor_array.shape)
tensor_array[0, 0, 0, 0] = 1.0
tensor_array[0, 0, 0, 1] = 2.0
tensor.set(tensor_array, place)
lod_tensor.set_tensor(tensor)
lod_v = numpy.array(lod_tensor.tensor())
self.assertAlmostEqual(1.0, lod_v[0, 0, 0, 0])
self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1])
self.assertEqual(len(lod_tensor.lod()), 0)
lod_py = [[0, 2, 5], [0, 2, 4, 5]]
lod_tensor.set_lod(lod_py)
lod = lod_tensor.lod()
self.assertListEqual(lod_py, lod)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册