提交 7aac1218 编写于 作者: Y Yi Wang

Fix bugs

上级 a40b755b
...@@ -15,24 +15,28 @@ limitations under the License. */ ...@@ -15,24 +15,28 @@ limitations under the License. */
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
using DEVICE_GPU = Eigen::GpuDevice;
TEST(Device, Init) { TEST(Device, Init) {
using paddle::platform::DeviceContext;
using paddle::platform::CUDADeviceContext;
using paddle::platform::GPUPlace;
int count = paddle::platform::GetDeviceCount(); int count = paddle::platform::GetDeviceCount();
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
paddle::platform::DeviceContext* device_context = DeviceContext* device_context = new CUDADeviceContext(GPUPlace(i));
new paddle::platform::CUDADeviceContext(i);
Eigen::GpuDevice* gpu_device = Eigen::GpuDevice* gpu_device =
device_context->template get_eigen_device<DEVICE_GPU>(); device_context->template get_eigen_device<Eigen::GpuDevice>();
ASSERT_NE(nullptr, gpu_device); ASSERT_NE(nullptr, gpu_device);
delete device_context; delete device_context;
} }
} }
TEST(Device, CUDADeviceContext) { TEST(Device, CUDADeviceContext) {
using paddle::platform::CUDADeviceContext;
using paddle::platform::GPUPlace;
int count = paddle::platform::GetDeviceCount(); int count = paddle::platform::GetDeviceCount();
for (int i = 0; i < count; i++) { for (int i = 0; i < count; i++) {
paddle::platform::CUDADeviceContext* device_context = CUDADeviceContext* device_context = new CUDADeviceContext(GPUPlace(i));
new paddle::platform::CUDADeviceContext(i);
Eigen::GpuDevice* gpu_device = device_context->eigen_device(); Eigen::GpuDevice* gpu_device = device_context->eigen_device();
ASSERT_NE(nullptr, gpu_device); ASSERT_NE(nullptr, gpu_device);
cudnnHandle_t cudnn_handle = device_context->cudnn_handle(); cudnnHandle_t cudnn_handle = device_context->cudnn_handle();
......
...@@ -39,8 +39,8 @@ public: ...@@ -39,8 +39,8 @@ public:
// size_ is 0. // size_ is 0.
Piece(); Piece();
Piece(const char* d, size_t n); Piece(const char* d, size_t n);
explicit Piece(const char* d); Piece(const char* d); // NOLINT: accept C string into Piece.
explicit Piece(const std::string& s); Piece(const std::string& s); // NOLINT: accept C++ string into Piece.
const char* data() const { return data_; } const char* data() const { return data_; }
size_t len() const { return size_; } size_t len() const { return size_; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册