From 07dd3d25b39878b6ccc4736e189c015cfd2265d2 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 5 Feb 2018 01:53:43 -0800 Subject: [PATCH] "fix const warning" --- paddle/framework/CMakeLists.txt | 1 + paddle/framework/lod_tensor_test.cu | 22 -------- paddle/framework/mixed_vector_test.cu | 72 +++++++++++++++++++++++++++ 3 files changed, 73 insertions(+), 22 deletions(-) create mode 100644 paddle/framework/mixed_vector_test.cu diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 8b71f73c36..7c4ba3afb9 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -20,6 +20,7 @@ endif() cc_test(eigen_test SRCS eigen_test.cc DEPS tensor) +nv_test(mixed_vector_test SRCS mixed_vector_test.cu DEPS place paddle_memory device_context init) cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto) cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor paddle_memory) nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor init) diff --git a/paddle/framework/lod_tensor_test.cu b/paddle/framework/lod_tensor_test.cu index d4c9f00bd9..adea02e3b3 100644 --- a/paddle/framework/lod_tensor_test.cu +++ b/paddle/framework/lod_tensor_test.cu @@ -28,28 +28,6 @@ __global__ void test(size_t* a, int size) { } } -TEST(Vector, Normal) { - using namespace paddle::framework; - using namespace paddle::platform; - using namespace paddle::memory; - - paddle::framework::InitDevices(); - - paddle::framework::Vector vec({1, 2, 3}); - size_t* ptr = vec.data(); - for (size_t i = 0; i < vec.size(); ++i) { - EXPECT_EQ(vec[i], *(ptr + i)); - } - - vec.clear(); - vec.CopyFromCUDA(); - - std::vector v = {1, 2, 3}; - for (size_t i = 0; i < v.size(); ++i) { - EXPECT_EQ(v[i], vec[i]); - } -} - TEST(LoD, data) { paddle::framework::InitDevices(); diff --git a/paddle/framework/mixed_vector_test.cu b/paddle/framework/mixed_vector_test.cu new file mode 100644 index 0000000000..7b571788ad --- /dev/null +++ b/paddle/framework/mixed_vector_test.cu @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ +#include +#include +#include "gtest/gtest.h" + +#include "paddle/framework/init.h" +#include "paddle/framework/mixed_vector.h" + +using namespace paddle::framework; +using namespace paddle::platform; +using namespace paddle::memory; + +template +__global__ void test(T* data, int size) { + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < size; + i += blockDim.x * gridDim.x) { + data[i] *= 2; + } +} + +TEST(Vector, Normal) { + // fill the device context pool. + InitDevices(); + + Vector vec({1, 2, 3}); + size_t* ptr = vec.data(); + for (size_t i = 0; i < vec.size(); ++i) { + EXPECT_EQ(vec[i], *(ptr + i)); + } + + vec.clear(); + vec.CopyFromCUDA(); + + std::vector v = {1, 2, 3}; + for (size_t i = 0; i < v.size(); ++i) { + EXPECT_EQ(v[i], vec[i]); + } +} + +TEST(Vector, MultipleCopy) { + InitDevices(); + Vector vec({1, 2, 3}); + CUDAPlace place(0); + vec.mutable_data(place); + auto vec2 = Vector(vec); + { + const size_t* ptr = vec2.data(CPUPlace()); + for (size_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(*(ptr + i), vec[i]); + } + } + test<<<3, 3>>>(vec2.mutable_data(place), vec2.size()); + vec2.CopyFromCUDA(); + { + const size_t* ptr = vec2.data(CPUPlace()); + for (size_t i = 0; i < vec2.size(); ++i) { + EXPECT_EQ(*(ptr + i), vec[i] * 2); + } + } +} -- GitLab