From d0dbc0610fd41d10ebb5abc133b25976e53484db Mon Sep 17 00:00:00 2001 From: dangqingqing Date: Mon, 11 Sep 2017 12:10:01 +0800 Subject: [PATCH] Correctly use host_vector in LoDTensor and expose LoDTensor to Python. --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/lod_tensor.h | 17 ++++- paddle/framework/lod_tensor_test.cu | 52 +++++++++++++++ paddle/operators/math/im2col_test.cc | 4 +- paddle/pybind/pybind.cc | 43 ++++++++++--- .../paddle/v2/framework/tests/test_tensor.py | 63 ++++++++++++++++++- 6 files changed, 164 insertions(+), 16 deletions(-) create mode 100644 paddle/framework/lod_tensor_test.cu diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index c0838d9b75..3371962c63 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -9,6 +9,7 @@ cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor) +nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor) cc_test(variable_test SRCS variable_test.cc) diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 154068fef6..bbddd6de9d 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -18,8 +18,10 @@ #ifndef PADDLE_ONLY_CPU #include #include +#include #endif +#include #include "paddle/framework/ddim.h" #include "paddle/framework/tensor.h" #include "paddle/platform/enforce.h" @@ -32,7 +34,8 @@ template using Vector = std::vector; #else template -using Vector = thrust::host_vector; +using Vector = thrust::host_vector< + T, thrust::system::cuda::experimental::pinned_allocator>; #endif using LoD = std::vector>; @@ -53,7 +56,17 @@ class LoDTensor { LoDTensor() {} LoDTensor(const LoD& lod, Tensor* t) : lod_(lod), tensor_(t) {} - void set_lod(const LoD& lod) { lod_ = lod; } + void set_lod(const LoD& lod) { + lod_ = lod; + LOG(INFO) << lod_[0][0]; + } + +#ifdef PADDLE_ONLY_CPU + void set_lod(const std::vector>& lod) { + lod_ = lod; + LOG(INFO) << lod_[0][0]; + } +#endif void set_tensor(Tensor* tensor) { tensor_ = tensor; } diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu new file mode 100644 index 0000000000..1079a36a2e --- /dev/null +++ b/paddle/framework/lod_tensor_test.cu @@ -0,0 +1,52 @@ +/* + Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +#include +#include +#include "paddle/framework/lod_tensor.h" +#include "paddle/platform/assert.h" + +#include + +__global__ void test(size_t* a, int size) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; + i += blockDim.x * gridDim.x) { + a[i] *= 2; + } +} + +TEST(LoDTensor, LoDInGPU) { + paddle::framework::Tensor tensor; + paddle::framework::LoDTensor lod_tensor; + paddle::platform::GPUPlace place(0); + + paddle::framework::LoD src_lod; + src_lod.push_back(std::vector{0, 2, 4, 6, 8, 10, 12, 14}); + + tensor.Resize({14, 16}); + tensor.mutable_data(place); + + lod_tensor.set_lod(src_lod); + lod_tensor.set_tensor(&tensor); + CHECK_EQ(lod_tensor.lod_element(0, 2), 4); + CHECK_EQ(lod_tensor.lod_element(0, 4), 8); + + auto lod = lod_tensor.lod(); + + test<<<1, 8>>>(lod[0].data(), lod[0].size()); + cudaDeviceSynchronize(); + + for (size_t i = 0; i < src_lod[0].size(); ++i) { + CHECK_EQ(lod[0].data()[i], src_lod[0].data()[i] * 2); + } +} diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index ee5fb98acd..f905600bb3 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -71,8 +71,10 @@ void testIm2col() { context = new paddle::platform::CPUDeviceContext(paddle::platform::CPUPlace()); } else { +#ifndef PADDLE_ONLY_CPU context = new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace()); +#endif } im2col(input, output_cfo, stride, stride, padding, padding, context); im2col_ocf(input, output_ocf, stride, stride, padding, padding, context); @@ -115,4 +117,4 @@ TEST(math, im2col) { #ifndef PADDLE_ONLY_CPU testIm2col(); #endif -} \ No newline at end of file +} diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 30189d538b..73fb7186ae 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -56,7 +56,8 @@ namespace paddle { namespace framework { using Tensor = framework::Tensor; -using LODTensor = framework::LODTensor; +using LoDTensor = framework::LoDTensor; +using LoD = framework::LoD; static size_t UniqueIntegerGenerator() { static std::atomic generator; @@ -116,23 +117,45 @@ PYBIND11_PLUGIN(core) { return self.data()[offset]; }); - py::class_(m, "LODTensor", R"DOC(LOD(Leval of Ddetails) Tensor. + py::class_(m, "LoDTensor", R"DOC(LoD(Leval of Ddetails) Tensor. -The tensor and LOD info should be created before creating the LODTensor, then +The tensor and LoD info should be created before creating the LoDTensor, then call the set_tensor and set_lod functions to set them. )DOC") .def("set_tensor", - [](LODTensor &self, Tensor *tensor) { self.set_tensor(tensor); }) + [](LoDTensor &self, Tensor *tensor) { self.set_tensor(tensor); }) .def("set_lod", - [](LODTensor &self, std::vector> &lod) { + [](LoDTensor &self, std::vector> &lod) { +#ifdef PADDLE_ONLY_CPU self.set_lod(lod); +#else + paddle::framework::LoD new_lod; + new_lod.reserve(lod.size()); + std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); + self.set_lod(new_lod); +#endif }) - .def("get_tensor", - [](LODTensor &self) -> Tensor & { return self.tensor(); }, + .def("tensor", + [](LoDTensor &self) -> Tensor & { return self.tensor(); }, py::return_value_policy::reference) - .def("get_lod", [](LODTensor &self) -> std::vector> { + .def("lod", [](LoDTensor &self) -> std::vector> { +#ifdef PADDLE_ONLY_CPU return self.lod(); +#else + auto lod = self.lod(); + std::vector> new_lod; + new_lod.reserve(lod.size()); + std::transform(lod.begin(), lod.end(), std::back_inserter(new_lod), + [](paddle::framework::Vector item) -> + std::vector { + std::vector v; + v.reserve(item.size()); + std::copy(item.begin(), item.end(), std::back_inserter(v)); + return v; + }); + return new_lod; +#endif }); py::class_(m, "Variable", R"DOC(Variable Class. @@ -147,8 +170,8 @@ All parameter, weight, gradient are variables in Paddle. [](Variable &self) -> Tensor * { return self.GetMutable(); }, py::return_value_policy::reference) .def("get_lod_tensor", - [](Variable &self) -> LODTensor * { - return self.GetMutable(); + [](Variable &self) -> LoDTensor * { + return self.GetMutable(); }, py::return_value_policy::reference) .def("get_net", diff --git a/python/paddle/v2/framework/tests/test_tensor.py b/python/paddle/v2/framework/tests/test_tensor.py index 1af39818a3..1bfe1370e2 100644 --- a/python/paddle/v2/framework/tests/test_tensor.py +++ b/python/paddle/v2/framework/tests/test_tensor.py @@ -3,7 +3,7 @@ import unittest import numpy -class TestScope(unittest.TestCase): +class TestTensor(unittest.TestCase): def test_int_tensor(self): scope = core.Scope() var = scope.new_var("test_tensor") @@ -20,8 +20,8 @@ class TestScope(unittest.TestCase): tensor.set(tensor_array, place) tensor_array_2 = numpy.array(tensor) - self.assertEqual(1.0, tensor_array_2[3, 9]) - self.assertEqual(2.0, tensor_array_2[19, 11]) + self.assertEqual(1, tensor_array_2[3, 9]) + self.assertEqual(2, tensor_array_2[19, 11]) def test_float_tensor(self): scope = core.Scope() @@ -43,6 +43,63 @@ class TestScope(unittest.TestCase): self.assertAlmostEqual(1.0, tensor_array_2[3, 9]) self.assertAlmostEqual(2.0, tensor_array_2[19, 11]) + def test_int_lod_tensor(self): + scope = core.Scope() + var = scope.new_var("test_tensor") + var_lod = scope.new_var("test_lod_tensor") + place = core.CPUPlace() + + tensor = var.get_tensor() + lod_tensor = var_lod.get_lod_tensor() + + tensor.set_dims([4, 4, 6]) + tensor.alloc_int(place) + array = numpy.array(tensor) + array[0, 0, 0] = 3 + array[3, 3, 5] = 10 + tensor.set(array, place) + + lod_tensor.set_tensor(tensor) + lod_tensor.set_lod([[0, 2, 4]]) + + lod_v = numpy.array(lod_tensor.tensor()) + self.assertTrue(numpy.alltrue(array == lod_v)) + + lod = lod_tensor.lod() + self.assertEqual(0, lod[0][0]) + self.assertEqual(2, lod[0][1]) + self.assertEqual(4, lod[0][2]) + + def test_float_lod_tensor(self): + scope = core.Scope() + var = scope.new_var("test_tensor") + var_lod = scope.new_var("test_lod_tensor") + place = core.CPUPlace() + + tensor = var.get_tensor() + lod_tensor = var_lod.get_lod_tensor() + + tensor.set_dims([5, 2, 3, 4]) + tensor.alloc_float(place) + + tensor_array = numpy.array(tensor) + self.assertEqual((5, 2, 3, 4), tensor_array.shape) + tensor_array[0, 0, 0, 0] = 1.0 + tensor_array[0, 0, 0, 1] = 2.0 + tensor.set(tensor_array, place) + + lod_tensor.set_tensor(tensor) + + lod_v = numpy.array(lod_tensor.tensor()) + self.assertAlmostEqual(1.0, lod_v[0, 0, 0, 0]) + self.assertAlmostEqual(2.0, lod_v[0, 0, 0, 1]) + self.assertEqual(len(lod_tensor.lod()), 0) + + lod_py = [[0, 2, 5], [0, 2, 4, 5]] + lod_tensor.set_lod(lod_py) + lod = lod_tensor.lod() + self.assertListEqual(lod_py, lod) + if __name__ == '__main__': unittest.main() -- GitLab