device_code_test.cc 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/platform/device_code.h"
16
#include <utility>
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/init.h"

constexpr auto saxpy_code = R"(
extern "C" __global__
void saxpy_kernel(float a, float *x, float* y, float* z, size_t n) {
  for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n;
       tid += blockDim.x * gridDim.x) {
    z[tid] = a * x[tid] + y[tid];
  }
}
)";

#ifdef PADDLE_WITH_CUDA
32 33 34 35 36 37
TEST(DeviceCode, cuda) {
  if (!paddle::platform::dynload::HasNVRTC() ||
      !paddle::platform::dynload::HasCUDADriver()) {
    return;
  }

38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
  paddle::framework::InitDevices(false, {0});
  paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
  paddle::platform::CUDADeviceCode code(place, "saxpy_kernel", saxpy_code);

  paddle::framework::Tensor cpu_x;
  paddle::framework::Tensor cpu_y;
  paddle::framework::Tensor cpu_z;

  float scale = 2;
  auto dims = paddle::framework::make_ddim(
      {static_cast<int64_t>(256), static_cast<int64_t>(1024)});
  cpu_x.mutable_data<float>(dims, paddle::platform::CPUPlace());
  cpu_y.mutable_data<float>(dims, paddle::platform::CPUPlace());

  size_t n = cpu_x.numel();
  for (size_t i = 0; i < n; ++i) {
    cpu_x.data<float>()[i] = static_cast<float>(i);
  }
  for (size_t i = 0; i < n; ++i) {
    cpu_y.data<float>()[i] = static_cast<float>(0.5);
  }

  paddle::framework::Tensor x;
  paddle::framework::Tensor y;
  paddle::framework::Tensor z;

  float* x_data = x.mutable_data<float>(dims, place);
  float* y_data = y.mutable_data<float>(dims, place);
  float* z_data = z.mutable_data<float>(dims, place);

  TensorCopySync(cpu_x, place, &x);
  TensorCopySync(cpu_y, place, &y);

71
  EXPECT_EQ(code.Compile(), true);
72 73 74 75 76 77

  std::vector<void*> args = {&scale, &x_data, &y_data, &z_data, &n};
  code.SetNumThreads(1024);
  code.SetWorkloadPerThread(1);
  code.Launch(n, &args);

78 79 80
  auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
  dev_ctx->Wait();

81 82
  TensorCopySync(z, paddle::platform::CPUPlace(), &cpu_z);
  for (size_t i = 0; i < n; i++) {
83 84 85 86 87 88 89
    EXPECT_EQ(cpu_z.data<float>()[i], static_cast<float>(i) * scale + 0.5);
  }
}

TEST(DeviceCodePool, cuda) {
  if (!paddle::platform::dynload::HasNVRTC()) {
    return;
90
  }
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107

  paddle::framework::InitDevices(false, {0});
  paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
  paddle::platform::DeviceCodePool& pool =
      paddle::platform::DeviceCodePool::Init({place});
  size_t num_device_codes_before = pool.size(place);
  EXPECT_EQ(num_device_codes_before, 0UL);

  std::unique_ptr<paddle::platform::DeviceCode> code(
      new paddle::platform::CUDADeviceCode(place, "saxpy_kernel", saxpy_code));
  LOG(INFO) << "origin ptr: " << code.get();
  pool.Set(std::move(code));
  size_t num_device_codes_after = pool.size(place);
  EXPECT_EQ(num_device_codes_after, 1UL);

  paddle::platform::DeviceCode* code_get = pool.Get(place, "saxpy_kernel");
  LOG(INFO) << "get ptr: " << code_get;
108 109
}
#endif