提交 45372842 编写于 作者: Y Yang Yu

Merge branch 'develop' of github.com:baidu/Paddle into feature/optimize_adam_speed

......@@ -16,8 +16,6 @@ cmake_minimum_required(VERSION 3.0)
set(CMAKE_MODULE_PATH ${CMAKE_MODULE_PATH} "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
set(PADDLE_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(PADDLE_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
SET(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
include(system)
......@@ -201,6 +199,10 @@ if(WITH_GOLANG)
endif(WITH_GOLANG)
set(PADDLE_PYTHON_BUILD_DIR "${CMAKE_CURRENT_BINARY_DIR}/python/build")
SET(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
SET(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
add_subdirectory(paddle)
if(WITH_PYTHON)
add_subdirectory(python)
......
......@@ -59,6 +59,7 @@ cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
cc_test(threadpool_test SRCS threadpool_test.cc)
cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init)
......
......@@ -54,7 +54,7 @@ bool InitDevices(const std::vector<std::string> &devices) {
#ifdef PADDLE_WITH_CUDA
auto pos = string::RFind(p, ':', string::Piece::npos);
auto number = device.substr(pos + 1);
places.emplace_back(platform::GPUPlace(std::stoi(number)));
places.emplace_back(platform::CUDAPlace(std::stoi(number)));
#else
LOG(WARNING)
<< "'GPU' is not supported, Please re-compile with WITH_GPU option";
......
......@@ -224,7 +224,7 @@ void SerializeToStream(std::ostream &os, const LoDTensor &tensor,
while (size != 0) {
size_t size_to_write = std::min(kBufSize, static_cast<size_t>(size));
memory::Copy(cpu, buf.get(),
boost::get<platform::GPUPlace>(tensor.place()),
boost::get<platform::CUDAPlace>(tensor.place()),
reinterpret_cast<const void *>(data), size_to_write,
gpu_dev_ctx.stream());
gpu_dev_ctx.Wait();
......
......@@ -27,7 +27,7 @@ __global__ void test(size_t* a, int size) {
TEST(LoDTensor, LoDInGPU) {
paddle::framework::LoDTensor lod_tensor;
paddle::platform::GPUPlace place(0);
paddle::platform::CUDAPlace place(0);
paddle::framework::LoD src_lod;
src_lod.push_back(std::vector<size_t>{0, 2, 4, 6, 8, 10, 12, 14});
......
......@@ -37,13 +37,13 @@ TEST(OpKernelType, Hash) {
using OpKernelType = paddle::framework::OpKernelType;
using DataType = paddle::framework::proto::DataType;
using CPUPlace = paddle::platform::CPUPlace;
using GPUPlace = paddle::platform::GPUPlace;
using CUDAPlace = paddle::platform::CUDAPlace;
using DataLayout = paddle::framework::DataLayout;
using LibraryType = paddle::framework::LibraryType;
OpKernelType op_kernel_type_1(DataType::FP32, CPUPlace(), DataLayout::kNCHW,
LibraryType::kCUDNN);
OpKernelType op_kernel_type_2(DataType::FP32, GPUPlace(0), DataLayout::kNCHW,
OpKernelType op_kernel_type_2(DataType::FP32, CUDAPlace(0), DataLayout::kNCHW,
LibraryType::kCUDNN);
OpKernelType::Hash hasher;
......
......@@ -188,7 +188,7 @@ class OpKernelRegistrar : public Registrar {
}
#define REGISTER_OP_CUDA_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::GPUPlace, __VA_ARGS__)
REGISTER_OP_KERNEL(op_type, CUDA, ::paddle::platform::CUDAPlace, __VA_ARGS__)
#define REGISTER_OP_CPU_KERNEL(op_type, ...) \
REGISTER_OP_KERNEL(op_type, CPU, ::paddle::platform::CPUPlace, __VA_ARGS__)
......
......@@ -71,7 +71,7 @@ private:
```
```c++
typedef boost::variant<GpuPlace, CpuPlace> Place;
typedef boost::variant<CUDAPlace, CpuPlace> Place;
typedef boost::variant<Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>,
Dim<6>, Dim<7>, Dim<8>, Dim<9>> DDimVar;
typedef boost::variant<
......
......@@ -125,11 +125,11 @@ inline void* Tensor::mutable_data(platform::Place place, std::type_index type) {
boost::get<platform::CPUPlace>(place), size, type));
} else if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
}
#else
holder_.reset(new PlaceholderImpl<platform::GPUPlace>(
boost::get<platform::GPUPlace>(place), size, type));
holder_.reset(new PlaceholderImpl<platform::CUDAPlace>(
boost::get<platform::CUDAPlace>(place), size, type));
}
#endif
offset_ = 0;
......
......@@ -80,20 +80,20 @@ TEST(Tensor, MutableData) {
float* p1 = nullptr;
float* p2 = nullptr;
// initialization
p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), GPUPlace());
p1 = src_tensor.mutable_data<float>(make_ddim({1, 2, 3}), CUDAPlace());
EXPECT_NE(p1, nullptr);
// set src_tensor a new dim with large size
// momery is supposed to be re-allocated
p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), GPUPlace());
p2 = src_tensor.mutable_data<float>(make_ddim({3, 4}), CUDAPlace());
EXPECT_NE(p2, nullptr);
EXPECT_NE(p1, p2);
// set src_tensor a new dim with same size
// momery block is supposed to be unchanged
p1 = src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), GPUPlace());
p1 = src_tensor.mutable_data<float>(make_ddim({2, 2, 3}), CUDAPlace());
EXPECT_EQ(p1, p2);
// set src_tensor a new dim with smaller size
// momery block is supposed to be unchanged
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), GPUPlace());
p2 = src_tensor.mutable_data<float>(make_ddim({2, 2}), CUDAPlace());
EXPECT_EQ(p1, p2);
}
#endif
......@@ -130,7 +130,7 @@ TEST(Tensor, ShareDataWith) {
{
Tensor src_tensor;
Tensor dst_tensor;
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), GPUPlace());
src_tensor.mutable_data<int>(make_ddim({2, 3, 4}), CUDAPlace());
dst_tensor.ShareDataWith(src_tensor);
ASSERT_EQ(src_tensor.data<int>(), dst_tensor.data<int>());
}
......@@ -166,7 +166,7 @@ TEST(Tensor, Slice) {
#ifdef PADDLE_WITH_CUDA
{
Tensor src_tensor;
src_tensor.mutable_data<double>(make_ddim({6, 9}), GPUPlace());
src_tensor.mutable_data<double>(make_ddim({6, 9}), CUDAPlace());
Tensor slice_tensor = src_tensor.Slice(2, 6);
DDim slice_dims = slice_tensor.dims();
ASSERT_EQ(arity(slice_dims), 2);
......@@ -176,11 +176,11 @@ TEST(Tensor, Slice) {
uintptr_t src_data_address =
reinterpret_cast<uintptr_t>(src_tensor.data<double>());
uintptr_t src_mutable_data_address = reinterpret_cast<uintptr_t>(
src_tensor.mutable_data<double>(src_tensor.dims(), GPUPlace()));
src_tensor.mutable_data<double>(src_tensor.dims(), CUDAPlace()));
uintptr_t slice_data_address =
reinterpret_cast<uintptr_t>(slice_tensor.data<double>());
uintptr_t slice_mutable_data_address = reinterpret_cast<uintptr_t>(
slice_tensor.mutable_data<double>(slice_tensor.dims(), GPUPlace()));
slice_tensor.mutable_data<double>(slice_tensor.dims(), CUDAPlace()));
EXPECT_EQ(src_data_address, src_mutable_data_address);
EXPECT_EQ(slice_data_address, slice_mutable_data_address);
EXPECT_EQ(src_data_address + 9 * 2 * sizeof(double), slice_data_address);
......
......@@ -47,11 +47,11 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src_place) && // NOLINT
platform::is_cpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::GPUPlace>(src_place);
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
auto dst_cpu_place = boost::get<platform::CPUPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place);
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_cpu_place, dst_ptr, src_gpu_place, src_ptr, size,
......@@ -59,21 +59,21 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_cpu_place = boost::get<platform::CPUPlace>(src_place);
auto dst_gpu_place = boost::get<platform::GPUPlace>(dst_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place);
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(dst_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_cpu_place, src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
auto src_gpu_place = boost::get<platform::GPUPlace>(src_place);
auto dst_gpu_place = boost::get<platform::GPUPlace>(dst_place);
auto src_gpu_place = boost::get<platform::CUDAPlace>(src_place);
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
auto ctx_place = ctx.GetPlace();
PADDLE_ENFORCE(platform::is_gpu_place(ctx_place));
auto ctx_gpu_place = boost::get<platform::GPUPlace>(ctx_place);
auto ctx_gpu_place = boost::get<platform::CUDAPlace>(ctx_place);
PADDLE_ENFORCE_EQ(src_gpu_place, ctx_gpu_place);
memory::Copy(
dst_gpu_place, dst_ptr, src_gpu_place, src_ptr, size,
......@@ -82,6 +82,28 @@ inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
#endif
}
/**
* @brief CopyFrom support CPU <-> CPU
*/
inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
Tensor* dst) {
src.check_memory_size();
dst->Resize(src.dims());
auto src_place = src.place();
auto src_ptr = src.data<void>();
auto dst_ptr = dst->mutable_data(dst_place, src.type());
auto size = src.numel() * SizeOfType(src.type());
PADDLE_ENFORCE(platform::is_cpu_place(src_place) &&
platform::is_cpu_place(dst_place));
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
/**
* @brief Copy the content of an external vector to a tensor.
*
......@@ -108,13 +130,28 @@ inline void CopyFromVector(const std::vector<T>& src,
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(dst_place)) { // NOLINT
memory::Copy(
boost::get<platform::GPUPlace>(dst_place), dst_ptr, src_place, src_ptr,
boost::get<platform::CUDAPlace>(dst_place), dst_ptr, src_place, src_ptr,
size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
}
/**
* @brief CopyFromVector CPU vector -> CPU Tensor
*/
template <typename T>
inline void CopyFromVector(const std::vector<T>& src, Tensor* dst) {
platform::CPUPlace dst_place = platform::CPUPlace();
auto src_ptr = static_cast<const void*>(src.data());
platform::CPUPlace src_place;
dst->Resize({static_cast<int64_t>(src.size())});
auto dst_ptr = static_cast<void*>(dst->mutable_data<T>(dst_place));
auto size = src.size() * sizeof(T);
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
}
/**
* @brief Copy the content of a tensor to a vector
*
......@@ -141,12 +178,30 @@ inline void CopyToVector(const Tensor& src, const platform::DeviceContext& ctx,
#ifdef PADDLE_WITH_CUDA
else if (platform::is_gpu_place(src.place())) { // NOLINT
memory::Copy(
dst_place, dst_ptr, boost::get<platform::GPUPlace>(src.place()),
dst_place, dst_ptr, boost::get<platform::CUDAPlace>(src.place()),
src_ptr, size,
reinterpret_cast<const platform::CUDADeviceContext&>(ctx).stream());
}
#endif
}
/**
* @brief CopyToVector CPUTensor <-> CPU Vector
*/
template <typename T>
inline void CopyToVector(const Tensor& src, std::vector<T>* dst) {
auto src_ptr = static_cast<const void*>(src.data<T>());
auto size = src.numel() * sizeof(T);
platform::CPUPlace dst_place;
dst->resize(src.numel());
auto dst_ptr = static_cast<void*>(dst->data());
PADDLE_ENFORCE(platform::is_cpu_place(src.place()));
memory::Copy(dst_place, dst_ptr, boost::get<platform::CPUPlace>(src.place()),
src_ptr, size);
}
} // namespace framework
} // namespace paddle
......@@ -17,6 +17,7 @@
namespace paddle {
namespace framework {
TEST(CopyFrom, Tensor) {
Tensor src_tensor;
Tensor dst_tensor;
......@@ -29,7 +30,7 @@ TEST(CopyFrom, Tensor) {
memcpy(src_ptr, arr, 9 * sizeof(int));
auto cpu_place = new platform::CPUPlace();
CopyFrom(src_tensor, *cpu_place, cpu_ctx, &dst_tensor);
CopyFrom(src_tensor, *cpu_place, &dst_tensor);
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
......@@ -58,7 +59,7 @@ TEST(CopyFrom, Tensor) {
memcpy(src_ptr, arr, 9 * sizeof(int));
// CPU Tensor to GPU Tensor
auto gpu_place = new platform::GPUPlace(0);
auto gpu_place = new platform::CUDAPlace(0);
platform::CUDADeviceContext gpu_ctx(*gpu_place);
CopyFrom(src_tensor, *gpu_place, gpu_ctx, &gpu_tensor);
......@@ -104,8 +105,7 @@ TEST(CopyFromVector, Tensor) {
// Copy to CPU Tensor
cpu_tensor.Resize(make_ddim({3, 3}));
auto cpu_place = new paddle::platform::CPUPlace();
CPUDeviceContext cpu_ctx(*cpu_place);
CopyFromVector<int>(src_vec, cpu_ctx, &cpu_tensor);
CopyFromVector<int>(src_vec, &cpu_tensor);
// Compare Tensors
const int* cpu_ptr = cpu_tensor.data<int>();
......@@ -117,7 +117,7 @@ TEST(CopyFromVector, Tensor) {
src_vec.erase(src_vec.begin(), src_vec.begin() + 5);
cpu_tensor.Resize(make_ddim({2, 2}));
CopyFromVector<int>(src_vec, cpu_ctx, &cpu_tensor);
CopyFromVector<int>(src_vec, &cpu_tensor);
cpu_ptr = cpu_tensor.data<int>();
src_ptr = src_vec.data();
ASSERT_NE(src_ptr, cpu_ptr);
......@@ -143,7 +143,7 @@ TEST(CopyFromVector, Tensor) {
// Copy to GPUTensor
gpu_tensor.Resize(make_ddim({3, 3}));
auto gpu_place = new paddle::platform::GPUPlace();
auto gpu_place = new paddle::platform::CUDAPlace();
CUDADeviceContext gpu_ctx(*gpu_place);
CopyFromVector<int>(src_vec, gpu_ctx, &gpu_tensor);
// Copy from GPU to CPU tensor for comparison
......@@ -198,9 +198,8 @@ TEST(CopyToVector, Tensor) {
}
CPUPlace place;
CPUDeviceContext cpu_ctx(place);
std::vector<int> dst;
CopyToVector<int>(src, cpu_ctx, &dst);
CopyToVector<int>(src, &dst);
for (int i = 0; i < 3 * 3; ++i) {
EXPECT_EQ(src_ptr[i], dst[i]);
......@@ -210,7 +209,7 @@ TEST(CopyToVector, Tensor) {
{
std::vector<int> src_vec = {1, 2, 3, 4, 5, 6, 7, 8, 9};
Tensor gpu_tensor;
GPUPlace place;
CUDAPlace place;
CUDADeviceContext gpu_ctx(place);
CopyFromVector<int>(src_vec, gpu_ctx, &gpu_tensor);
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <condition_variable>
#include <cstdio>
#include <functional>
#include <iostream>
#include <mutex>
#include <queue>
#include <thread>
#include "paddle/platform/call_once.h"
#include "paddle/platform/enforce.h"
namespace paddle {
namespace framework {
typedef std::function<void()> Task;
class ThreadPool {
public:
/**
* @brief Get a instance of threadpool, the thread number will
* be specified as the number of hardware thread contexts
*/
static ThreadPool* GetInstance() {
std::call_once(init_flag, &ThreadPool::Init);
return threadpool.get();
}
~ThreadPool() {
{
// notify all threads to stop running
running_ = false;
scheduled_.notify_all();
}
for (auto& t : threads_) {
t->join();
t.reset(nullptr);
}
}
int GetNumThreads() const { return num_threads_; }
int GetAvailable() {
std::unique_lock<std::mutex> lock(mutex_);
return available_;
}
/**
* @brief Push a function to the queue, and will be scheduled and
* executed if a thread is available.
* @param[in] Task will be pushed to the task queue.
*/
void Run(const Task& fn) {
std::unique_lock<std::mutex> lock(mutex_);
tasks_.push(fn);
lock.unlock();
scheduled_.notify_one();
}
/**
* @brief Wait until all the tasks are completed.
*/
void Wait() {
std::unique_lock<std::mutex> lock(mutex_);
completed_.wait(lock, [=] { return Done() == true; });
}
private:
ThreadPool& operator=(const ThreadPool&) = delete;
ThreadPool(const ThreadPool&) = delete;
ThreadPool(int num_threads)
: num_threads_(num_threads), available_(num_threads), running_(true) {
threads_.resize(num_threads);
for (auto& thread : threads_) {
// TODO(Yancey1989): binding the thread on the specify CPU number
thread.reset(new std::thread(std::bind(&ThreadPool::TaskLoop, this)));
}
}
/**
* @brief If the task queue is empty and avaialbe
* is equal to the number of threads, means that
* all tasks are completed.
*
* Note: this function is not thread-safe.
*
* @return true if all tasks are completed.
*/
bool Done() { return tasks_.empty() && available_ == num_threads_; }
void TaskLoop() {
while (running_) {
std::unique_lock<std::mutex> lock(mutex_);
scheduled_.wait(lock, [=] { return !tasks_.empty() || !running_; });
if (!running_) {
break;
}
// pop a task from the task queue
auto task = tasks_.front();
tasks_.pop();
--available_;
lock.unlock();
// run the task
task();
{
std::unique_lock<std::mutex> lock(mutex_);
++available_;
if (Done()) {
completed_.notify_all();
}
}
}
}
static void Init() {
if (threadpool.get() == nullptr) {
// TODO(Yancey1989): specify the max threads number
int num_threads = std::thread::hardware_concurrency();
PADDLE_ENFORCE_GT(num_threads, 0);
threadpool.reset(new ThreadPool(num_threads));
}
}
private:
static std::unique_ptr<ThreadPool> threadpool;
static std::once_flag init_flag;
int num_threads_;
int available_;
bool running_;
std::queue<Task> tasks_;
std::vector<std::unique_ptr<std::thread>> threads_;
std::mutex mutex_;
std::condition_variable scheduled_;
std::condition_variable completed_;
};
std::unique_ptr<ThreadPool> ThreadPool::threadpool(nullptr);
std::once_flag ThreadPool::init_flag;
} // namespace framework
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "threadpool.h"
#include <gtest/gtest.h>
#include <atomic>
#include <chrono>
#include <map>
#include <thread>
namespace framework = paddle::framework;
void do_sum(framework::ThreadPool* pool, std::atomic<int>& sum, int cnt) {
for (int i = 0; i < cnt; ++i) {
pool->Run([&sum]() { sum.fetch_add(1); });
}
}
TEST(ThreadPool, ConcurrentInit) {
framework::ThreadPool* pool;
int concurrent_cnt = 50;
std::vector<std::thread> threads;
for (int i = 0; i < concurrent_cnt; ++i) {
std::thread t([&pool]() { pool = framework::ThreadPool::GetInstance(); });
threads.push_back(std::move(t));
}
for (auto& t : threads) {
t.join();
}
}
TEST(ThreadPool, ConcurrentStart) {
framework::ThreadPool* pool = framework::ThreadPool::GetInstance();
std::atomic<int> sum(0);
std::vector<std::thread> threads;
int concurrent_cnt = 50;
// sum = (n * (n + 1)) / 2
for (int i = 1; i <= concurrent_cnt; ++i) {
std::thread t(do_sum, pool, std::ref(sum), i);
threads.push_back(std::move(t));
}
for (auto& t : threads) {
t.join();
}
pool->Wait();
EXPECT_EQ(sum, ((concurrent_cnt + 1) * concurrent_cnt) / 2);
}
......@@ -12,13 +12,13 @@ p = memory::Alloc(platform::CPUPlace(), 4*1024);
To allocate 4KB memory on the 3rd GPU:
```cpp
p = memory::Alloc(platform::GPUPlace(2), 4*1024);
p = memory::Alloc(platform::CUDAPlace(2), 4*1024);
```
To free memory and check the so-far used amount of memory on a place:
```cpp
auto pl = platform::GPUPlace(0);
auto pl = platform::CUDAPlace(0);
p = memory::Alloc(pl, 4*1024);
cout << memory::Used(pl);
memory::Free(pl, p);
......@@ -36,7 +36,7 @@ template <typename Place> size_t Used(Place);
} // namespace memory
```
These function templates have specializations on either `platform::CPUPlace` or `platform::GPUPlace`:
These function templates have specializations on either `platform::CPUPlace` or `platform::CUDAPlace`:
```cpp
template<>
......@@ -49,7 +49,7 @@ and
```cpp
template<>
void Alloc<GPUPlace>(GPUPlace p, size_t size) {
void Alloc<CUDAPlace>(CUDAPlace p, size_t size) {
return GetGPUBuddyAllocator(p.id)->Alloc(size);
}
```
......@@ -122,7 +122,7 @@ There are two implementations of `Context`:
1. [`CPUContext`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context.h#L105), whose [`New` method](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context.h#L131) calls [`g_cpu_allocator.get()->New(size_t)`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context.cc#L15) to allocate the memory.
1. [`CUDAContext`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.h#L99), which has a data member [`int gpu_id_`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.h#L202). This looks very similar to class `majel::GPUPlace`, who also has an `int id_` data member. `CUDAContext::New(size_t)` calls [`g_cub_allocator->DeviceAllocate(&ptr, nbytes)`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.cu#L355) to allocate the memory.
1. [`CUDAContext`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.h#L99), which has a data member [`int gpu_id_`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.h#L202). This looks very similar to class `majel::CUDAPlace`, who also has an `int id_` data member. `CUDAContext::New(size_t)` calls [`g_cub_allocator->DeviceAllocate(&ptr, nbytes)`](https://github.com/caffe2/caffe2/blob/v0.7.0/caffe2/core/context_gpu.cu#L355) to allocate the memory.
### Majel
......
......@@ -28,31 +28,25 @@ void Copy<platform::CPUPlace, platform::CPUPlace>(platform::CPUPlace, void* dst,
#ifdef PADDLE_WITH_CUDA
template <>
void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
void* dst,
platform::GPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
void Copy<platform::CPUPlace, platform::CUDAPlace>(
platform::CPUPlace dst_place, void* dst, platform::CUDAPlace src_place,
const void* src, size_t num, cudaStream_t stream) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToHost, stream);
}
template <>
void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
void* dst,
platform::CPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
void Copy<platform::CUDAPlace, platform::CPUPlace>(
platform::CUDAPlace dst_place, void* dst, platform::CPUPlace src_place,
const void* src, size_t num, cudaStream_t stream) {
platform::SetDeviceId(dst_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyHostToDevice, stream);
}
template <>
void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
void* dst,
platform::GPUPlace src_place,
const void* src, size_t num,
cudaStream_t stream) {
void Copy<platform::CUDAPlace, platform::CUDAPlace>(
platform::CUDAPlace dst_place, void* dst, platform::CUDAPlace src_place,
const void* src, size_t num, cudaStream_t stream) {
if (dst_place == src_place) {
platform::SetDeviceId(src_place.device);
platform::GpuMemcpyAsync(dst, src, num, cudaMemcpyDeviceToDevice, stream);
......
......@@ -83,12 +83,12 @@ BuddyAllocator* GetGPUBuddyAllocator(int gpu_id) {
}
template <>
size_t Used<platform::GPUPlace>(platform::GPUPlace place) {
size_t Used<platform::CUDAPlace>(platform::CUDAPlace place) {
return GetGPUBuddyAllocator(place.device)->Used();
}
template <>
void* Alloc<platform::GPUPlace>(platform::GPUPlace place, size_t size) {
void* Alloc<platform::CUDAPlace>(platform::CUDAPlace place, size_t size) {
auto* buddy_allocator = GetGPUBuddyAllocator(place.device);
auto* ptr = buddy_allocator->Alloc(size);
if (ptr == nullptr) {
......@@ -101,14 +101,14 @@ void* Alloc<platform::GPUPlace>(platform::GPUPlace place, size_t size) {
LOG(WARNING) << "total " << total;
LOG(WARNING) << "GpuMinChunkSize " << platform::GpuMinChunkSize();
LOG(WARNING) << "GpuMaxChunkSize " << platform::GpuMaxChunkSize();
LOG(WARNING) << "GPU memory used: " << Used<platform::GPUPlace>(place);
LOG(WARNING) << "GPU memory used: " << Used<platform::CUDAPlace>(place);
platform::SetDeviceId(cur_dev);
}
return ptr;
}
template <>
void Free<platform::GPUPlace>(platform::GPUPlace place, void* p) {
void Free<platform::CUDAPlace>(platform::CUDAPlace place, void* p) {
GetGPUBuddyAllocator(place.device)->Free(p);
}
......
......@@ -82,7 +82,7 @@ TEST(BuddyAllocator, CPUMultAlloc) {
#ifdef PADDLE_WITH_CUDA
size_t align(size_t size, paddle::platform::GPUPlace place) {
size_t align(size_t size, paddle::platform::CUDAPlace place) {
size += sizeof(paddle::memory::detail::Metadata);
size_t alignment = paddle::platform::GpuMinChunkSize();
size_t remaining = size % alignment;
......@@ -94,7 +94,7 @@ TEST(BuddyAllocator, GPUAllocation) {
EXPECT_EQ(p, nullptr);
paddle::platform::GPUPlace gpu(0);
paddle::platform::CUDAPlace gpu(0);
p = paddle::memory::Alloc(gpu, 4096);
EXPECT_NE(p, nullptr);
......@@ -103,7 +103,7 @@ TEST(BuddyAllocator, GPUAllocation) {
}
TEST(BuddyAllocator, GPUMultAlloc) {
paddle::platform::GPUPlace gpu;
paddle::platform::CUDAPlace gpu;
std::unordered_map<void *, size_t> ps;
......
......@@ -56,7 +56,7 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"It must use CUDAPlace.");
auto* inference = ctx.Input<Tensor>("Out");
auto* indices = ctx.Input<Tensor>("Indices");
auto* label = ctx.Input<Tensor>("Label");
......
......@@ -53,7 +53,7 @@ class BatchNormKernel<platform::CUDADeviceContext, T>
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"It must use CUDAPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const float momentum = ctx.Attr<float>("momentum");
const bool is_test = ctx.Attr<bool>("is_test");
......@@ -179,7 +179,7 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"It must use CUDAPlace.");
double epsilon = static_cast<double>(ctx.Attr<float>("epsilon"));
const std::string data_layout_str = ctx.Attr<std::string>("data_layout");
const DataLayout data_layout =
......
......@@ -36,7 +36,7 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"It must use CUDAPlace.");
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
......@@ -130,7 +130,7 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
handle, cudnn_input_desc, cudnn_filter_desc, cudnn_conv_desc,
cudnn_output_desc, algo, &workspace_size_in_bytes));
// Allocate on GPU memory
platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward ---------------------
T alpha = 1.0f, beta = 0.0f;
......@@ -151,7 +151,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"It must use CUDAPlace.");
auto input = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter");
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
......@@ -277,7 +277,7 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv workspace ---------------------
// Already on GPU
void* cudnn_workspace = nullptr;
platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
T alpha = 1.0f, beta = 0.0f;
......
......@@ -35,7 +35,7 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"It must use CUDAPlace.");
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* output = ctx.Output<Tensor>("Output");
......@@ -100,7 +100,7 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
cudnn_output_desc, algo, &workspace_size_in_bytes));
// Allocate on GPU memory
platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv transpose forward ---------------------
......@@ -120,7 +120,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"It must use CUDAPlace.");
auto input = ctx.Input<Tensor>("Input");
auto filter = ctx.Input<Tensor>("Filter");
auto output_grad = ctx.Input<Tensor>(framework::GradVarName("Output"));
......@@ -201,7 +201,7 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn conv workspace ---------------------
// Already on GPU
void* cudnn_workspace = nullptr;
platform::GPUPlace gpu = boost::get<platform::GPUPlace>(ctx.GetPlace());
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data ---------------------
// FIXME(typhoonzero): template type T may not be the same as cudnn call.
......
......@@ -35,7 +35,7 @@ struct StridedMemcpyFunctor<T, 1> {
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim.head);
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::GPUPlace>(place);
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T) * dst_dim.head,
......
......@@ -219,8 +219,8 @@ class LinearChainCRFOpKernel : public framework::OpKernel<T> {
// operators runs on GPU device.
auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor& src,
Tensor* dst) {
dst->mutable_data<T>(platform::GPUPlace());
framework::CopyFrom(src, platform::GPUPlace(), ctx, dst);
dst->mutable_data<T>(platform::CUDAPlace());
framework::CopyFrom(src, platform::CUDAPlace(), ctx, dst);
};
copyTensor(ctx, emission_exps_src, emission_exps_dst);
copyTensor(ctx, transition_exps_src, transition_exps_dst);
......@@ -433,8 +433,8 @@ class LinearChainCRFGradOpKernel : public framework::OpKernel<T> {
auto copyTensor = [](const platform::DeviceContext& ctx, const Tensor* src,
Tensor* dst) {
if (src && dst) {
dst->mutable_data<T>(platform::GPUPlace());
framework::CopyFrom(*src, platform::GPUPlace(), ctx, dst);
dst->mutable_data<T>(platform::CUDAPlace());
framework::CopyFrom(*src, platform::CUDAPlace(), ctx, dst);
}
};
copyTensor(ctx, emission_grad_src, emission_grad_dst);
......
......@@ -101,7 +101,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
// copy GPU memory to CPU pinned memory
framework::Vector<int64_t> new_rows;
new_rows.resize(ids_dim[0]);
auto gpu_place = boost::get<platform::GPUPlace>(context.GetPlace());
auto gpu_place = boost::get<platform::CUDAPlace>(context.GetPlace());
memory::Copy(platform::CPUPlace(), new_rows.data(), gpu_place, ids_data,
ids_dim[0] * sizeof(int64_t), stream);
......
......@@ -98,7 +98,7 @@ class LstmUnitOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"It must use CUDAPlace.");
auto* x_tensor = ctx.Input<framework::Tensor>("X");
auto* c_prev_tensor = ctx.Input<framework::Tensor>("C_prev");
......@@ -129,7 +129,7 @@ class LstmUnitGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"It must use CUDAPlace.");
auto x_tensor = ctx.Input<Tensor>("X");
auto c_prev_tensor = ctx.Input<Tensor>("C_prev");
......
......@@ -159,6 +159,7 @@ void testIm2col() {
TEST(math, im2col) {
testIm2col<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>();
#ifdef PADDLE_WITH_CUDA
testIm2col<paddle::platform::CUDADeviceContext, paddle::platform::GPUPlace>();
testIm2col<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace>();
#endif
}
......@@ -105,7 +105,7 @@ void matmul<platform::CUDADeviceContext, float>(
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in GPUPlace");
"Matrix must all be in CUDAPlace");
int M = dim_out[0];
int N = dim_out[1];
......@@ -134,7 +134,7 @@ void matmul<platform::CUDADeviceContext, double>(
PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) &&
platform::is_gpu_place(matrix_b.place()) &&
platform::is_gpu_place(matrix_out->place()),
"Matrix must all be in GPUPlace");
"Matrix must all be in CUDAPlace");
int M = dim_out[0];
int N = dim_out[1];
......@@ -266,7 +266,7 @@ struct TensorSetConstantGPU {
};
template <>
void set_constant_with_place<platform::GPUPlace>(
void set_constant_with_place<platform::CUDAPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
framework::VisitDataType(framework::ToDataType(tensor->type()),
......@@ -277,7 +277,7 @@ template <>
void set_constant_with_place<platform::CUDNNPlace>(
const platform::DeviceContext& context, framework::Tensor* tensor,
float value) {
set_constant_with_place<platform::GPUPlace>(context, tensor, value);
set_constant_with_place<platform::CUDAPlace>(context, tensor, value);
}
template struct RowwiseAdd<platform::CUDADeviceContext, float>;
......
......@@ -13,7 +13,7 @@ TEST(math_function, notrans_mul_trans) {
float arr[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr, 6 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
auto* gpu_place = new paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
paddle::framework::CopyFrom(input1, *gpu_place, context, &input1_gpu);
......@@ -47,7 +47,7 @@ TEST(math_function, trans_mul_notrans) {
float arr[6] = {0, 1, 2, 3, 4, 5};
memcpy(input1_ptr, arr, 6 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
auto* gpu_place = new paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
paddle::framework::CopyFrom(input1, *gpu_place, context, &input1_gpu);
......@@ -96,7 +96,7 @@ TEST(math_function, gemm_notrans_cublas) {
float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
memcpy(input3_ptr, arr3, 8 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
auto* gpu_place = new paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
paddle::framework::CopyFrom(input1, *gpu_place, context, &input1_gpu);
......@@ -151,7 +151,7 @@ TEST(math_function, gemm_trans_cublas) {
float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7};
memcpy(input3_ptr, arr3, 8 * sizeof(float));
auto* gpu_place = new paddle::platform::GPUPlace(0);
auto* gpu_place = new paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceContext context(*gpu_place);
paddle::framework::CopyFrom(input1, *gpu_place, context, &input1_gpu);
......@@ -189,7 +189,7 @@ void GemvTest(int m, int n, bool trans) {
T* data_b = vec_b.mutable_data<T>({trans ? m : n}, *cpu_place);
T* data_c = vec_c.mutable_data<T>({trans ? n : m}, *cpu_place);
auto* gpu_place = new paddle::platform::GPUPlace(0);
auto* gpu_place = new paddle::platform::CUDAPlace(0);
paddle::framework::Tensor g_mat_a;
paddle::framework::Tensor g_vec_b;
paddle::framework::Tensor g_vec_c;
......
......@@ -58,15 +58,15 @@ struct SelectedRowsAdd<platform::CUDADeviceContext, T> {
PADDLE_ENFORCE(platform::is_gpu_place(out_place));
memory::Copy(
boost::get<platform::GPUPlace>(out_place), out_data,
boost::get<platform::GPUPlace>(in1_place), in1_data,
boost::get<platform::CUDAPlace>(out_place), out_data,
boost::get<platform::CUDAPlace>(in1_place), in1_data,
in1_value.numel() * sizeof(T),
reinterpret_cast<const platform::CUDADeviceContext&>(context).stream());
auto* in2_data = in2_value.data<T>();
memory::Copy(boost::get<platform::GPUPlace>(out_place),
memory::Copy(boost::get<platform::CUDAPlace>(out_place),
out_data + in1_value.numel(),
boost::get<platform::GPUPlace>(in2_place), in2_data,
boost::get<platform::CUDAPlace>(in2_place), in2_data,
in2_value.numel() * sizeof(T), context.stream());
}
};
......@@ -160,9 +160,9 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
auto* in1_data = in1_value.data<T>();
auto* in2_data = in2_value->data<T>();
memory::Copy(boost::get<platform::GPUPlace>(in2_place),
memory::Copy(boost::get<platform::CUDAPlace>(in2_place),
in2_data + input2_offset,
boost::get<platform::GPUPlace>(in1_place), in1_data,
boost::get<platform::CUDAPlace>(in1_place), in1_data,
in1_value.numel() * sizeof(T), context.stream());
}
};
......
......@@ -21,7 +21,7 @@ TEST(selected_rows_functor, gpu_add) {
using namespace paddle::platform;
using namespace paddle::operators::math;
GPUPlace gpu_place(0);
CUDAPlace gpu_place(0);
CPUPlace cpu_place;
CUDADeviceContext ctx(gpu_place);
SetConstant<CUDADeviceContext, float> functor;
......@@ -119,7 +119,7 @@ TEST(selected_rows_functor, gpu_add_to) {
using namespace paddle::platform;
using namespace paddle::operators::math;
GPUPlace gpu_place(0);
CUDAPlace gpu_place(0);
CPUPlace cpu_place;
CUDADeviceContext ctx(gpu_place);
SetConstant<CUDADeviceContext, float> functor;
......
......@@ -122,6 +122,6 @@ TEST(math, vol2col) {
testVol2col<paddle::platform::CPUDeviceContext, paddle::platform::CPUPlace>();
#ifdef PADDLE_WITH_CUDA
testVol2col<paddle::platform::CUDADeviceContext,
paddle::platform::GPUPlace>();
paddle::platform::CUDAPlace>();
#endif // PADDLE_WITH_CUDA
}
......@@ -36,7 +36,7 @@ class MultiplexGPUKernel : public framework::OpKernel<T> {
CopyFrom(*ids, platform::CPUPlace(), ctx.device_context(), &index_t_cpu);
auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.cuda_device_context().stream();
platform::GPUPlace place = boost::get<platform::GPUPlace>(ctx.GetPlace());
platform::CUDAPlace place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
int32_t k = index[i];
PADDLE_ENFORCE_GE(k, 0, "index must be nonnegative.");
......@@ -73,7 +73,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
auto* index = index_t_cpu.data<int32_t>();
auto stream = ctx.cuda_device_context().stream();
platform::GPUPlace place = boost::get<platform::GPUPlace>(ctx.GetPlace());
platform::CUDAPlace place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) {
size_t k = static_cast<size_t>(index[i]);
if (d_ins[k]) {
......
......@@ -67,7 +67,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
auto stream = ctx.cuda_device_context().stream();
// device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
for (size_t i = 0; i < ins.size(); ++i) {
......@@ -120,7 +120,7 @@ class NCCLReduceKernel : public framework::OpKernel<T> {
ctx.device_context())
.stream();
// device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
auto ins_names = ctx.Inputs("X");
......@@ -164,7 +164,7 @@ class NCCLBcastKernel : public framework::OpKernel<T> {
ctx.device_context())
.stream();
// device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
int gpu_id = boost::get<platform::CUDAPlace>(ctx.GetPlace()).GetDeviceId();
int idx = comm->GetCommId(gpu_id);
if (idx == root) {
......
......@@ -52,7 +52,7 @@ class NCCLTester : public ::testing::Test {
virtual void SetUp() override {
paddle::platform::CPUPlace cpu_place;
for (size_t i = 0; i < gpu_list.size(); ++i) {
p::GPUPlace place(i);
p::CUDAPlace place(i);
dev_ctxs.emplace_back(new p::CUDADeviceContext(place));
}
......@@ -87,7 +87,7 @@ class NCCLTester : public ::testing::Test {
std::unique_lock<std::mutex> lk(mu);
const f::OpDesc *op1 = &op_desc;
p::GPUPlace place(gpu_id);
p::CUDAPlace place(gpu_id);
auto &ctx = dev_ctxs.at(gpu_id);
auto *send_tensor = scope->Var("st")->GetMutable<f::LoDTensor>();
......@@ -171,7 +171,7 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
for (size_t i = 0; i < dev_scopes.size(); ++i) {
p::CPUPlace cpu_place;
p::GPUPlace gpu_place(gpu_list[i]);
p::CUDAPlace gpu_place(gpu_list[i]);
auto &recv_tensor = dev_scopes[i]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
......@@ -180,7 +180,7 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::GPUPlace(gpu_list[i]), rt,
cpu_place, ct, p::CUDAPlace(gpu_list[i]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[i])->stream());
......@@ -219,7 +219,7 @@ TEST_F(NCCLTester, ncclReduceOp) {
float result = std::accumulate(gpu_list.begin(), gpu_list.end(), 0);
p::CPUPlace cpu_place;
p::GPUPlace gpu_place(gpu_list[kRoot]);
p::CUDAPlace gpu_place(gpu_list[kRoot]);
auto &recv_tensor = dev_scopes[kRoot]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
......@@ -229,7 +229,7 @@ TEST_F(NCCLTester, ncclReduceOp) {
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::GPUPlace(gpu_list[kRoot]), rt,
cpu_place, ct, p::CUDAPlace(gpu_list[kRoot]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[kRoot])->stream());
......@@ -268,7 +268,7 @@ TEST_F(NCCLTester, ncclBcastOp) {
float result = kRoot;
p::CPUPlace cpu_place;
p::GPUPlace gpu_place(gpu_list[idx]);
p::CUDAPlace gpu_place(gpu_list[idx]);
auto &recv_tensor = dev_scopes[idx]->FindVar("rt")->Get<f::LoDTensor>();
auto *rt = recv_tensor.data<float>();
......@@ -277,7 +277,7 @@ TEST_F(NCCLTester, ncclBcastOp) {
auto *ct = result_tensor->mutable_data<float>(cpu_place);
paddle::memory::Copy(
cpu_place, ct, p::GPUPlace(gpu_list[idx]), rt,
cpu_place, ct, p::CUDAPlace(gpu_list[idx]), rt,
recv_tensor.numel() * sizeof(float),
static_cast<p::CUDADeviceContext *>(dev_ctxs[idx])->stream());
......@@ -300,7 +300,7 @@ int main(int argc, char **argv) {
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::GPUPlace(i));
places.emplace_back(paddle::platform::CUDAPlace(i));
gpu_list.emplace_back(i);
}
......
......@@ -29,7 +29,7 @@ class PoolCudnnOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"It must use CUDAPlace.");
const Tensor *input = ctx.Input<Tensor>("X");
Tensor *output = ctx.Output<Tensor>("Out");
......@@ -90,7 +90,7 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"It must use CUDAPlace.");
const Tensor *input = ctx.Input<Tensor>("X");
const Tensor *output = ctx.Input<Tensor>("Out");
......
......@@ -16,7 +16,7 @@
REGISTER_OP_CUDA_KERNEL(
reshape,
paddle::operators::ReshapeKernel<paddle::platform::GPUPlace, float>);
paddle::operators::ReshapeKernel<paddle::platform::CUDAPlace, float>);
REGISTER_OP_CUDA_KERNEL(
reshape_grad,
paddle::operators::ReshapeGradKernel<paddle::platform::GPUPlace, float>);
paddle::operators::ReshapeGradKernel<paddle::platform::CUDAPlace, float>);
......@@ -82,7 +82,7 @@ TEST(StridedMemcpy, GPUCrop) {
};
// clang-format on
platform::GPUPlace gpu0(0);
platform::CUDAPlace gpu0(0);
platform::CPUPlace cpu;
platform::CUDADeviceContext ctx(gpu0);
......@@ -121,7 +121,7 @@ TEST(StridedMemcpy, GPUConcat) {
};
// clang-format on
platform::GPUPlace gpu0(0);
platform::CUDAPlace gpu0(0);
platform::CPUPlace cpu;
platform::CUDADeviceContext ctx(gpu0);
......
......@@ -283,7 +283,7 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
"It must use CUDAPlace.");
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
......
......@@ -22,23 +22,7 @@ namespace paddle {
namespace platform {
void CudaProfilerInit(std::string output_file, std::string output_mode,
std::vector<std::string> config_flags) {
std::array<char, 128> buf;
std::string tmpl = "/tmp/cuda_profile_config.XXXXXX";
PADDLE_ENFORCE_LT(tmpl.size(), buf.size());
memcpy(buf.data(), tmpl.data(), tmpl.size());
auto result = mktemp(buf.data());
PADDLE_ENFORCE(strlen(result) != 0);
std::string config_file = result;
{
std::ofstream ofs(config_file, std::ios::out | std::ios::trunc);
PADDLE_ENFORCE(ofs.is_open(), "ofstream: ", ofs.rdstate());
for (const auto& line : config_flags) {
ofs << line << std::endl;
}
}
std::string config_file) {
PADDLE_ENFORCE(output_mode == "kvp" || output_mode == "csv");
cudaOutputMode_t mode = output_mode == "csv" ? cudaCSV : cudaKeyValuePair;
PADDLE_ENFORCE(
......
......@@ -58,10 +58,10 @@ DeviceContextPool::DeviceContextPool(
#ifdef PADDLE_WITH_CUDA
device_contexts_.emplace(places[i],
new platform::CUDADeviceContext(
boost::get<platform::GPUPlace>(places[i])));
boost::get<platform::CUDAPlace>(places[i])));
#else
PADDLE_THROW(
"'GPUPlace' is not supported, Please re-compile with WITH_GPU "
"'CUDAPlace' is not supported, Please re-compile with WITH_GPU "
"option");
#endif
}
......@@ -91,7 +91,7 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
}
~EigenCudaStreamDevice() override {}
void Reinitialize(const cudaStream_t* cuda_stream, GPUPlace place) {
void Reinitialize(const cudaStream_t* cuda_stream, CUDAPlace place) {
stream_ = cuda_stream;
place_ = place;
device_prop_ = &Eigen::m_deviceProperties[place.device];
......@@ -130,14 +130,14 @@ class EigenCudaStreamDevice : public Eigen::StreamInterface {
}
private:
GPUPlace place_;
CUDAPlace place_;
const cudaStream_t* stream_; // not owned;
const cudaDeviceProp* device_prop_; // not owned;
mutable void* scratch_;
mutable unsigned int* semaphore_;
};
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
CUDADeviceContext::CUDADeviceContext(CUDAPlace place) : place_(place) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
eigen_stream_.reset(new EigenCudaStreamDevice());
......
......@@ -58,7 +58,7 @@ class EigenCudaStreamDevice;
class CUDADeviceContext : public DeviceContext {
public:
explicit CUDADeviceContext(GPUPlace place);
explicit CUDADeviceContext(CUDAPlace place);
virtual ~CUDADeviceContext();
/*! \brief Wait for all operations completion in the stream. */
......@@ -80,7 +80,7 @@ class CUDADeviceContext : public DeviceContext {
cudaStream_t stream() const;
private:
GPUPlace place_;
CUDAPlace place_;
std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
......@@ -143,7 +143,7 @@ class DeviceContextPool {
size_t operator()(const platform::Place& place) const {
int pre_hash = place.which() + (1 << LEFT_SHIFT);
if (platform::is_gpu_place(place)) {
pre_hash += boost::get<platform::GPUPlace>(place).GetDeviceId();
pre_hash += boost::get<platform::CUDAPlace>(place).GetDeviceId();
}
return hash_(pre_hash);
}
......
......@@ -20,11 +20,11 @@ limitations under the License. */
TEST(Device, Init) {
using paddle::platform::DeviceContext;
using paddle::platform::CUDADeviceContext;
using paddle::platform::GPUPlace;
using paddle::platform::CUDAPlace;
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; i++) {
CUDADeviceContext* device_context = new CUDADeviceContext(GPUPlace(i));
CUDADeviceContext* device_context = new CUDADeviceContext(CUDAPlace(i));
Eigen::GpuDevice* gpu_device = device_context->eigen_device();
ASSERT_NE(nullptr, gpu_device);
delete device_context;
......@@ -33,11 +33,11 @@ TEST(Device, Init) {
TEST(Device, CUDADeviceContext) {
using paddle::platform::CUDADeviceContext;
using paddle::platform::GPUPlace;
using paddle::platform::CUDAPlace;
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; i++) {
CUDADeviceContext* device_context = new CUDADeviceContext(GPUPlace(i));
CUDADeviceContext* device_context = new CUDADeviceContext(CUDAPlace(i));
Eigen::GpuDevice* gpu_device = device_context->eigen_device();
ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
......@@ -70,7 +70,7 @@ TEST(Device, DeviceContextPool) {
using paddle::platform::CUDADeviceContext;
using paddle::platform::Place;
using paddle::platform::CPUPlace;
using paddle::platform::GPUPlace;
using paddle::platform::CUDAPlace;
DeviceContextPool& pool = DeviceContextPool::Get();
auto cpu_dev_ctx1 = pool.Borrow(CPUPlace());
......@@ -80,14 +80,14 @@ TEST(Device, DeviceContextPool) {
std::vector<Place> gpu_places;
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
gpu_places.emplace_back(GPUPlace(i));
gpu_places.emplace_back(CUDAPlace(i));
}
auto dev_ctxs = pool.Borrow(gpu_places);
for (size_t i = 0; i < dev_ctxs.size(); ++i) {
auto* dev_ctx = static_cast<const CUDADeviceContext*>(dev_ctxs[i]);
// check same as GPUPlace(i)
GPUPlace place = boost::get<GPUPlace>(dev_ctx->GetPlace());
// check same as CUDAPlace(i)
CUDAPlace place = boost::get<CUDAPlace>(dev_ctx->GetPlace());
EXPECT_EQ(place.GetDeviceId(), static_cast<int>(i));
}
}
......@@ -106,7 +106,7 @@ int main(int argc, char** argv) {
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::GPUPlace(i));
places.emplace_back(paddle::platform::CUDAPlace(i));
}
VLOG(0) << " DeviceCount " << count;
......
......@@ -50,7 +50,7 @@ struct PerThreadData {
T* RecvBuff() { return thrust::raw_pointer_cast(recv_buff.data()); }
PerThreadData(int gpu_id, size_t size) : dev_ctx(GPUPlace(gpu_id)) {
PerThreadData(int gpu_id, size_t size) : dev_ctx(CUDAPlace(gpu_id)) {
send_buff.resize(size);
for (size_t i = 0; i < size; ++i) {
send_buff[i] = static_cast<T>(i);
......@@ -140,7 +140,7 @@ int main(int argc, char** argv) {
places.emplace_back(paddle::platform::CPUPlace());
int count = paddle::platform::GetCUDADeviceCount();
for (int i = 0; i < count; ++i) {
places.emplace_back(paddle::platform::GPUPlace(i));
places.emplace_back(paddle::platform::CUDAPlace(i));
}
VLOG(0) << " DeviceCount " << count;
......
......@@ -24,7 +24,9 @@ class PlacePrinter : public boost::static_visitor<> {
explicit PlacePrinter(std::ostream &os) : os_(os) {}
void operator()(const CPUPlace &) { os_ << "CPUPlace"; }
void operator()(const MKLDNNPlace &) { os_ << "MKLDNNPlace"; }
void operator()(const GPUPlace &p) { os_ << "GPUPlace(" << p.device << ")"; }
void operator()(const CUDAPlace &p) {
os_ << "CUDAPlace(" << p.device << ")";
}
private:
std::ostream &os_;
......@@ -37,12 +39,12 @@ static Place the_default_place;
void set_place(const Place &place) { the_default_place = place; }
const Place &get_place() { return the_default_place; }
const GPUPlace default_gpu() { return GPUPlace(0); }
const CUDAPlace default_gpu() { return CUDAPlace(0); }
const CPUPlace default_cpu() { return CPUPlace(); }
const MKLDNNPlace default_mkldnn() { return MKLDNNPlace(); }
bool is_gpu_place(const Place &p) {
return boost::apply_visitor(IsGPUPlace(), p);
return boost::apply_visitor(IsCUDAPlace(), p);
}
bool is_cpu_place(const Place &p) {
return !is_gpu_place(p) && !is_mkldnn_place(p);
......
......@@ -39,43 +39,45 @@ struct MKLDNNPlace {
inline bool operator!=(const MKLDNNPlace &) const { return false; }
};
struct GPUPlace {
GPUPlace() : GPUPlace(0) {}
explicit GPUPlace(int d) : device(d) {}
struct CUDAPlace {
CUDAPlace() : CUDAPlace(0) {}
explicit CUDAPlace(int d) : device(d) {}
inline int GetDeviceId() const { return device; }
// needed for variant equality comparison
inline bool operator==(const GPUPlace &o) const { return device == o.device; }
inline bool operator!=(const GPUPlace &o) const { return !(*this == o); }
inline bool operator==(const CUDAPlace &o) const {
return device == o.device;
}
inline bool operator!=(const CUDAPlace &o) const { return !(*this == o); }
int device;
};
struct CUDNNPlace : public GPUPlace {
CUDNNPlace() : GPUPlace() {}
explicit CUDNNPlace(int d) : GPUPlace(d) {}
struct CUDNNPlace : public CUDAPlace {
CUDNNPlace() : CUDAPlace() {}
explicit CUDNNPlace(int d) : CUDAPlace(d) {}
};
struct IsGPUPlace : public boost::static_visitor<bool> {
struct IsCUDAPlace : public boost::static_visitor<bool> {
bool operator()(const CPUPlace &) const { return false; }
bool operator()(const MKLDNNPlace &) const { return false; }
bool operator()(const GPUPlace &gpu) const { return true; }
bool operator()(const CUDAPlace &gpu) const { return true; }
bool operator()(const CUDNNPlace &) const { return true; }
};
struct IsMKLDNNPlace : public boost::static_visitor<bool> {
bool operator()(const MKLDNNPlace &) const { return true; }
bool operator()(const CPUPlace &) const { return false; }
bool operator()(const GPUPlace &) const { return false; }
bool operator()(const CUDAPlace &) const { return false; }
bool operator()(const CUDNNPlace &) const { return false; }
};
typedef boost::variant<CUDNNPlace, GPUPlace, CPUPlace, MKLDNNPlace> Place;
typedef boost::variant<CUDNNPlace, CUDAPlace, CPUPlace, MKLDNNPlace> Place;
void set_place(const Place &);
const Place &get_place();
const GPUPlace default_gpu();
const CUDAPlace default_gpu();
const CPUPlace default_cpu();
const MKLDNNPlace default_mkldnn();
......
......@@ -4,7 +4,7 @@
TEST(Place, Equality) {
paddle::platform::CPUPlace cpu;
paddle::platform::GPUPlace g0(0), g1(1), gg0(0);
paddle::platform::CUDAPlace g0(0), g1(1), gg0(0);
paddle::platform::CUDNNPlace d0(0), d1(1), dd0(0);
EXPECT_EQ(cpu, cpu);
......@@ -41,8 +41,8 @@ TEST(Place, Default) {
TEST(Place, Print) {
{
std::stringstream ss;
ss << paddle::platform::GPUPlace(1);
EXPECT_EQ("GPUPlace(1)", ss.str());
ss << paddle::platform::CUDAPlace(1);
EXPECT_EQ("CUDAPlace(1)", ss.str());
}
{
std::stringstream ss;
......
......@@ -49,7 +49,7 @@ TEST(Transform, CPUUnary) {
TEST(Transform, GPUUnary) {
using namespace paddle::platform;
using namespace paddle::memory;
GPUPlace gpu0(0);
CUDAPlace gpu0(0);
CUDADeviceContext ctx(gpu0);
float cpu_buf[4] = {0.1, 0.2, 0.3, 0.4};
float* gpu_buf = static_cast<float*>(Alloc(gpu0, sizeof(float) * 4));
......@@ -80,7 +80,7 @@ TEST(Transform, GPUBinary) {
using namespace paddle::platform;
using namespace paddle::memory;
int buf[4] = {1, 2, 3, 4};
GPUPlace gpu0(0);
CUDAPlace gpu0(0);
CUDADeviceContext ctx(gpu0);
int* gpu_buf = static_cast<int*>(Alloc(gpu0, sizeof(buf)));
Copy(gpu0, gpu_buf, CPUPlace(), buf, sizeof(buf), ctx.stream());
......
......@@ -79,7 +79,7 @@ PYBIND11_PLUGIN(core) {
self.Resize(make_ddim(dim));
})
.def("alloc_float",
[](Tensor &self, paddle::platform::GPUPlace &place) {
[](Tensor &self, paddle::platform::CUDAPlace &place) {
self.mutable_data<float>(place);
})
.def("alloc_float",
......@@ -91,7 +91,7 @@ PYBIND11_PLUGIN(core) {
self.mutable_data<int>(place);
})
.def("alloc_int",
[](Tensor &self, paddle::platform::GPUPlace &place) {
[](Tensor &self, paddle::platform::CUDAPlace &place) {
self.mutable_data<int>(place);
})
.def("set", PyCPUTensorSetFromArray<float>)
......@@ -310,10 +310,10 @@ All parameter, weight, gradient are variables in Paddle.
return new paddle::platform::CPUDeviceContext();
})
.def_static("create",
[](paddle::platform::GPUPlace& place)
[](paddle::platform::CUDAPlace& place)
-> paddle::platform::DeviceContext* {
#ifndef PADDLE_WITH_CUDA
PADDLE_THROW("GPUPlace is not supported in CPU device.");
PADDLE_THROW("CUDAPlace is not supported in CPU device.");
#else
return new paddle::platform::CUDADeviceContext(place);
#endif
......@@ -323,9 +323,9 @@ All parameter, weight, gradient are variables in Paddle.
#ifdef PADDLE_WITH_CUDA
py::class_<platform::Communicator>(m, "Communicator").def(py::init<>());
#endif
py::class_<platform::GPUPlace>(m, "GPUPlace")
py::class_<platform::CUDAPlace>(m, "CUDAPlace")
.def(py::init<int>())
.def("__str__", string::to_string<const platform::GPUPlace &>);
.def("__str__", string::to_string<const platform::CUDAPlace &>);
py::class_<paddle::platform::CPUPlace>(m, "CPUPlace")
.def(py::init<>())
......@@ -338,7 +338,7 @@ All parameter, weight, gradient are variables in Paddle.
self = cpu_place;
})
.def("set_place",
[](platform::Place &self, const platform::GPUPlace &gpu_place) {
[](platform::Place &self, const platform::CUDAPlace &gpu_place) {
self = gpu_place;
});
......@@ -363,7 +363,7 @@ All parameter, weight, gradient are variables in Paddle.
const platform::CPUPlace &place) { self.Run(scope, place); })
.def("run",
[](OperatorBase &self, const Scope &scope,
const platform::GPUPlace &place) { self.Run(scope, place); })
const platform::CUDAPlace &place) { self.Run(scope, place); })
.def("type",
[](const OperatorBase &op) -> std::string { return op.Type(); })
.def("outputs",
......
......@@ -71,7 +71,7 @@ struct CastToPyBufferImpl<true, I, ARGS...> {
dst_ptr, src_ptr, sizeof(CUR_TYPE) * tensor.numel(),
cudaMemcpyDeviceToHost, dev_ctx->stream());
#else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
PADDLE_THROW("'CUDAPlace' is not supported in CPU only device.");
#endif
} else if (paddle::platform::is_cpu_place(tensor.place())) {
dst_tensor = tensor;
......@@ -127,7 +127,7 @@ template <typename T>
void PyCUDATensorSetFromArray(
framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array,
paddle::platform::GPUPlace &place) {
paddle::platform::CUDAPlace &place) {
std::vector<int64_t> dims;
dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) {
......
......@@ -36,7 +36,7 @@ int main(int argc, char** argv) {
paddle::memory::Used(paddle::platform::CPUPlace());
std::vector<std::string> devs = {"CPU"};
#ifdef PADDLE_WITH_CUDA
paddle::memory::Used(paddle::platform::GPUPlace(0));
paddle::memory::Used(paddle::platform::CUDAPlace(0));
devs.push_back("GPU:0");
#endif
paddle::framework::InitDevices(devs);
......
......@@ -15,14 +15,14 @@ import backward
import regularizer
from param_attr import ParamAttr
from data_feeder import DataFeeder
from core import LoDTensor, CPUPlace, GPUPlace
from core import LoDTensor, CPUPlace, CUDAPlace
from distribute_transpiler import DistributeTranspiler
import clip
Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + [
'io', 'initializer', 'layers', 'nets', 'optimizer', 'backward',
'regularizer', 'LoDTensor', 'CPUPlace', 'GPUPlace', 'Tensor', 'ParamAttr'
'regularizer', 'LoDTensor', 'CPUPlace', 'CUDAPlace', 'Tensor', 'ParamAttr'
'DataFeeder', 'clip', 'DistributeTranspiler'
]
......
......@@ -47,7 +47,7 @@ class Executor(object):
act_places.append(p)
# TODO(dzhwinter) : consider that our fluid tests all written in
# GPUPlace(gpu_id), this will be changed in the future
# CUDAPlace(gpu_id), this will be changed in the future
if core.is_compile_gpu():
core.init_devices(["CPU", "GPU:0"])
else:
......
......@@ -163,8 +163,9 @@ def embedding(input, size, is_sparse=False, param_attr=None, dtype='float32'):
Examples:
.. code-block:: python
dict_size = len(dataset.ids)
data = fluid.layers.data(name='ids', shape=[32, 32], dtype='float32')
fc = fluid.layers.embedding(input=data, size=16)
fc = fluid.layers.embedding(input=data, size=[dict_size, 16])
"""
helper = LayerHelper('embedding', **locals())
......
import paddle.v2.fluid.core as core
from contextlib import contextmanager
import os
__all__ = ['CudaProfiler']
......@@ -30,17 +31,21 @@ def cuda_profiler(output_file, output_mode=None, config=None):
written into this file.
output_mode (string) : The output mode has Key-Value pair format and
Comma separated values format. It should be 'kvp' or 'csv'.
config (string) : The profiler options and counters can refer to
"Compute Command Line Profiler User Guide".
config (list of string) : The profiler options and counters can refer
to "Compute Command Line Profiler User Guide".
"""
if output_mode is None:
output_mode = 'csv'
if output_mode not in ['kvp', 'csv']:
raise ValueError("The output mode must be 'kvp' or 'csv'.")
config = NVPROF_CONFIG if config is None else config
core.nvprof_init(output_file, output_mode, config)
config_file = 'nvprof_config_file'
with open(config_file, 'wb') as fp:
fp.writelines(["%s\n" % item for item in config])
core.nvprof_init(output_file, output_mode, config_file)
# Enables profiler collection by the active CUDA profiling tool.
core.nvprof_start()
yield
# Disables profiler collection.
core.nvprof_stop()
os.remove(config_file)
......@@ -142,7 +142,7 @@ def main():
opts = sgd_optimizer.minimize(cost)
if USE_GPU:
place = core.GPUPlace(0)
place = core.CUDAPlace(0)
else:
place = core.CPUPlace()
......
......@@ -316,7 +316,7 @@ class OpTest(unittest.TestCase):
def check_output(self, atol=1e-5):
places = [core.CPUPlace()]
if core.is_compile_gpu() and core.op_support_gpu(self.op_type):
places.append(core.GPUPlace(0))
places.append(core.CUDAPlace(0))
for place in places:
self.check_output_with_place(place, atol)
......@@ -379,7 +379,7 @@ class OpTest(unittest.TestCase):
"Gradient Check On %s" % str(cpu_place))
if core.is_compile_gpu() and self.op.support_gpu():
gpu_place = core.GPUPlace(0)
gpu_place = core.CUDAPlace(0)
gpu_analytic_grads = self._get_gradient(inputs_to_check, gpu_place,
output_names, no_grad_set)
......
......@@ -167,7 +167,7 @@ class TestSparseAdagradOp(unittest.TestCase):
def test_sparse_adagrad(self):
places = [core.CPUPlace()]
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
......
......@@ -304,7 +304,7 @@ class TestBatchNormOp(OpTest):
self.__assert_close(saved_variance_tensor, saved_variance,
"saved_variance")
self.__assert_close(mean_out_tensor, mean_out, "mean_out")
if isinstance(place, core.GPUPlace):
if isinstance(place, core.CUDAPlace):
atol = 5e-2
else:
atol = 1e-4
......@@ -339,7 +339,7 @@ class TestBatchNormOp(OpTest):
places = [core.CPUPlace()]
if core.is_compile_gpu() and core.op_support_gpu("batch_norm"):
places.append(core.GPUPlace(0))
places.append(core.CUDAPlace(0))
core.init_devices(["CPU", "GPU:0"])
else:
......
......@@ -20,7 +20,7 @@ class TestGaussianRandomOp(unittest.TestCase):
def test_gpu(self):
if core.is_compile_gpu():
self.gaussian_random_test(place=fluid.GPUPlace(0))
self.gaussian_random_test(place=fluid.CUDAPlace(0))
def gaussian_random_test(self, place):
......
......@@ -3,6 +3,7 @@ import numpy as np
import paddle.v2.fluid as fluid
import paddle.v2.fluid.profiler as profiler
import paddle.v2.fluid.layers as layers
import os
class TestProfiler(unittest.TestCase):
......@@ -14,14 +15,16 @@ class TestProfiler(unittest.TestCase):
data = layers.data(name='data', shape=[3, 28, 28], dtype='float32')
conv = layers.conv2d(data, 20, 3, stride=[1, 1], padding=[1, 1])
place = fluid.GPUPlace(0)
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
with profiler.cuda_profiler("cuda_profiler.txt", 'csv') as nvprof:
output_file = 'cuda_profiler.txt'
with profiler.cuda_profiler(output_file, 'csv') as nvprof:
for i in range(epoc):
input = np.random.random(dshape).astype('float32')
exe.run(fluid.default_main_program(), feed={'data': input})
os.remove(output_file)
if __name__ == '__main__':
......
......@@ -78,7 +78,7 @@ class TestSparseSGDOp(unittest.TestCase):
def test_sparse_sgd(self):
places = [core.CPUPlace()]
if core.is_compile_gpu():
places.append(core.GPUPlace(0))
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
......
......@@ -23,7 +23,7 @@ class TestUniformRandomOp(unittest.TestCase):
def test_gpu(self):
if core.is_compile_gpu():
self.uniform_random_test(place=core.GPUPlace(0))
self.uniform_random_test(place=core.CUDAPlace(0))
def uniform_random_test(self, place):
program = fluid.Program()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册