diff --git a/cmake/operators.cmake b/cmake/operators.cmake index f77240e3177f4d186d33fd62cd1ac5c1d39c758b..afad22ca2b4eef0480ba2c454dc4875673a8860d 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -116,7 +116,9 @@ function(op_library TARGET) # Define operators that don't need pybind here. foreach(manual_pybind_op "compare_op" "logical_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" -"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" "multihead_matmul_op") +"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" +"sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" +"multihead_matmul_op" "fusion_group_op") if ("${TARGET}" STREQUAL "${manual_pybind_op}") set(pybind_flag 1) endif() diff --git a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc index 1f06e46c4744c3f8cf545b0baca8b228beb07f04..4999acbf7daf2999ffbd00283889efbf00351c6d 100644 --- a/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc +++ b/paddle/fluid/framework/ir/fusion_group/fusion_group_pass.cc @@ -82,7 +82,7 @@ void FusionGroupPass::InsertFusionGroupOp( input_names.push_back(n->Name()); external_nodes.insert(n); } - op_desc.SetInput("Xs", input_names); + op_desc.SetInput("Inputs", input_names); std::vector output_names; for (auto* n : output_vars_of_subgraph) { diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 42529f02920f91f55ef846108710292af93e1e01..db3ff0883fd37409d083c0086de71b28ce30ee80 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -4,7 +4,8 @@ register_operators(EXCLUDES fusion_transpose_flatten_concat_op fusion_conv_inception_op fused_fc_elementwise_layernorm_op - multihead_matmul_op) + multihead_matmul_op + fusion_group_op) if (WITH_GPU) # conv_fusion_op needs cudnn 7 above @@ -26,4 +27,10 @@ if (WITH_GPU) # multihead_matmul_op op_library(multihead_matmul_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(multihead_matmul);\n") + # fusion_group + if(NOT APPLE AND NOT WIN32) + op_library(fusion_group_op DEPS device_code) + file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fusion_group);\n") + cc_test(test_fusion_group_op SRCS fusion_group_op_test.cc DEPS fusion_group_op) + endif() endif() diff --git a/paddle/fluid/operators/fused/fusion_group_op.cc b/paddle/fluid/operators/fused/fusion_group_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5880c3b317e6d5ed2a5b5cef80186928755e746a --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_group_op.cc @@ -0,0 +1,90 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fusion_group_op.h" + +namespace paddle { +namespace operators { + +class FusionGroupOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + const size_t num_ins = ctx->Inputs("Inputs").size(); + const size_t num_outs = ctx->Outputs("Outs").size(); + + PADDLE_ENFORCE_GE( + num_ins, 1UL, + platform::errors::InvalidArgument( + "Expected the number of inputs >= 1. Received %d.", num_ins)); + PADDLE_ENFORCE_GE( + num_outs, 1UL, + platform::errors::InvalidArgument( + "Expected the number of outputs >= 1. Recived %d.", num_outs)); + + int type = ctx->Attrs().Get("type"); + PADDLE_ENFORCE_EQ(type, 0UL, + platform::errors::InvalidArgument( + "Only support fusion of elementwise operations.")); + + std::vector x_dims = ctx->GetInputsDim("Inputs"); + if (type == 0) { + for (size_t i = 1; i < num_ins; ++i) { + PADDLE_ENFORCE_EQ(x_dims[0], x_dims[i], + platform::errors::InvalidArgument( + "All the inputs' dims should be the same.")); + } + std::vector out_dims; + for (size_t j = 0; j < num_outs; ++j) { + out_dims.push_back(x_dims[0]); + } + ctx->SetOutputsDim("Outs", out_dims); + } + + // Only lod of Inputs[0] would be shared with Outs. + for (size_t j = 0; j < num_outs; ++j) { + ctx->ShareLoD("Inputs", /*->*/ "Outs", 0, j); + } + } +}; + +class FusionGroupOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Inputs", + "(std::vector) The inputs of fusion_group op.") + .AsDuplicable(); + AddOutput("Outs", + "(std::vector) The outputs of fusion_group op.") + .AsDuplicable(); + AddAttr("type", "Fusion type.").SetDefault(0); + AddAttr("func_name", "Name of the generated functions.") + .SetDefault(""); + AddComment(R"DOC( +fusion_group Operator. + +It is used to execute a generated CUDA kernel which fuse the computation of +multiple operators into one. It supports serveral types: +0, fused computation of elementwise operations in which all the dims of inputs + and outputs should be exactly the same. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fusion_group, ops::FusionGroupOp, ops::FusionGroupOpMaker); diff --git a/paddle/fluid/operators/fused/fusion_group_op.cu.cc b/paddle/fluid/operators/fused/fusion_group_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..63c243beafb9cdfd20f6d8343a1e2bd9b716c66e --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_group_op.cu.cc @@ -0,0 +1,22 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused/fusion_group_op.h" + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + fusion_group, + ops::FusionGroupKernel, + ops::FusionGroupKernel); diff --git a/paddle/fluid/operators/fused/fusion_group_op.h b/paddle/fluid/operators/fused/fusion_group_op.h new file mode 100644 index 0000000000000000000000000000000000000000..cc8af48792f27c425db4a097283d5af5535d3730 --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_group_op.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/device_code.h" + +namespace paddle { +namespace operators { + +template +class FusionGroupKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto ins = ctx.MultiInput("Inputs"); + auto outs = ctx.MultiOutput("Outs"); + int type = ctx.Attr("type"); + + size_t num_ins = ins.size(); + size_t num_outs = outs.size(); + + auto place = ctx.GetPlace(); + for (size_t i = 0; i < num_outs; ++i) { + outs[i]->mutable_data(place); + } + + std::string func_name = ctx.Attr("func_name"); + platform::DeviceCode* dev_code = + platform::DeviceCodePool::Instance().Get(place, func_name); + VLOG(3) << "func_name: " << func_name; + + if (type == 0) { + size_t n = ins[0]->numel(); + std::vector args; + args.push_back(&n); + std::vector ptrs(num_ins + num_outs); + for (size_t i = 0; i < num_ins; ++i) { + ptrs[i] = ins[i]->data(); + args.push_back(&ptrs[i]); + } + for (size_t j = 0; j < num_outs; ++j) { + ptrs[num_ins + j] = outs[j]->data(); + args.push_back(&ptrs[num_ins + j]); + } + dev_code->Launch(n, &args); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fusion_group_op_test.cc b/paddle/fluid/operators/fused/fusion_group_op_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..81acb0791c61c6b9861abd46914fb4d5fa52de3c --- /dev/null +++ b/paddle/fluid/operators/fused/fusion_group_op_test.cc @@ -0,0 +1,220 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/platform/device_code.h" +#include "paddle/fluid/platform/init.h" + +namespace paddle { +namespace operators { + +using CPUKernelFunc = std::function args)>; + +template +framework::Tensor* CreateTensor(framework::Scope* scope, + const platform::Place& place, + const std::string& name, + const std::vector& shape) { + auto* var = scope->Var(name); + auto* tensor = var->GetMutable(); + if (shape.size() > 0) { + tensor->mutable_data(framework::make_ddim(shape), place); + } + return tensor; +} + +template +void SetupRandomCPUTensor(framework::Tensor* tensor, + const std::vector& shape) { + static unsigned int seed = 100; + std::mt19937 rng(seed++); + std::uniform_real_distribution uniform_dist(0, 1); + + T* ptr = tensor->mutable_data(framework::make_ddim(shape), + platform::CPUPlace()); + for (int64_t i = 0; i < tensor->numel(); ++i) { + ptr[i] = static_cast(uniform_dist(rng)) - static_cast(0.5); + } +} + +framework::OpDesc* CreateFusionGroupOp( + framework::ProgramDesc* program, + const std::vector& input_names, + const std::vector>& input_shapes, + const std::vector& output_names, int type, + std::string func_name) { + EXPECT_EQ(input_names.size(), input_shapes.size()); + + for (size_t i = 0; i < input_names.size(); ++i) { + auto* var = program->MutableBlock(0)->Var(input_names[i]); + var->SetType(framework::proto::VarType::LOD_TENSOR); + var->SetDataType(framework::proto::VarType::FP32); + var->SetShape(input_shapes[i]); + } + for (size_t j = 0; j < output_names.size(); ++j) { + auto* var = program->MutableBlock(0)->Var(output_names[j]); + var->SetType(framework::proto::VarType::LOD_TENSOR); + var->SetDataType(framework::proto::VarType::FP32); + } + + auto* op = program->MutableBlock(0)->AppendOp(); + op->SetType("fusion_group"); + op->SetInput("Inputs", input_names); + op->SetOutput("Outs", output_names); + op->SetAttr("type", type); + op->SetAttr("func_name", func_name); + op->SetAttr(framework::OpProtoAndCheckerMaker::OpRoleAttrName(), + static_cast(framework::OpRole::kForward)); + return op; +} + +void PrepareDeviceCode(platform::Place place, std::string func_name, + std::string cuda_kernel_str) { + paddle::platform::DeviceCodePool& pool = + paddle::platform::DeviceCodePool::Init({place}); + + std::unique_ptr code( + new paddle::platform::CUDADeviceCode(place, func_name, cuda_kernel_str)); + code->Compile(); + pool.Set(std::move(code)); +} + +void CheckOutputs(framework::Scope* scope, + const std::vector& output_names, + std::vector* cpu_tensors, + size_t num_inputs, CPUKernelFunc cpu_kernel_func) { + std::vector cpu_outputs; + cpu_outputs.resize(output_names.size()); + for (size_t j = 0; j < output_names.size(); ++j) { + auto* var = scope->Var(output_names[j]); + const auto& dev_tensor = var->Get(); + TensorCopySync(dev_tensor, platform::CPUPlace(), &(cpu_outputs[j])); + + cpu_tensors->at(num_inputs + j) + .mutable_data(dev_tensor.dims(), platform::CPUPlace()); + } + + size_t n = cpu_tensors->at(0).numel(); + std::vector args; + for (size_t i = 0; i < cpu_tensors->size(); ++i) { + args.push_back(cpu_tensors->at(i).data()); + } + cpu_kernel_func(n, args); + + for (size_t j = 0; j < output_names.size(); ++j) { + auto* dev_ptr = cpu_outputs[j].data(); + auto* cpu_ptr = cpu_tensors->at(num_inputs + j).data(); + int64_t length = cpu_outputs[j].numel(); + LOG(INFO) << "Check the " << j << "th output..."; + for (int64_t i = 0; i < length; ++i) { + EXPECT_NEAR(dev_ptr[i], cpu_ptr[i], 1.E-05); + } + } +} + +void TestMain(const std::vector& input_names, + const std::vector>& input_shapes, + const std::vector& output_names, int type, + std::string func_name, std::string cuda_kernel_str, + CPUKernelFunc cpu_kernel_func) { + // Compile the device code + paddle::framework::InitDevices(false, {0}); + platform::CUDAPlace place = platform::CUDAPlace(0); + PrepareDeviceCode(place, func_name, cuda_kernel_str); + + // Create a ProgramDesc that has a fusion_group_op. + framework::ProgramDesc program; + framework::OpDesc* op_desc = CreateFusionGroupOp( + &program, input_names, input_shapes, output_names, type, func_name); + auto fusion_group_op = framework::OpRegistry::CreateOp(*op_desc); + + framework::Scope scope; + + // Prepare input tensors. + std::vector cpu_tensors; + cpu_tensors.resize(input_names.size() + output_names.size()); + for (size_t i = 0; i < input_names.size(); ++i) { + SetupRandomCPUTensor(&(cpu_tensors[i]), input_shapes[i]); + framework::Tensor* dev_tensor = + CreateTensor(&scope, place, input_names[i], input_shapes[i]); + TensorCopySync(cpu_tensors[i], place, dev_tensor); + } + // Create output tensors. + std::vector empty_shape; + for (size_t j = 0; j < output_names.size(); ++j) { + CreateTensor(&scope, place, output_names[j], empty_shape); + } + + fusion_group_op->Run(scope, place); + + auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place); + dev_ctx->Wait(); + + // Check the output. + CheckOutputs(&scope, output_names, &cpu_tensors, input_names.size(), + cpu_kernel_func); +} + +TEST(FusionGroupOp, elementwise) { + if (!platform::dynload::HasNVRTC() || !platform::dynload::HasCUDADriver()) { + return; + } + + // z = relu(x + y) + std::vector input_names = {"x", "y"}; + std::vector output_names = {"z"}; + std::vector> input_shapes = {{256, 256}, {256, 256}}; + constexpr auto kernel = R"( +static inline __device__ float relu(float x) { + return x * (x > 0); +} + +extern "C" __global__ +void elementwise_cuda_kernel_0(size_t n, float *x, float* y, float* z) { + for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < n; + tid += blockDim.x * gridDim.x) { + float tmp_0 = x[tid]; + float tmp_1 = y[tid]; + float tmp_2 = tmp_0 + tmp_1; + float tmp_3 = relu(tmp_2); + z[tid] = tmp_3; + } +})"; + + auto elementwise_cpu_kernel_0 = [](size_t n, + std::vector args) -> void { + float* x = static_cast(args[0]); + float* y = static_cast(args[1]); + float* z = static_cast(args[2]); + for (size_t i = 0; i < n; ++i) { + float tmp_0 = x[i]; + float tmp_1 = y[i]; + float tmp_2 = tmp_0 + tmp_1; + float tmp_3 = tmp_2 > 0 ? tmp_2 : 0; + z[i] = tmp_3; + } + }; + + TestMain(input_names, input_shapes, output_names, 0, + "elementwise_cuda_kernel_0", kernel, elementwise_cpu_kernel_0); +} + +} // namespace operators +} // namespace paddle + +USE_CUDA_ONLY_OP(fusion_group); diff --git a/paddle/fluid/platform/device_code.cc b/paddle/fluid/platform/device_code.cc index 24421b5c3c99bd341c562f4c35df55ad749bdc50..b26de68395d50190f40aa8797f9b0cc9b9dc8ecd 100644 --- a/paddle/fluid/platform/device_code.cc +++ b/paddle/fluid/platform/device_code.cc @@ -14,26 +14,76 @@ limitations under the License. */ #include "paddle/fluid/platform/device_code.h" #include +#include +#include #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace platform { -#ifdef PADDLE_WITH_CUDA -inline bool is_error(nvrtcResult stat) { return stat != NVRTC_SUCCESS; } +DeviceCodePool* DeviceCodePool::pool = nullptr; + +void DeviceCodePool::Set(std::unique_ptr&& code) { + Place place = code->GetPlace(); + std::string name = code->GetName(); + + auto iter = device_codes_.find(place); + if (iter == device_codes_.end()) { + PADDLE_THROW(platform::errors::NotFound( + "Place %s is not supported for runtime compiling.", place)); + } + + auto& codes_map = iter->second; + codes_map.emplace(name, std::move(code)); +} + +platform::DeviceCode* DeviceCodePool::Get(const platform::Place& place, + const std::string& name) { + auto iter = device_codes_.find(place); + if (iter == device_codes_.end()) { + PADDLE_THROW(platform::errors::NotFound( + "Place %s is not supported for runtime compiling.", place)); + } + + auto& codes_map = iter->second; + auto code_iter = codes_map.find(name); + if (code_iter == codes_map.end()) { + PADDLE_THROW(platform::errors::NotFound( + "Device code named %s for place %s does not exist.", name.c_str(), + place)); + } + + return code_iter->second.get(); +} -inline void throw_on_error(nvrtcResult stat, const std::string& msg) { -#ifndef REPLACE_ENFORCE_GLOG - throw std::runtime_error(dynload::nvrtcGetErrorString(stat) + msg); +DeviceCodePool::DeviceCodePool(const std::vector& places) { + PADDLE_ENFORCE_GT( + places.size(), 0, + errors::InvalidArgument( + "Expected the number of places >= 1. Expected %d.", places.size())); + // Remove the duplicated places + std::set set; + for (auto& p : places) { + set.insert(p); + } + for (auto& p : set) { + if (is_gpu_place(p)) { +#ifdef PADDLE_WITH_CUDA + device_codes_.emplace(p, DeviceCodeMap()); #else - LOG(FATAL) << dynload::nvrtcGetErrorString(stat) << msg; + PADDLE_THROW(platform::errors::PreconditionNotMet( + "CUDAPlace is not supported, please re-compile with WITH_GPU=ON.")); #endif + } + } } +#ifdef PADDLE_WITH_CUDA CUDADeviceCode::CUDADeviceCode(const Place& place, const std::string& name, const std::string& kernel) { if (!is_gpu_place(place)) { - PADDLE_THROW("CUDADeviceCode can only launch on GPU place."); + PADDLE_THROW(platform::errors::PermissionDenied( + "CUDADeviceCode can only launch on GPU place.")); } place_ = place; @@ -41,16 +91,24 @@ CUDADeviceCode::CUDADeviceCode(const Place& place, const std::string& name, kernel_ = kernel; } -void CUDADeviceCode::Compile() { +bool CUDADeviceCode::Compile() { + is_compiled_ = false; + if (!dynload::HasNVRTC() || !dynload::HasCUDADriver()) { + LOG(WARNING) + << "NVRTC and CUDA driver are need for JIT compiling of CUDA code."; + return false; + } + nvrtcProgram program; - PADDLE_ENFORCE_EQ(dynload::nvrtcCreateProgram(&program, - kernel_.c_str(), // buffer - name_.c_str(), // name - 0, // numHeaders - nullptr, // headers - nullptr), // includeNames - NVRTC_SUCCESS, - "nvrtcCreateProgram failed."); + if (!CheckNVRTCResult(dynload::nvrtcCreateProgram(&program, + kernel_.c_str(), // buffer + name_.c_str(), // name + 0, // numHeaders + nullptr, // headers + nullptr), // includeNames + "nvrtcCreateProgram")) { + return false; + } // Compile the program for specified compute_capability auto* dev_ctx = reinterpret_cast( @@ -67,38 +125,62 @@ void CUDADeviceCode::Compile() { if (compile_result == NVRTC_ERROR_COMPILATION) { // Obtain compilation log from the program size_t log_size; - PADDLE_ENFORCE_EQ(dynload::nvrtcGetProgramLogSize(program, &log_size), - NVRTC_SUCCESS, "nvrtcGetProgramLogSize failed."); + if (!CheckNVRTCResult(dynload::nvrtcGetProgramLogSize(program, &log_size), + "nvrtcGetProgramLogSize")) { + return false; + } std::vector log; log.resize(log_size + 1); - PADDLE_ENFORCE_EQ(dynload::nvrtcGetProgramLog(program, log.data()), - NVRTC_SUCCESS, "nvrtcGetProgramLog failed."); - LOG(FATAL) << "JIT compiling of CUDA code failed:\n" << log.data(); + if (!CheckNVRTCResult(dynload::nvrtcGetProgramLog(program, log.data()), + "nvrtcGetProgramLog")) { + return false; + } + LOG(WARNING) << "JIT compiling of CUDA code failed:" + << "\n Kernel name: " << name_ << "\n Kernel body:\n" + << kernel_ << "\n Compiling log: " << log.data(); + + return false; } // Obtain PTX from the program size_t ptx_size; - PADDLE_ENFORCE_EQ(dynload::nvrtcGetPTXSize(program, &ptx_size), NVRTC_SUCCESS, - "nvrtcGetPTXSize failed."); + if (!CheckNVRTCResult(dynload::nvrtcGetPTXSize(program, &ptx_size), + "nvrtcGetPTXSize")) { + return false; + } ptx_.resize(ptx_size + 1); - PADDLE_ENFORCE_EQ(dynload::nvrtcGetPTX(program, ptx_.data()), NVRTC_SUCCESS, - "nvrtcGetPTX failed."); + if (!CheckNVRTCResult(dynload::nvrtcGetPTX(program, ptx_.data()), + "nvrtcGetPTX")) { + return false; + } - PADDLE_ENFORCE_EQ(dynload::nvrtcDestroyProgram(&program), NVRTC_SUCCESS, - "nvrtcDestroyProgram failed."); + if (!CheckNVRTCResult(dynload::nvrtcDestroyProgram(&program), + "nvrtcDestroyProgram")) { + return false; + } - PADDLE_ENFORCE_EQ( - dynload::cuModuleLoadData(&module_, ptx_.data()), CUDA_SUCCESS, - "Fail to load PTX of %s (in cuModuleLoadData.)", name_.c_str()); - PADDLE_ENFORCE_EQ( - dynload::cuModuleGetFunction(&function_, module_, name_.c_str()), - CUDA_SUCCESS, "Fail to get function of %s (in cuModuleGetFunction.)", - name_.c_str()); + if (!CheckCUDADriverResult(dynload::cuModuleLoadData(&module_, ptx_.data()), + "cuModuleLoadData")) { + return false; + } + + if (!CheckCUDADriverResult( + dynload::cuModuleGetFunction(&function_, module_, name_.c_str()), + "cuModuleGetFunction")) { + return false; + } max_threads_ = dev_ctx->GetMaxPhysicalThreadCount(); + is_compiled_ = true; + return true; } void CUDADeviceCode::Launch(const size_t n, std::vector* args) const { + PADDLE_ENFORCE_EQ( + is_compiled_, true, + errors::PreconditionNotMet( + "Please compile the code before launching the kernel.")); + int max_blocks = std::max(max_threads_ / num_threads_, 1); int workload_per_block = workload_per_thread_ * num_threads_; int num_blocks = @@ -114,8 +196,30 @@ void CUDADeviceCode::Launch(const size_t n, std::vector* args) const { dev_ctx->stream(), // stream args->data(), // arguments nullptr), - CUDA_SUCCESS, "Fail to launch kernel %s (in cuLaunchKernel.)", - name_.c_str()); + CUDA_SUCCESS, + errors::External("Fail to launch kernel %s (in cuLaunchKernel.)", + name_.c_str())); +} + +bool CUDADeviceCode::CheckNVRTCResult(nvrtcResult result, + std::string function) { + if (result != NVRTC_SUCCESS) { + LOG(WARNING) << "Call " << function + << " failed: " << dynload::nvrtcGetErrorString(result); + return false; + } + return true; +} + +bool CUDADeviceCode::CheckCUDADriverResult(CUresult result, + std::string function) { + if (result != CUDA_SUCCESS) { + const char* error = nullptr; + LOG(WARNING) << "Call " << function + << " failed: " << dynload::cuGetErrorString(result, &error); + return false; + } + return true; } #endif diff --git a/paddle/fluid/platform/device_code.h b/paddle/fluid/platform/device_code.h index 19adb0707f1742e9a41c4eaec549f7ccd5101acb..2895c568b6e8de2d04d4a328c56661e850e44eda 100644 --- a/paddle/fluid/platform/device_code.h +++ b/paddle/fluid/platform/device_code.h @@ -14,7 +14,10 @@ limitations under the License. */ #pragma once +#include +#include #include +#include #include #include "paddle/fluid/platform/device_context.h" #ifdef PADDLE_WITH_CUDA @@ -28,9 +31,12 @@ namespace platform { class DeviceCode { public: virtual ~DeviceCode() {} - virtual void Compile() = 0; + virtual bool Compile() = 0; virtual void Launch(const size_t n, std::vector* args) const = 0; + Place GetPlace() const { return place_; } + std::string GetName() const { return name_; } + protected: Place place_; std::string name_; @@ -42,7 +48,7 @@ class CUDADeviceCode : public DeviceCode { public: explicit CUDADeviceCode(const Place& place, const std::string& name, const std::string& kernel); - void Compile() override; + bool Compile() override; void Launch(const size_t n, std::vector* args) const override; void SetNumThreads(int num_threads) { num_threads_ = num_threads; } @@ -51,6 +57,10 @@ class CUDADeviceCode : public DeviceCode { } private: + bool CheckNVRTCResult(nvrtcResult result, std::string function); + bool CheckCUDADriverResult(CUresult result, std::string function); + + bool is_compiled_{false}; int max_threads_{0}; int num_threads_{1024}; int workload_per_thread_{1}; @@ -60,5 +70,46 @@ class CUDADeviceCode : public DeviceCode { }; #endif +class DeviceCodePool { + public: + using DeviceCodeMap = + std::unordered_map>; + + explicit DeviceCodePool(const std::vector& places); + + static DeviceCodePool& Instance() { + PADDLE_ENFORCE_NOT_NULL( + pool, + errors::NotFound("Need to create DeviceCodePool first, by calling " + "DeviceCodePool::Init(places)!")); + return *pool; + } + + static DeviceCodePool& Init(const std::vector& places) { + if (pool == nullptr) { + pool = new DeviceCodePool(places); + } + return *pool; + } + + void Set(std::unique_ptr&& code); + + platform::DeviceCode* Get(const platform::Place& place, + const std::string& name); + + size_t size(const platform::Place& place) const { + auto iter = device_codes_.find(place); + if (iter == device_codes_.end()) { + return 0; + } + return iter->second.size(); + } + + private: + static DeviceCodePool* pool; + std::map device_codes_; + DISABLE_COPY_AND_ASSIGN(DeviceCodePool); +}; + } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device_code_test.cc b/paddle/fluid/platform/device_code_test.cc index 3b63ed4e369c7c9ccecf8a6b7e2272973a44e266..aa6bce6f1e54f86098b6df729ff00373ff0a638f 100644 --- a/paddle/fluid/platform/device_code_test.cc +++ b/paddle/fluid/platform/device_code_test.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/platform/device_code.h" +#include #include "gtest/gtest.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/platform/init.h" @@ -28,7 +29,12 @@ void saxpy_kernel(float a, float *x, float* y, float* z, size_t n) { )"; #ifdef PADDLE_WITH_CUDA -TEST(device_code, cuda) { +TEST(DeviceCode, cuda) { + if (!paddle::platform::dynload::HasNVRTC() || + !paddle::platform::dynload::HasCUDADriver()) { + return; + } + paddle::framework::InitDevices(false, {0}); paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0); paddle::platform::CUDADeviceCode code(place, "saxpy_kernel", saxpy_code); @@ -62,17 +68,42 @@ TEST(device_code, cuda) { TensorCopySync(cpu_x, place, &x); TensorCopySync(cpu_y, place, &y); - code.Compile(); + EXPECT_EQ(code.Compile(), true); std::vector args = {&scale, &x_data, &y_data, &z_data, &n}; code.SetNumThreads(1024); code.SetWorkloadPerThread(1); code.Launch(n, &args); + auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place); + dev_ctx->Wait(); + TensorCopySync(z, paddle::platform::CPUPlace(), &cpu_z); for (size_t i = 0; i < n; i++) { - PADDLE_ENFORCE_EQ(cpu_z.data()[i], - static_cast(i) * scale + 0.5); + EXPECT_EQ(cpu_z.data()[i], static_cast(i) * scale + 0.5); + } +} + +TEST(DeviceCodePool, cuda) { + if (!paddle::platform::dynload::HasNVRTC()) { + return; } + + paddle::framework::InitDevices(false, {0}); + paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0); + paddle::platform::DeviceCodePool& pool = + paddle::platform::DeviceCodePool::Init({place}); + size_t num_device_codes_before = pool.size(place); + EXPECT_EQ(num_device_codes_before, 0UL); + + std::unique_ptr code( + new paddle::platform::CUDADeviceCode(place, "saxpy_kernel", saxpy_code)); + LOG(INFO) << "origin ptr: " << code.get(); + pool.Set(std::move(code)); + size_t num_device_codes_after = pool.size(place); + EXPECT_EQ(num_device_codes_after, 1UL); + + paddle::platform::DeviceCode* code_get = pool.Get(place, "saxpy_kernel"); + LOG(INFO) << "get ptr: " << code_get; } #endif diff --git a/paddle/fluid/platform/dynload/cuda_driver.cc b/paddle/fluid/platform/dynload/cuda_driver.cc index 2c2edb2ccef9720f0b31b3734c3a775337b5e1ce..017e887bc7da53b6721af42a0a1fcc29b09f2565 100644 --- a/paddle/fluid/platform/dynload/cuda_driver.cc +++ b/paddle/fluid/platform/dynload/cuda_driver.cc @@ -25,6 +25,15 @@ void* cuda_dso_handle = nullptr; CUDA_ROUTINE_EACH(DEFINE_WRAP); +#ifdef PADDLE_USE_DSO +bool HasCUDADriver() { + std::call_once(cuda_dso_flag, []() { cuda_dso_handle = GetCUDADsoHandle(); }); + return cuda_dso_handle != nullptr; +} +#else +bool HasCUDADriver() { return false; } +#endif + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/cuda_driver.h b/paddle/fluid/platform/dynload/cuda_driver.h index 894797728bb1c3794082bc0ba3094a6748c5a0c4..a37a47b7900d7ac50b277197b42d4f431bde7179 100644 --- a/paddle/fluid/platform/dynload/cuda_driver.h +++ b/paddle/fluid/platform/dynload/cuda_driver.h @@ -25,6 +25,7 @@ namespace dynload { extern std::once_flag cuda_dso_flag; extern void* cuda_dso_handle; +extern bool HasCUDADriver(); #ifdef PADDLE_USE_DSO diff --git a/paddle/fluid/platform/dynload/nvrtc.cc b/paddle/fluid/platform/dynload/nvrtc.cc index 793b5b8d149daa89d7a570e7d7519a3e9aebf584..f95d4b6ab521d22aeadca1214e265761c0652d33 100644 --- a/paddle/fluid/platform/dynload/nvrtc.cc +++ b/paddle/fluid/platform/dynload/nvrtc.cc @@ -25,6 +25,16 @@ void* nvrtc_dso_handle = nullptr; NVRTC_ROUTINE_EACH(DEFINE_WRAP); +#ifdef PADDLE_USE_DSO +bool HasNVRTC() { + std::call_once(nvrtc_dso_flag, + []() { nvrtc_dso_handle = GetNVRTCDsoHandle(); }); + return nvrtc_dso_handle != nullptr; +} +#else +bool HasNVRTC() { return false; } +#endif + } // namespace dynload } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/dynload/nvrtc.h b/paddle/fluid/platform/dynload/nvrtc.h index 20647affabc807ed5a570f09daa241e4389007e4..b4437099baa63b46bb646242c45b34de524da916 100644 --- a/paddle/fluid/platform/dynload/nvrtc.h +++ b/paddle/fluid/platform/dynload/nvrtc.h @@ -25,6 +25,7 @@ namespace dynload { extern std::once_flag nvrtc_dso_flag; extern void* nvrtc_dso_handle; +extern bool HasNVRTC(); #ifdef PADDLE_USE_DSO