未验证 提交 d4832077 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add the first implememtation of fusion_group op (#19621)

* Add the dynamic load of nvrtc, and support runtime compiling of CUDA kernel using nvrtc.
test=develop

* Call CUDA driver api to launch the kernel compiled by nvrtc.
test=develop

* Disable for mac and windows.
test=develop

* Refine the codes to support manually specified num_threads and workload_per_thread.
test=develop

* Refine the CUDA kernel to support large dims.
test=develop

* Add DeviceCodePool to manage all device codes.

* Add the first implementation fusion_group op.

* Add unit-test for fusion_group op.

* Add the check of result.

* Add the check of nvrtc in unit-test.
test=develop

* Add comment to explain the inputs, outputs and features of fusion_group op.
test=develop

* Disable fusion_group op for mac and windows.
test=develop

* Make the compiling of device code return status instead of hanging up.
test=develop

* Add the check of whether there is CUDA driver library, and do not core dump when failing to call the CUDA driver API.

* Unify fusion_group_op's input and output names.
test=develop

* Add the check of CUDA driver library in unittest.
test=develop

* Refine the calling of PADDLE_ENFORCE.
test=develop
上级 61921084
......@@ -116,7 +116,9 @@ function(op_library TARGET)
# Define operators that don't need pybind here.
foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op"
"tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op"
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "multihead_matmul_op")
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"multihead_matmul_op" "fusion_group_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -82,7 +82,7 @@ void FusionGroupPass::InsertFusionGroupOp(
input_names.push_back(n->Name());
external_nodes.insert(n);
}
op_desc.SetInput("Xs", input_names);
op_desc.SetInput("Inputs", input_names);
std::vector<std::string> output_names;
for (auto* n : output_vars_of_subgraph) {
......
......@@ -4,7 +4,8 @@ register_operators(EXCLUDES
fusion_transpose_flatten_concat_op
fusion_conv_inception_op
fused_fc_elementwise_layernorm_op
multihead_matmul_op)
multihead_matmul_op
fusion_group_op)
if (WITH_GPU)
# conv_fusion_op needs cudnn 7 above
......@@ -26,4 +27,10 @@ if (WITH_GPU)
# multihead_matmul_op
op_library(multihead_matmul_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n")
# fusion_group
if(NOT APPLE AND NOT WIN32)
op_library(fusion_group_op DEPS device_code)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_group);\n")
cc_test(test_fusion_group_op SRCS fusion_group_op_test.cc DEPS fusion_group_op)
endif()
endif()
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_group_op.h"
namespace paddle {
namespace operators {
class FusionGroupOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
const size_t num_ins = ctx->Inputs("Inputs").size();
const size_t num_outs = ctx->Outputs("Outs").size();
PADDLE_ENFORCE_GE(
num_ins, 1UL,
platform::errors::InvalidArgument(
"Expected the number of inputs >= 1. Received %d.", num_ins));
PADDLE_ENFORCE_GE(
num_outs, 1UL,
platform::errors::InvalidArgument(
"Expected the number of outputs >= 1. Recived %d.", num_outs));
int type = ctx->Attrs().Get<int>("type");
PADDLE_ENFORCE_EQ(type, 0UL,
platform::errors::InvalidArgument(
"Only support fusion of elementwise operations."));
std::vector<framework::DDim> x_dims = ctx->GetInputsDim("Inputs");
if (type == 0) {
for (size_t i = 1; i < num_ins; ++i) {
PADDLE_ENFORCE_EQ(x_dims[0], x_dims[i],
platform::errors::InvalidArgument(
"All the inputs' dims should be the same."));
}
std::vector<framework::DDim> out_dims;
for (size_t j = 0; j < num_outs; ++j) {
out_dims.push_back(x_dims[0]);
}
ctx->SetOutputsDim("Outs", out_dims);
}
// Only lod of Inputs[0] would be shared with Outs.
for (size_t j = 0; j < num_outs; ++j) {
ctx->ShareLoD("Inputs", /*->*/ "Outs", 0, j);
}
}
};
class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Inputs",
"(std::vector<LoDTensor>) The inputs of fusion_group op.")
.AsDuplicable();
AddOutput("Outs",
"(std::vector<LoDTensor>) The outputs of fusion_group op.")
.AsDuplicable();
AddAttr<int>("type", "Fusion type.").SetDefault(0);
AddAttr<std::string>("func_name", "Name of the generated functions.")
.SetDefault("");
AddComment(R"DOC(
fusion_group Operator.
It is used to execute a generated CUDA kernel which fuse the computation of
multiple operators into one. It supports serveral types:
0, fused computation of elementwise operations in which all the dims of inputs
and outputs should be exactly the same.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fusion_group, ops::FusionGroupOp, ops::FusionGroupOpMaker);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_group_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
fusion_group,
ops::FusionGroupKernel<paddle::platform::CUDADeviceContext, double>,
ops::FusionGroupKernel<paddle::platform::CUDADeviceContext, float>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_code.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class FusionGroupKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("Inputs");
auto outs = ctx.MultiOutput<framework::LoDTensor>("Outs");
int type = ctx.Attr<int>("type");
size_t num_ins = ins.size();
size_t num_outs = outs.size();
auto place = ctx.GetPlace();
for (size_t i = 0; i < num_outs; ++i) {
outs[i]->mutable_data<T>(place);
}
std::string func_name = ctx.Attr<std::string>("func_name");
platform::DeviceCode* dev_code =
platform::DeviceCodePool::Instance().Get(place, func_name);
VLOG(3) << "func_name: " << func_name;
if (type == 0) {
size_t n = ins[0]->numel();
std::vector<void*> args;
args.push_back(&n);
std::vector<const T*> ptrs(num_ins + num_outs);
for (size_t i = 0; i < num_ins; ++i) {
ptrs[i] = ins[i]->data<T>();
args.push_back(&ptrs[i]);
}
for (size_t j = 0; j < num_outs; ++j) {
ptrs[num_ins + j] = outs[j]->data<T>();
args.push_back(&ptrs[num_ins + j]);
}
dev_code->Launch(n, &args);
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/device_code.h"
#include "paddle/fluid/platform/init.h"
namespace paddle {
namespace operators {
using CPUKernelFunc = std::function<void(size_t n, std::vector<void*> args)>;
template <typename T>
framework::Tensor* CreateTensor(framework::Scope* scope,
const platform::Place& place,
const std::string& name,
const std::vector<int64_t>& shape) {
auto* var = scope->Var(name);
auto* tensor = var->GetMutable<framework::LoDTensor>();
if (shape.size() > 0) {
tensor->mutable_data<T>(framework::make_ddim(shape), place);
}
return tensor;
}
template <typename T>
void SetupRandomCPUTensor(framework::Tensor* tensor,
const std::vector<int64_t>& shape) {
static unsigned int seed = 100;
std::mt19937 rng(seed++);
std::uniform_real_distribution<double> uniform_dist(0, 1);
T* ptr = tensor->mutable_data<T>(framework::make_ddim(shape),
platform::CPUPlace());
for (int64_t i = 0; i < tensor->numel(); ++i) {
ptr[i] = static_cast<T>(uniform_dist(rng)) - static_cast<T>(0.5);
}
}
framework::OpDesc* CreateFusionGroupOp(
framework::ProgramDesc* program,
const std::vector<std::string>& input_names,
const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::string>& output_names, int type,
std::string func_name) {
EXPECT_EQ(input_names.size(), input_shapes.size());
for (size_t i = 0; i < input_names.size(); ++i) {
auto* var = program->MutableBlock(0)->Var(input_names[i]);
var->SetType(framework::proto::VarType::LOD_TENSOR);
var->SetDataType(framework::proto::VarType::FP32);
var->SetShape(input_shapes[i]);
}
for (size_t j = 0; j < output_names.size(); ++j) {
auto* var = program->MutableBlock(0)->Var(output_names[j]);
var->SetType(framework::proto::VarType::LOD_TENSOR);
var->SetDataType(framework::proto::VarType::FP32);
}
auto* op = program->MutableBlock(0)->AppendOp();
op->SetType("fusion_group");
op->SetInput("Inputs", input_names);
op->SetOutput("Outs", output_names);
op->SetAttr("type", type);
op->SetAttr("func_name", func_name);
op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(),
static_cast<int>(framework::OpRole::kForward));
return op;
}
void PrepareDeviceCode(platform::Place place, std::string func_name,
std::string cuda_kernel_str) {
paddle::platform::DeviceCodePool& pool =
paddle::platform::DeviceCodePool::Init({place});
std::unique_ptr<paddle::platform::DeviceCode> code(
new paddle::platform::CUDADeviceCode(place, func_name, cuda_kernel_str));
code->Compile();
pool.Set(std::move(code));
}
void CheckOutputs(framework::Scope* scope,
const std::vector<std::string>& output_names,
std::vector<framework::Tensor>* cpu_tensors,
size_t num_inputs, CPUKernelFunc cpu_kernel_func) {
std::vector<framework::Tensor> cpu_outputs;
cpu_outputs.resize(output_names.size());
for (size_t j = 0; j < output_names.size(); ++j) {
auto* var = scope->Var(output_names[j]);
const auto& dev_tensor = var->Get<framework::LoDTensor>();
TensorCopySync(dev_tensor, platform::CPUPlace(), &(cpu_outputs[j]));
cpu_tensors->at(num_inputs + j)
.mutable_data<float>(dev_tensor.dims(), platform::CPUPlace());
}
size_t n = cpu_tensors->at(0).numel();
std::vector<void*> args;
for (size_t i = 0; i < cpu_tensors->size(); ++i) {
args.push_back(cpu_tensors->at(i).data<float>());
}
cpu_kernel_func(n, args);
for (size_t j = 0; j < output_names.size(); ++j) {
auto* dev_ptr = cpu_outputs[j].data<float>();
auto* cpu_ptr = cpu_tensors->at(num_inputs + j).data<float>();
int64_t length = cpu_outputs[j].numel();
LOG(INFO) << "Check the " << j << "th output...";
for (int64_t i = 0; i < length; ++i) {
EXPECT_NEAR(dev_ptr[i], cpu_ptr[i], 1.E-05);
}
}
}
void TestMain(const std::vector<std::string>& input_names,
const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::string>& output_names, int type,
std::string func_name, std::string cuda_kernel_str,
CPUKernelFunc cpu_kernel_func) {
// Compile the device code
paddle::framework::InitDevices(false, {0});
platform::CUDAPlace place = platform::CUDAPlace(0);
PrepareDeviceCode(place, func_name, cuda_kernel_str);
// Create a ProgramDesc that has a fusion_group_op.
framework::ProgramDesc program;
framework::OpDesc* op_desc = CreateFusionGroupOp(
&program, input_names, input_shapes, output_names, type, func_name);
auto fusion_group_op = framework::OpRegistry::CreateOp(*op_desc);
framework::Scope scope;
// Prepare input tensors.
std::vector<framework::Tensor> cpu_tensors;
cpu_tensors.resize(input_names.size() + output_names.size());
for (size_t i = 0; i < input_names.size(); ++i) {
SetupRandomCPUTensor<float>(&(cpu_tensors[i]), input_shapes[i]);
framework::Tensor* dev_tensor =
CreateTensor<float>(&scope, place, input_names[i], input_shapes[i]);
TensorCopySync(cpu_tensors[i], place, dev_tensor);
}
// Create output tensors.
std::vector<int64_t> empty_shape;
for (size_t j = 0; j < output_names.size(); ++j) {
CreateTensor<float>(&scope, place, output_names[j], empty_shape);
}
fusion_group_op->Run(scope, place);
auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
dev_ctx->Wait();
// Check the output.
CheckOutputs(&scope, output_names, &cpu_tensors, input_names.size(),
cpu_kernel_func);
}
TEST(FusionGroupOp, elementwise) {
if (!platform::dynload::HasNVRTC() || !platform::dynload::HasCUDADriver()) {
return;
}
// z = relu(x + y)
std::vector<std::string> input_names = {"x", "y"};
std::vector<std::string> output_names = {"z"};
std::vector<std::vector<int64_t>> input_shapes = {{256, 256}, {256, 256}};
constexpr auto kernel = R"(
static inline __device__ float relu(float x) {
return x * (x > 0);
}
extern "C" __global__
void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) {
for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n;
tid += blockDim.x * gridDim.x) {
float tmp_0 = x[tid];
float tmp_1 = y[tid];
float tmp_2 = tmp_0 + tmp_1;
float tmp_3 = relu(tmp_2);
z[tid] = tmp_3;
}
})";
auto elementwise_cpu_kernel_0 = [](size_t n,
std::vector<void*> args) -> void {
float* x = static_cast<float*>(args[0]);
float* y = static_cast<float*>(args[1]);
float* z = static_cast<float*>(args[2]);
for (size_t i = 0; i < n; ++i) {
float tmp_0 = x[i];
float tmp_1 = y[i];
float tmp_2 = tmp_0 + tmp_1;
float tmp_3 = tmp_2 > 0 ? tmp_2 : 0;
z[i] = tmp_3;
}
};
TestMain(input_names, input_shapes, output_names, 0,
"elementwise_cuda_kernel_0", kernel, elementwise_cpu_kernel_0);
}
} // namespace operators
} // namespace paddle
USE_CUDA_ONLY_OP(fusion_group);
......@@ -14,26 +14,76 @@ limitations under the License. */
#include "paddle/fluid/platform/device_code.h"
#include <algorithm>
#include <set>
#include <utility>
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
#ifdef PADDLE_WITH_CUDA
inline bool is_error(nvrtcResult stat) { return stat != NVRTC_SUCCESS; }
DeviceCodePool* DeviceCodePool::pool = nullptr;
void DeviceCodePool::Set(std::unique_ptr<DeviceCode>&& code) {
Place place = code->GetPlace();
std::string name = code->GetName();
auto iter = device_codes_.find(place);
if (iter == device_codes_.end()) {
PADDLE_THROW(platform::errors::NotFound(
"Place %s is not supported for runtime compiling.", place));
}
auto& codes_map = iter->second;
codes_map.emplace(name, std::move(code));
}
platform::DeviceCode* DeviceCodePool::Get(const platform::Place& place,
const std::string& name) {
auto iter = device_codes_.find(place);
if (iter == device_codes_.end()) {
PADDLE_THROW(platform::errors::NotFound(
"Place %s is not supported for runtime compiling.", place));
}
auto& codes_map = iter->second;
auto code_iter = codes_map.find(name);
if (code_iter == codes_map.end()) {
PADDLE_THROW(platform::errors::NotFound(
"Device code named %s for place %s does not exist.", name.c_str(),
place));
}
return code_iter->second.get();
}
inline void throw_on_error(nvrtcResult stat, const std::string& msg) {
#ifndef REPLACE_ENFORCE_GLOG
throw std::runtime_error(dynload::nvrtcGetErrorString(stat) + msg);
DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) {
PADDLE_ENFORCE_GT(
places.size(), 0,
errors::InvalidArgument(
"Expected the number of places >= 1. Expected %d.", places.size()));
// Remove the duplicated places
std::set<Place> set;
for (auto& p : places) {
set.insert(p);
}
for (auto& p : set) {
if (is_gpu_place(p)) {
#ifdef PADDLE_WITH_CUDA
device_codes_.emplace(p, DeviceCodeMap());
#else
LOG(FATAL) << dynload::nvrtcGetErrorString(stat) << msg;
PADDLE_THROW(platform::errors::PreconditionNotMet(
"CUDAPlace is not supported, please re-compile with WITH_GPU=ON."));
#endif
}
}
}
#ifdef PADDLE_WITH_CUDA
CUDADeviceCode::CUDADeviceCode(const Place& place, const std::string& name,
const std::string& kernel) {
if (!is_gpu_place(place)) {
PADDLE_THROW("CUDADeviceCode can only launch on GPU place.");
PADDLE_THROW(platform::errors::PermissionDenied(
"CUDADeviceCode can only launch on GPU place."));
}
place_ = place;
......@@ -41,16 +91,24 @@ CUDADeviceCode::CUDADeviceCode(const Place& place, const std::string& name,
kernel_ = kernel;
}
void CUDADeviceCode::Compile() {
bool CUDADeviceCode::Compile() {
is_compiled_ = false;
if (!dynload::HasNVRTC() || !dynload::HasCUDADriver()) {
LOG(WARNING)
<< "NVRTC and CUDA driver are need for JIT compiling of CUDA code.";
return false;
}
nvrtcProgram program;
PADDLE_ENFORCE_EQ(dynload::nvrtcCreateProgram(&program,
kernel_.c_str(), // buffer
name_.c_str(), // name
0, // numHeaders
nullptr, // headers
nullptr), // includeNames
NVRTC_SUCCESS,
"nvrtcCreateProgram failed.");
if (!CheckNVRTCResult(dynload::nvrtcCreateProgram(&program,
kernel_.c_str(), // buffer
name_.c_str(), // name
0, // numHeaders
nullptr, // headers
nullptr), // includeNames
"nvrtcCreateProgram")) {
return false;
}
// Compile the program for specified compute_capability
auto* dev_ctx = reinterpret_cast<CUDADeviceContext*>(
......@@ -67,38 +125,62 @@ void CUDADeviceCode::Compile() {
if (compile_result == NVRTC_ERROR_COMPILATION) {
// Obtain compilation log from the program
size_t log_size;
PADDLE_ENFORCE_EQ(dynload::nvrtcGetProgramLogSize(program, &log_size),
NVRTC_SUCCESS, "nvrtcGetProgramLogSize failed.");
if (!CheckNVRTCResult(dynload::nvrtcGetProgramLogSize(program, &log_size),
"nvrtcGetProgramLogSize")) {
return false;
}
std::vector<char> log;
log.resize(log_size + 1);
PADDLE_ENFORCE_EQ(dynload::nvrtcGetProgramLog(program, log.data()),
NVRTC_SUCCESS, "nvrtcGetProgramLog failed.");
LOG(FATAL) << "JIT compiling of CUDA code failed:\n" << log.data();
if (!CheckNVRTCResult(dynload::nvrtcGetProgramLog(program, log.data()),
"nvrtcGetProgramLog")) {
return false;
}
LOG(WARNING) << "JIT compiling of CUDA code failed:"
<< "\n Kernel name: " << name_ << "\n Kernel body:\n"
<< kernel_ << "\n Compiling log: " << log.data();
return false;
}
// Obtain PTX from the program
size_t ptx_size;
PADDLE_ENFORCE_EQ(dynload::nvrtcGetPTXSize(program, &ptx_size), NVRTC_SUCCESS,
"nvrtcGetPTXSize failed.");
if (!CheckNVRTCResult(dynload::nvrtcGetPTXSize(program, &ptx_size),
"nvrtcGetPTXSize")) {
return false;
}
ptx_.resize(ptx_size + 1);
PADDLE_ENFORCE_EQ(dynload::nvrtcGetPTX(program, ptx_.data()), NVRTC_SUCCESS,
"nvrtcGetPTX failed.");
if (!CheckNVRTCResult(dynload::nvrtcGetPTX(program, ptx_.data()),
"nvrtcGetPTX")) {
return false;
}
PADDLE_ENFORCE_EQ(dynload::nvrtcDestroyProgram(&program), NVRTC_SUCCESS,
"nvrtcDestroyProgram failed.");
if (!CheckNVRTCResult(dynload::nvrtcDestroyProgram(&program),
"nvrtcDestroyProgram")) {
return false;
}
PADDLE_ENFORCE_EQ(
dynload::cuModuleLoadData(&module_, ptx_.data()), CUDA_SUCCESS,
"Fail to load PTX of %s (in cuModuleLoadData.)", name_.c_str());
PADDLE_ENFORCE_EQ(
dynload::cuModuleGetFunction(&function_, module_, name_.c_str()),
CUDA_SUCCESS, "Fail to get function of %s (in cuModuleGetFunction.)",
name_.c_str());
if (!CheckCUDADriverResult(dynload::cuModuleLoadData(&module_, ptx_.data()),
"cuModuleLoadData")) {
return false;
}
if (!CheckCUDADriverResult(
dynload::cuModuleGetFunction(&function_, module_, name_.c_str()),
"cuModuleGetFunction")) {
return false;
}
max_threads_ = dev_ctx->GetMaxPhysicalThreadCount();
is_compiled_ = true;
return true;
}
void CUDADeviceCode::Launch(const size_t n, std::vector<void*>* args) const {
PADDLE_ENFORCE_EQ(
is_compiled_, true,
errors::PreconditionNotMet(
"Please compile the code before launching the kernel."));
int max_blocks = std::max(max_threads_ / num_threads_, 1);
int workload_per_block = workload_per_thread_ * num_threads_;
int num_blocks =
......@@ -114,8 +196,30 @@ void CUDADeviceCode::Launch(const size_t n, std::vector<void*>* args) const {
dev_ctx->stream(), // stream
args->data(), // arguments
nullptr),
CUDA_SUCCESS, "Fail to launch kernel %s (in cuLaunchKernel.)",
name_.c_str());
CUDA_SUCCESS,
errors::External("Fail to launch kernel %s (in cuLaunchKernel.)",
name_.c_str()));
}
bool CUDADeviceCode::CheckNVRTCResult(nvrtcResult result,
std::string function) {
if (result != NVRTC_SUCCESS) {
LOG(WARNING) << "Call " << function
<< " failed: " << dynload::nvrtcGetErrorString(result);
return false;
}
return true;
}
bool CUDADeviceCode::CheckCUDADriverResult(CUresult result,
std::string function) {
if (result != CUDA_SUCCESS) {
const char* error = nullptr;
LOG(WARNING) << "Call " << function
<< " failed: " << dynload::cuGetErrorString(result, &error);
return false;
}
return true;
}
#endif
......
......@@ -14,7 +14,10 @@ limitations under the License. */
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_CUDA
......@@ -28,9 +31,12 @@ namespace platform {
class DeviceCode {
public:
virtual ~DeviceCode() {}
virtual void Compile() = 0;
virtual bool Compile() = 0;
virtual void Launch(const size_t n, std::vector<void*>* args) const = 0;
Place GetPlace() const { return place_; }
std::string GetName() const { return name_; }
protected:
Place place_;
std::string name_;
......@@ -42,7 +48,7 @@ class CUDADeviceCode : public DeviceCode {
public:
explicit CUDADeviceCode(const Place& place, const std::string& name,
const std::string& kernel);
void Compile() override;
bool Compile() override;
void Launch(const size_t n, std::vector<void*>* args) const override;
void SetNumThreads(int num_threads) { num_threads_ = num_threads; }
......@@ -51,6 +57,10 @@ class CUDADeviceCode : public DeviceCode {
}
private:
bool CheckNVRTCResult(nvrtcResult result, std::string function);
bool CheckCUDADriverResult(CUresult result, std::string function);
bool is_compiled_{false};
int max_threads_{0};
int num_threads_{1024};
int workload_per_thread_{1};
......@@ -60,5 +70,46 @@ class CUDADeviceCode : public DeviceCode {
};
#endif
class DeviceCodePool {
public:
using DeviceCodeMap =
std::unordered_map<std::string, std::unique_ptr<DeviceCode>>;
explicit DeviceCodePool(const std::vector<platform::Place>& places);
static DeviceCodePool& Instance() {
PADDLE_ENFORCE_NOT_NULL(
pool,
errors::NotFound("Need to create DeviceCodePool first, by calling "
"DeviceCodePool::Init(places)!"));
return *pool;
}
static DeviceCodePool& Init(const std::vector<platform::Place>& places) {
if (pool == nullptr) {
pool = new DeviceCodePool(places);
}
return *pool;
}
void Set(std::unique_ptr<DeviceCode>&& code);
platform::DeviceCode* Get(const platform::Place& place,
const std::string& name);
size_t size(const platform::Place& place) const {
auto iter = device_codes_.find(place);
if (iter == device_codes_.end()) {
return 0;
}
return iter->second.size();
}
private:
static DeviceCodePool* pool;
std::map<Place, DeviceCodeMap> device_codes_;
DISABLE_COPY_AND_ASSIGN(DeviceCodePool);
};
} // namespace platform
} // namespace paddle
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_code.h"
#include <utility>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/init.h"
......@@ -28,7 +29,12 @@ void saxpy_kernel(float a, float *x, float* y, float* z, size_t n) {
)";
#ifdef PADDLE_WITH_CUDA
TEST(device_code, cuda) {
TEST(DeviceCode, cuda) {
if (!paddle::platform::dynload::HasNVRTC() ||
!paddle::platform::dynload::HasCUDADriver()) {
return;
}
paddle::framework::InitDevices(false, {0});
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceCode code(place, "saxpy_kernel", saxpy_code);
......@@ -62,17 +68,42 @@ TEST(device_code, cuda) {
TensorCopySync(cpu_x, place, &x);
TensorCopySync(cpu_y, place, &y);
code.Compile();
EXPECT_EQ(code.Compile(), true);
std::vector<void*> args = {&scale, &x_data, &y_data, &z_data, &n};
code.SetNumThreads(1024);
code.SetWorkloadPerThread(1);
code.Launch(n, &args);
auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
dev_ctx->Wait();
TensorCopySync(z, paddle::platform::CPUPlace(), &cpu_z);
for (size_t i = 0; i < n; i++) {
PADDLE_ENFORCE_EQ(cpu_z.data<float>()[i],
static_cast<float>(i) * scale + 0.5);
EXPECT_EQ(cpu_z.data<float>()[i], static_cast<float>(i) * scale + 0.5);
}
}
TEST(DeviceCodePool, cuda) {
if (!paddle::platform::dynload::HasNVRTC()) {
return;
}
paddle::framework::InitDevices(false, {0});
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
paddle::platform::DeviceCodePool& pool =
paddle::platform::DeviceCodePool::Init({place});
size_t num_device_codes_before = pool.size(place);
EXPECT_EQ(num_device_codes_before, 0UL);
std::unique_ptr<paddle::platform::DeviceCode> code(
new paddle::platform::CUDADeviceCode(place, "saxpy_kernel", saxpy_code));
LOG(INFO) << "origin ptr: " << code.get();
pool.Set(std::move(code));
size_t num_device_codes_after = pool.size(place);
EXPECT_EQ(num_device_codes_after, 1UL);
paddle::platform::DeviceCode* code_get = pool.Get(place, "saxpy_kernel");
LOG(INFO) << "get ptr: " << code_get;
}
#endif
......@@ -25,6 +25,15 @@ void* cuda_dso_handle = nullptr;
CUDA_ROUTINE_EACH(DEFINE_WRAP);
#ifdef PADDLE_USE_DSO
bool HasCUDADriver() {
std::call_once(cuda_dso_flag, []() { cuda_dso_handle = GetCUDADsoHandle(); });
return cuda_dso_handle != nullptr;
}
#else
bool HasCUDADriver() { return false; }
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -25,6 +25,7 @@ namespace dynload {
extern std::once_flag cuda_dso_flag;
extern void* cuda_dso_handle;
extern bool HasCUDADriver();
#ifdef PADDLE_USE_DSO
......
......@@ -25,6 +25,16 @@ void* nvrtc_dso_handle = nullptr;
NVRTC_ROUTINE_EACH(DEFINE_WRAP);
#ifdef PADDLE_USE_DSO
bool HasNVRTC() {
std::call_once(nvrtc_dso_flag,
[]() { nvrtc_dso_handle = GetNVRTCDsoHandle(); });
return nvrtc_dso_handle != nullptr;
}
#else
bool HasNVRTC() { return false; }
#endif
} // namespace dynload
} // namespace platform
} // namespace paddle
......@@ -25,6 +25,7 @@ namespace dynload {
extern std::once_flag nvrtc_dso_flag;
extern void* nvrtc_dso_handle;
extern bool HasNVRTC();
#ifdef PADDLE_USE_DSO
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册