提交 65e5aebd 编写于 作者: Q qiaolongfei

fix mixed_vector_test

上级 da035fc6
......@@ -23,7 +23,7 @@ endif()
cc_test(eigen_test SRCS eigen_test.cc DEPS tensor)
if(WITH_GPU)
nv_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
nv_test(mixed_vector_test SRCS mixed_vector_test.cc mixed_vector_test.cu DEPS place memory device_context tensor)
else()
cc_test(mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor)
endif()
......
......@@ -12,18 +12,11 @@
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
#endif
#include <memory>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/platform/gpu_info.h"
template <typename T>
using vec = paddle::framework::Vector<T>;
......@@ -77,58 +70,3 @@ TEST(mixed_vector, Resize) {
vec.push_back(0);
vec.push_back(0);
}
#ifdef PADDLE_WITH_CUDA
static __global__ void multiply_10(int* ptr) {
for (int i = 0; i < 10; ++i) {
ptr[i] *= 10;
}
}
cudaStream_t GetCUDAStream(paddle::platform::CUDAPlace place) {
return reinterpret_cast<const paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place))
->stream();
}
TEST(mixed_vector, GPU_VECTOR) {
vec<int> tmp;
for (int i = 0; i < 10; ++i) {
tmp.push_back(i);
}
ASSERT_EQ(tmp.size(), 10UL);
paddle::platform::CUDAPlace gpu(0);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu)>>>(tmp.MutableData(gpu));
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp[i], i * 10);
}
}
TEST(mixed_vector, MultiGPU) {
if (paddle::platform::GetCUDADeviceCount() < 2) {
LOG(WARNING) << "Skip mixed_vector.MultiGPU since there are not multiple "
"GPUs in your machine.";
return;
}
vec<int> tmp;
for (int i = 0; i < 10; ++i) {
tmp.push_back(i);
}
ASSERT_EQ(tmp.size(), 10UL);
paddle::platform::CUDAPlace gpu0(0);
paddle::platform::SetDeviceId(0);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu0)>>>(tmp.MutableData(gpu0));
paddle::platform::CUDAPlace gpu1(1);
auto* gpu1_ptr = tmp.MutableData(gpu1);
paddle::platform::SetDeviceId(1);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu1)>>>(gpu1_ptr);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp[i], i * 100);
}
}
#endif
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda_runtime.h>
#include <memory>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/platform/gpu_info.h"
template <typename T>
using vec = paddle::framework::Vector<T>;
static __global__ void multiply_10(int* ptr) {
for (int i = 0; i < 10; ++i) {
ptr[i] *= 10;
}
}
cudaStream_t GetCUDAStream(paddle::platform::CUDAPlace place) {
return reinterpret_cast<const paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(place))
->stream();
}
TEST(mixed_vector, GPU_VECTOR) {
vec<int> tmp;
for (int i = 0; i < 10; ++i) {
tmp.push_back(i);
}
ASSERT_EQ(tmp.size(), 10UL);
paddle::platform::CUDAPlace gpu(0);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu)>>>(tmp.MutableData(gpu));
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp[i], i * 10);
}
}
TEST(mixed_vector, MultiGPU) {
if (paddle::platform::GetCUDADeviceCount() < 2) {
LOG(WARNING) << "Skip mixed_vector.MultiGPU since there are not multiple "
"GPUs in your machine.";
return;
}
vec<int> tmp;
for (int i = 0; i < 10; ++i) {
tmp.push_back(i);
}
ASSERT_EQ(tmp.size(), 10UL);
paddle::platform::CUDAPlace gpu0(0);
paddle::platform::SetDeviceId(0);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu0)>>>(tmp.MutableData(gpu0));
paddle::platform::CUDAPlace gpu1(1);
auto* gpu1_ptr = tmp.MutableData(gpu1);
paddle::platform::SetDeviceId(1);
multiply_10<<<1, 1, 0, GetCUDAStream(gpu1)>>>(gpu1_ptr);
for (int i = 0; i < 10; ++i) {
ASSERT_EQ(tmp[i], i * 100);
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册