未验证 提交 26da689d 编写于 作者: H huangjiyi 提交者: GitHub

move fusion_group kernel to phi (#53781)

上级 0bed2203
...@@ -6,13 +6,13 @@ if(WITH_GPU OR WITH_ROCM) ...@@ -6,13 +6,13 @@ if(WITH_GPU OR WITH_ROCM)
cc_test( cc_test(
test_code_generator test_code_generator
SRCS code_generator_tester.cc SRCS code_generator_tester.cc
DEPS code_generator device_code lod_tensor graph_viz_pass) DEPS code_generator phi_backends lod_tensor graph_viz_pass)
endif() endif()
cc_library( cc_library(
fusion_group_pass fusion_group_pass
SRCS fusion_group_pass.cc elementwise_group_detector.cc SRCS fusion_group_pass.cc elementwise_group_detector.cc
DEPS subgraph_detector fuse_pass_base code_generator device_code) DEPS subgraph_detector fuse_pass_base code_generator phi_backends)
cc_test( cc_test(
test_fusion_group_pass test_fusion_group_pass
SRCS fusion_group_pass_tester.cc SRCS fusion_group_pass_tester.cc
......
...@@ -20,8 +20,8 @@ limitations under the License. */ ...@@ -20,8 +20,8 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/code_generator.h" #include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h" #include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device_code.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/device_code.h"
namespace phi { namespace phi {
class DenseTensor; class DenseTensor;
...@@ -182,7 +182,7 @@ void TestMainImpl(std::string func_name, ...@@ -182,7 +182,7 @@ void TestMainImpl(std::string func_name,
std::type_index(typeid(paddle::platform::float16)); std::type_index(typeid(paddle::platform::float16));
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0); paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceCode device_code(place, func_name, code_str); phi::GPUDeviceCode device_code(place, func_name, code_str);
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
device_code.Compile(true); device_code.Compile(true);
#else #else
......
...@@ -19,12 +19,10 @@ limitations under the License. */ ...@@ -19,12 +19,10 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h" #include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/platform/device_code.h" #include "paddle/phi/backends/device_code.h"
namespace paddle { namespace phi {
namespace platform {
class DeviceCodePool; class DeviceCodePool;
} // namespace platform } // namespace phi
} // namespace paddle
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -36,7 +34,7 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const { ...@@ -36,7 +34,7 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("fusion_group_pass", graph); FusePassBase::Init("fusion_group_pass", graph);
if (Get<bool>("use_gpu")) { if (Get<bool>("use_gpu")) {
// TODO(liuyiqun): open this check. // TODO(liuyiqun): open this check.
// if (!platform::CUDADeviceCode::IsAvailable()) { // if (!phi::GPUDeviceCode::IsAvailable()) {
// LOG(WARNING) // LOG(WARNING)
// << "Disable fusion_group because CUDA Driver or NVRTC is not // << "Disable fusion_group because CUDA Driver or NVRTC is not
// avaiable."; // avaiable.";
...@@ -54,7 +52,7 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const { ...@@ -54,7 +52,7 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const { int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
// TODO(liuyiqun): supported different places // TODO(liuyiqun): supported different places
platform::CUDAPlace place = platform::CUDAPlace(0); platform::CUDAPlace place = platform::CUDAPlace(0);
int index = platform::DeviceCodePool::Init({place}).size(place); int index = phi::DeviceCodePool::Init({place}).size(place);
std::vector<std::vector<Node*>> subgraphs = std::vector<std::vector<Node*>> subgraphs =
fusion_group::ElementwiseGroupDetector()(graph); fusion_group::ElementwiseGroupDetector()(graph);
...@@ -88,11 +86,11 @@ bool FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const { ...@@ -88,11 +86,11 @@ bool FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const {
// TODO(liuyiqun): supported different places // TODO(liuyiqun): supported different places
platform::CUDAPlace place = platform::CUDAPlace(0); platform::CUDAPlace place = platform::CUDAPlace(0);
std::unique_ptr<platform::CUDADeviceCode> device_code( std::unique_ptr<phi::GPUDeviceCode> device_code(
new platform::CUDADeviceCode(place, subgraph->GetFuncName(), code_str)); new phi::GPUDeviceCode(place, subgraph->GetFuncName(), code_str));
bool is_compiled = device_code->Compile(); bool is_compiled = device_code->Compile();
if (is_compiled) { if (is_compiled) {
platform::DeviceCodePool& pool = platform::DeviceCodePool::Init({place}); phi::DeviceCodePool& pool = phi::DeviceCodePool::Init({place});
pool.Set(std::move(device_code)); pool.Set(std::move(device_code));
} }
return is_compiled; return is_compiled;
......
...@@ -73,7 +73,7 @@ if(WITH_GPU OR WITH_ROCM) ...@@ -73,7 +73,7 @@ if(WITH_GPU OR WITH_ROCM)
op_library(fused_gate_attention_op) op_library(fused_gate_attention_op)
# fusion_group # fusion_group
if(NOT APPLE AND NOT WIN32) if(NOT APPLE AND NOT WIN32)
op_library(fusion_group_op DEPS device_code) op_library(fusion_group_op)
endif() endif()
# fused_bn_add_activation # fused_bn_add_activation
# HIP not support bn act fuse in MIOPEN # HIP not support bn act fuse in MIOPEN
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_group_op.h" #include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_code.h"
namespace paddle {
namespace operators {
template <typename DeviceContext>
static void MutableMultiTypeData(std::vector<phi::DenseTensor*>* var,
const std::vector<int>& data_type,
const DeviceContext& dev_ctx,
const platform::Place& place) {
for (size_t i = 0; i < var->size(); i++) {
if (data_type[i] == framework::proto::VarType::FP32) {
dev_ctx.template Alloc<float>((*var)[i],
(*var)[i]->numel() * sizeof(float));
} else if (data_type[i] == framework::proto::VarType::FP16) {
dev_ctx.template Alloc<paddle::platform::float16>(
(*var)[i], (*var)[i]->numel() * sizeof(paddle::platform::float16));
} else if (data_type[i] == framework::proto::VarType::FP64) {
dev_ctx.template Alloc<double>((*var)[i],
(*var)[i]->numel() * sizeof(double));
}
}
}
template <typename T, typename DeviceContext>
class FusionGroupKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<phi::DenseTensor>("Inputs");
auto outs = ctx.MultiOutput<phi::DenseTensor>("Outs");
int type = ctx.Attr<int>("type");
const auto& outs_dtype = ctx.Attr<std::vector<int>>("outs_dtype");
const auto& inputs_dtype = ctx.Attr<std::vector<int>>("inputs_dtype");
size_t num_ins = ins.size();
size_t num_outs = outs.size();
auto place = ctx.GetPlace();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
MutableMultiTypeData(&outs, outs_dtype, dev_ctx, place);
std::string func_name = ctx.Attr<std::string>("func_name");
platform::DeviceCode* dev_code =
platform::DeviceCodePool::Instance().Get(place, func_name);
VLOG(3) << "func_name: " << func_name;
if (type == 0) {
size_t n = ins[0]->numel();
std::vector<void*> args;
args.push_back(&n);
std::vector<const void*> ptrs(num_ins + num_outs);
for (size_t i = 0; i < num_ins; ++i) {
if (inputs_dtype[i] == framework::proto::VarType::FP16) {
ptrs[i] = ins[i]->data<paddle::platform::float16>();
} else if (inputs_dtype[i] == framework::proto::VarType::FP32) {
ptrs[i] = ins[i]->data<float>();
} else if (inputs_dtype[i] == framework::proto::VarType::FP64) {
ptrs[i] = ins[i]->data<double>();
}
args.push_back(&ptrs[i]);
}
for (size_t j = 0; j < num_outs; ++j) {
if (outs_dtype[j] == framework::proto::VarType::FP16) {
ptrs[num_ins + j] = outs[j]->data<paddle::platform::float16>();
} else if (outs_dtype[j] == framework::proto::VarType::FP32) {
ptrs[num_ins + j] = outs[j]->data<float>();
} else if (outs_dtype[j] == framework::proto::VarType::FP64) {
ptrs[num_ins + j] = outs[j]->data<double>();
}
args.push_back(&ptrs[num_ins + j]);
}
dev_code->Launch(n, &args);
}
}
};
} // namespace operators
} // namespace paddle
...@@ -356,15 +356,11 @@ if(WITH_ROCM) ...@@ -356,15 +356,11 @@ if(WITH_ROCM)
endif() endif()
if(NOT APPLE AND NOT WIN32) if(NOT APPLE AND NOT WIN32)
cc_library(
device_code
SRCS device_code.cc
DEPS device_context)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
cc_test( cc_test(
device_code_test device_code_test
SRCS device_code_test.cc SRCS device_code_test.cc
DEPS device_code lod_tensor) DEPS phi_backends lod_tensor)
endif() endif()
endif() endif()
......
...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/device_code.h" #include "paddle/phi/backends/device_code.h"
#include <utility> #include <utility>
...@@ -47,14 +47,13 @@ void saxpy_kernel(float a, float *x, float* y, float* z, size_t n) { ...@@ -47,14 +47,13 @@ void saxpy_kernel(float a, float *x, float* y, float* z, size_t n) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(DeviceCode, cuda) { TEST(DeviceCode, cuda) {
if (!paddle::platform::dynload::HasNVRTC() || if (!phi::dynload::HasNVRTC() || !phi::dynload::HasCUDADriver()) {
!paddle::platform::dynload::HasCUDADriver()) {
return; return;
} }
paddle::framework::InitDevices({0}); paddle::framework::InitDevices({0});
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0); phi::GPUPlace place = phi::GPUPlace(0);
paddle::platform::CUDADeviceCode code(place, "saxpy_kernel", saxpy_code); phi::GPUDeviceCode code(place, "saxpy_kernel", saxpy_code);
phi::DenseTensor cpu_x; phi::DenseTensor cpu_x;
phi::DenseTensor cpu_y; phi::DenseTensor cpu_y;
...@@ -63,8 +62,12 @@ TEST(DeviceCode, cuda) { ...@@ -63,8 +62,12 @@ TEST(DeviceCode, cuda) {
float scale = 2; float scale = 2;
auto dims = auto dims =
phi::make_ddim({static_cast<int64_t>(256), static_cast<int64_t>(1024)}); phi::make_ddim({static_cast<int64_t>(256), static_cast<int64_t>(1024)});
cpu_x.mutable_data<float>(dims, paddle::platform::CPUPlace()); phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
cpu_y.mutable_data<float>(dims, paddle::platform::CPUPlace()); auto* cpu_ctx = reinterpret_cast<phi::CPUContext*>(pool.Get(phi::CPUPlace()));
cpu_x.Resize(dims);
cpu_ctx->template Alloc<float>(&cpu_x);
cpu_y.Resize(dims);
cpu_ctx->template Alloc<float>(&cpu_y);
size_t n = cpu_x.numel(); size_t n = cpu_x.numel();
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
...@@ -78,9 +81,13 @@ TEST(DeviceCode, cuda) { ...@@ -78,9 +81,13 @@ TEST(DeviceCode, cuda) {
phi::DenseTensor y; phi::DenseTensor y;
phi::DenseTensor z; phi::DenseTensor z;
float* x_data = x.mutable_data<float>(dims, place); auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(pool.Get(place));
float* y_data = y.mutable_data<float>(dims, place); x.Resize(dims);
float* z_data = z.mutable_data<float>(dims, place); float* x_data = dev_ctx->template Alloc<float>(&x);
y.Resize(dims);
float* y_data = dev_ctx->template Alloc<float>(&y);
z.Resize(dims);
float* z_data = dev_ctx->template Alloc<float>(&z);
paddle::framework::TensorCopySync(cpu_x, place, &x); paddle::framework::TensorCopySync(cpu_x, place, &x);
paddle::framework::TensorCopySync(cpu_y, place, &y); paddle::framework::TensorCopySync(cpu_y, place, &y);
...@@ -92,36 +99,33 @@ TEST(DeviceCode, cuda) { ...@@ -92,36 +99,33 @@ TEST(DeviceCode, cuda) {
code.SetWorkloadPerThread(1); code.SetWorkloadPerThread(1);
code.Launch(n, &args); code.Launch(n, &args);
auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
dev_ctx->Wait(); dev_ctx->Wait();
paddle::framework::TensorCopySync(z, paddle::platform::CPUPlace(), &cpu_z); paddle::framework::TensorCopySync(z, phi::CPUPlace(), &cpu_z);
for (size_t i = 0; i < n; i++) { for (size_t i = 0; i < n; i++) {
EXPECT_EQ(cpu_z.data<float>()[i], static_cast<float>(i) * scale + 0.5); EXPECT_EQ(cpu_z.data<float>()[i], static_cast<float>(i) * scale + 0.5);
} }
} }
TEST(DeviceCodePool, cuda) { TEST(DeviceCodePool, cuda) {
if (!paddle::platform::dynload::HasNVRTC() || if (!phi::dynload::HasNVRTC() || !phi::dynload::HasCUDADriver()) {
!paddle::platform::dynload::HasCUDADriver()) {
return; return;
} }
paddle::framework::InitDevices({0}); paddle::framework::InitDevices({0});
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0); phi::GPUPlace place = phi::GPUPlace(0);
paddle::platform::DeviceCodePool& pool = phi::DeviceCodePool& pool = phi::DeviceCodePool::Init({place});
paddle::platform::DeviceCodePool::Init({place});
size_t num_device_codes_before = pool.size(place); size_t num_device_codes_before = pool.size(place);
EXPECT_EQ(num_device_codes_before, 0UL); EXPECT_EQ(num_device_codes_before, 0UL);
std::unique_ptr<paddle::platform::DeviceCode> code( std::unique_ptr<phi::DeviceCode> code(
new paddle::platform::CUDADeviceCode(place, "saxpy_kernel", saxpy_code)); new phi::GPUDeviceCode(place, "saxpy_kernel", saxpy_code));
LOG(INFO) << "origin ptr: " << code.get(); LOG(INFO) << "origin ptr: " << code.get();
pool.Set(std::move(code)); pool.Set(std::move(code));
size_t num_device_codes_after = pool.size(place); size_t num_device_codes_after = pool.size(place);
EXPECT_EQ(num_device_codes_after, 1UL); EXPECT_EQ(num_device_codes_after, 1UL);
paddle::platform::DeviceCode* code_get = pool.Get(place, "saxpy_kernel"); phi::DeviceCode* code_get = pool.Get(place, "saxpy_kernel");
LOG(INFO) << "get ptr: " << code_get; LOG(INFO) << "get ptr: " << code_get;
} }
#endif #endif
...@@ -14,6 +14,10 @@ if(WITH_XBYAK) ...@@ -14,6 +14,10 @@ if(WITH_XBYAK)
list(APPEND BACKENDS_DEPS xbyak) list(APPEND BACKENDS_DEPS xbyak)
endif() endif()
if(NOT APPLE AND NOT WIN32)
list(APPEND BACKENDS_SRCS device_code.cc)
endif()
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
list(APPEND BACKENDS_SRCS gpu/gpu_context.cc gpu/gpu_info.cc list(APPEND BACKENDS_SRCS gpu/gpu_context.cc gpu/gpu_info.cc
gpu/gpu_resources.cc) gpu/gpu_resources.cc)
......
...@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/device_code.h" #include "paddle/phi/backends/device_code.h"
#include <glog/logging.h>
#include <sys/stat.h> #include <sys/stat.h>
#include <algorithm> #include <algorithm>
#include <set> #include <set>
#include <utility> #include <utility>
#include "paddle/fluid/platform/enforce.h" #include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_string(cuda_dir); PHI_DECLARE_string(cuda_dir);
namespace paddle { namespace phi {
namespace platform {
DeviceCodePool* DeviceCodePool::pool = nullptr; DeviceCodePool* DeviceCodePool::pool = nullptr;
...@@ -35,7 +37,7 @@ void DeviceCodePool::Set(std::unique_ptr<DeviceCode>&& code) { ...@@ -35,7 +37,7 @@ void DeviceCodePool::Set(std::unique_ptr<DeviceCode>&& code) {
auto iter = device_codes_.find(place); auto iter = device_codes_.find(place);
if (iter == device_codes_.end()) { if (iter == device_codes_.end()) {
PADDLE_THROW(platform::errors::NotFound( PADDLE_THROW(phi::errors::NotFound(
"Place %s is not supported for runtime compiling.", place)); "Place %s is not supported for runtime compiling.", place));
} }
...@@ -43,18 +45,18 @@ void DeviceCodePool::Set(std::unique_ptr<DeviceCode>&& code) { ...@@ -43,18 +45,18 @@ void DeviceCodePool::Set(std::unique_ptr<DeviceCode>&& code) {
codes_map.emplace(name, std::move(code)); codes_map.emplace(name, std::move(code));
} }
platform::DeviceCode* DeviceCodePool::Get(const platform::Place& place, DeviceCode* DeviceCodePool::Get(const phi::Place& place,
const std::string& name) { const std::string& name) {
auto iter = device_codes_.find(place); auto iter = device_codes_.find(place);
if (iter == device_codes_.end()) { if (iter == device_codes_.end()) {
PADDLE_THROW(platform::errors::NotFound( PADDLE_THROW(phi::errors::NotFound(
"Place %s is not supported for runtime compiling.", place)); "Place %s is not supported for runtime compiling.", place));
} }
auto& codes_map = iter->second; auto& codes_map = iter->second;
auto code_iter = codes_map.find(name); auto code_iter = codes_map.find(name);
if (code_iter == codes_map.end()) { if (code_iter == codes_map.end()) {
PADDLE_THROW(platform::errors::NotFound( PADDLE_THROW(phi::errors::NotFound(
"Device code named %s for place %s does not exist.", "Device code named %s for place %s does not exist.",
name.c_str(), name.c_str(),
place)); place));
...@@ -63,7 +65,7 @@ platform::DeviceCode* DeviceCodePool::Get(const platform::Place& place, ...@@ -63,7 +65,7 @@ platform::DeviceCode* DeviceCodePool::Get(const platform::Place& place,
return code_iter->second.get(); return code_iter->second.get();
} }
DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) { DeviceCodePool::DeviceCodePool(const std::vector<phi::Place>& places) {
PADDLE_ENFORCE_GT(places.size(), PADDLE_ENFORCE_GT(places.size(),
0, 0,
errors::InvalidArgument( errors::InvalidArgument(
...@@ -75,11 +77,11 @@ DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) { ...@@ -75,11 +77,11 @@ DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) {
set.insert(p); set.insert(p);
} }
for (auto& p : set) { for (auto& p : set) {
if (is_gpu_place(p)) { if (p.GetType() == phi::AllocationType::GPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
device_codes_.emplace(p, DeviceCodeMap()); device_codes_.emplace(p, DeviceCodeMap());
#else #else
PADDLE_THROW(platform::errors::PreconditionNotMet( PADDLE_THROW(phi::errors::PreconditionNotMet(
"CUDAPlace or HIPPlace is not supported, please re-compile with " "CUDAPlace or HIPPlace is not supported, please re-compile with "
"WITH_GPU=ON or WITH_ROCM=ON.")); "WITH_GPU=ON or WITH_ROCM=ON."));
#endif #endif
...@@ -87,7 +89,7 @@ DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) { ...@@ -87,7 +89,7 @@ DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) {
} }
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
CUDADeviceCode::CheckAvailableStatus(); GPUDeviceCode::CheckAvailableStatus();
#endif #endif
} }
...@@ -114,8 +116,8 @@ static bool CheckCUDADriverResult(CUresult result, ...@@ -114,8 +116,8 @@ static bool CheckCUDADriverResult(CUresult result,
return true; return true;
} }
bool CUDADeviceCode::available_ = false; bool GPUDeviceCode::available_ = false;
void CUDADeviceCode::CheckAvailableStatus() { void GPUDeviceCode::CheckAvailableStatus() {
available_ = false; available_ = false;
if (!dynload::HasNVRTC() || !dynload::HasCUDADriver()) { if (!dynload::HasNVRTC() || !dynload::HasCUDADriver()) {
LOG_FIRST_N(WARNING, 1) LOG_FIRST_N(WARNING, 1)
...@@ -215,12 +217,12 @@ static std::string FindCUDAIncludePath() { ...@@ -215,12 +217,12 @@ static std::string FindCUDAIncludePath() {
return ""; return "";
} }
CUDADeviceCode::CUDADeviceCode(const Place& place, GPUDeviceCode::GPUDeviceCode(const Place& place,
const std::string& name, const std::string& name,
const std::string& kernel) { const std::string& kernel) {
if (!is_gpu_place(place)) { if (place.GetType() != phi::AllocationType::GPU) {
PADDLE_THROW(platform::errors::PermissionDenied( PADDLE_THROW(phi::errors::PermissionDenied(
"CUDADeviceCode can only launch on GPU place.")); "GPUDeviceCode can only launch on GPU place."));
} }
place_ = place; place_ = place;
...@@ -232,7 +234,7 @@ CUDADeviceCode::CUDADeviceCode(const Place& place, ...@@ -232,7 +234,7 @@ CUDADeviceCode::CUDADeviceCode(const Place& place,
#endif #endif
} }
bool CUDADeviceCode::Compile(bool include_path) { bool GPUDeviceCode::Compile(bool include_path) {
is_compiled_ = false; is_compiled_ = false;
if (!dynload::HasNVRTC() || !dynload::HasCUDADriver()) { if (!dynload::HasNVRTC() || !dynload::HasCUDADriver()) {
LOG_FIRST_N(WARNING, 1) LOG_FIRST_N(WARNING, 1)
...@@ -403,7 +405,7 @@ bool CUDADeviceCode::Compile(bool include_path) { ...@@ -403,7 +405,7 @@ bool CUDADeviceCode::Compile(bool include_path) {
return true; return true;
} }
void CUDADeviceCode::Launch(const size_t n, std::vector<void*>* args) const { void GPUDeviceCode::Launch(const size_t n, std::vector<void*>* args) const {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
is_compiled_, is_compiled_,
true, true,
...@@ -454,8 +456,8 @@ void CUDADeviceCode::Launch(const size_t n, std::vector<void*>* args) const { ...@@ -454,8 +456,8 @@ void CUDADeviceCode::Launch(const size_t n, std::vector<void*>* args) const {
} }
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
bool CUDADeviceCode::CheckNVRTCResult(hiprtcResult result, bool GPUDeviceCode::CheckNVRTCResult(hiprtcResult result,
std::string function) { std::string function) {
if (result != HIPRTC_SUCCESS) { if (result != HIPRTC_SUCCESS) {
LOG_FIRST_N(WARNING, 1) LOG_FIRST_N(WARNING, 1)
<< "Call " << function << " for < " << name_ << "Call " << function << " for < " << name_
...@@ -463,8 +465,7 @@ bool CUDADeviceCode::CheckNVRTCResult(hiprtcResult result, ...@@ -463,8 +465,7 @@ bool CUDADeviceCode::CheckNVRTCResult(hiprtcResult result,
return false; return false;
} }
#else #else
bool CUDADeviceCode::CheckNVRTCResult(nvrtcResult result, bool GPUDeviceCode::CheckNVRTCResult(nvrtcResult result, std::string function) {
std::string function) {
if (result != NVRTC_SUCCESS) { if (result != NVRTC_SUCCESS) {
LOG_FIRST_N(WARNING, 1) LOG_FIRST_N(WARNING, 1)
<< "Call " << function << " for < " << name_ << "Call " << function << " for < " << name_
...@@ -476,5 +477,4 @@ bool CUDADeviceCode::CheckNVRTCResult(nvrtcResult result, ...@@ -476,5 +477,4 @@ bool CUDADeviceCode::CheckNVRTCResult(nvrtcResult result,
} }
#endif #endif
} // namespace platform } // namespace phi
} // namespace paddle
...@@ -20,18 +20,18 @@ limitations under the License. */ ...@@ -20,18 +20,18 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/platform/device_context.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cuda_driver.h" #include "paddle/phi/backends/dynload/cuda_driver.h"
#include "paddle/fluid/platform/dynload/nvrtc.h" #include "paddle/phi/backends/dynload/nvrtc.h"
#endif #endif
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/hiprtc.h" #include "paddle/phi/backends/dynload/hiprtc.h"
#include "paddle/fluid/platform/dynload/rocm_driver.h" #include "paddle/phi/backends/dynload/rocm_driver.h"
#endif #endif
namespace paddle { namespace phi {
namespace platform {
class DeviceCode { class DeviceCode {
public: public:
...@@ -49,11 +49,11 @@ class DeviceCode { ...@@ -49,11 +49,11 @@ class DeviceCode {
}; };
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
class CUDADeviceCode : public DeviceCode { class GPUDeviceCode : public DeviceCode {
public: public:
explicit CUDADeviceCode(const Place& place, explicit GPUDeviceCode(const Place& place,
const std::string& name, const std::string& name,
const std::string& kernel); const std::string& kernel);
bool Compile(bool include_path = false) override; bool Compile(bool include_path = false) override;
void Launch(const size_t n, std::vector<void*>* args) const override; void Launch(const size_t n, std::vector<void*>* args) const override;
...@@ -94,7 +94,7 @@ class DeviceCodePool { ...@@ -94,7 +94,7 @@ class DeviceCodePool {
using DeviceCodeMap = using DeviceCodeMap =
std::unordered_map<std::string, std::unique_ptr<DeviceCode>>; std::unordered_map<std::string, std::unique_ptr<DeviceCode>>;
explicit DeviceCodePool(const std::vector<platform::Place>& places); explicit DeviceCodePool(const std::vector<Place>& places);
static DeviceCodePool& Instance() { static DeviceCodePool& Instance() {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
...@@ -104,7 +104,7 @@ class DeviceCodePool { ...@@ -104,7 +104,7 @@ class DeviceCodePool {
return *pool; return *pool;
} }
static DeviceCodePool& Init(const std::vector<platform::Place>& places) { static DeviceCodePool& Init(const std::vector<Place>& places) {
if (pool == nullptr) { if (pool == nullptr) {
pool = new DeviceCodePool(places); pool = new DeviceCodePool(places);
} }
...@@ -113,10 +113,9 @@ class DeviceCodePool { ...@@ -113,10 +113,9 @@ class DeviceCodePool {
void Set(std::unique_ptr<DeviceCode>&& code); void Set(std::unique_ptr<DeviceCode>&& code);
platform::DeviceCode* Get(const platform::Place& place, DeviceCode* Get(const Place& place, const std::string& name);
const std::string& name);
size_t size(const platform::Place& place) const { size_t size(const Place& place) const {
auto iter = device_codes_.find(place); auto iter = device_codes_.find(place);
if (iter == device_codes_.end()) { if (iter == device_codes_.end()) {
return 0; return 0;
...@@ -130,5 +129,4 @@ class DeviceCodePool { ...@@ -130,5 +129,4 @@ class DeviceCodePool {
DISABLE_COPY_AND_ASSIGN(DeviceCodePool); DISABLE_COPY_AND_ASSIGN(DeviceCodePool);
}; };
} // namespace platform } // namespace phi
} // namespace paddle
...@@ -153,6 +153,11 @@ if(WITH_CUTLASS) ...@@ -153,6 +153,11 @@ if(WITH_CUTLASS)
list(APPEND kernel_cu ${cutlass_cu}) list(APPEND kernel_cu ${cutlass_cu})
endif() endif()
if(APPLE OR WIN32)
list(REMOVE_ITEM kernel_cu
"${CMAKE_CURRENT_SOURCE_DIR}/fusion/gpu/fusion_group_kernel.cu")
endif()
if(WITH_MKLDNN) if(WITH_MKLDNN)
file( file(
GLOB GLOB
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/phi/backends/device_code.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
namespace fusion {
template <typename DeviceContext>
static void MutableMultiTypeData(std::vector<phi::DenseTensor*>* var,
const std::vector<int>& data_type,
const DeviceContext& dev_ctx) {
for (size_t i = 0; i < var->size(); i++) {
if (data_type[i] == phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
dev_ctx.template Alloc<float>((*var)[i],
(*var)[i]->numel() * sizeof(float));
} else if (data_type[i] ==
phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
dev_ctx.template Alloc<phi::dtype::float16>(
(*var)[i], (*var)[i]->numel() * sizeof(phi::dtype::float16));
} else if (data_type[i] ==
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
dev_ctx.template Alloc<double>((*var)[i],
(*var)[i]->numel() * sizeof(double));
}
}
}
template <typename T, typename Context>
void FusionGroupKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& ins,
const std::vector<int>& outs_dtype,
const std::vector<int>& inputs_dtype,
const std::string& func_name,
int type,
std::vector<DenseTensor*> outs) {
size_t num_ins = ins.size();
size_t num_outs = outs.size();
MutableMultiTypeData(&outs, outs_dtype, dev_ctx);
phi::DeviceCode* dev_code =
phi::DeviceCodePool::Instance().Get(dev_ctx.GetPlace(), func_name);
VLOG(3) << "func_name: " << func_name;
if (type == 0) {
size_t n = ins[0]->numel();
std::vector<void*> args;
args.push_back(&n);
std::vector<const void*> ptrs(num_ins + num_outs);
for (size_t i = 0; i < num_ins; ++i) {
if (inputs_dtype[i] == phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
ptrs[i] = ins[i]->data<phi::dtype::float16>();
} else if (inputs_dtype[i] ==
phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
ptrs[i] = ins[i]->data<float>();
} else if (inputs_dtype[i] ==
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
ptrs[i] = ins[i]->data<double>();
}
args.push_back(&ptrs[i]);
}
for (size_t j = 0; j < num_outs; ++j) {
if (outs_dtype[j] == phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
ptrs[num_ins + j] = outs[j]->data<phi::dtype::float16>();
} else if (outs_dtype[j] ==
phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
ptrs[num_ins + j] = outs[j]->data<float>();
} else if (outs_dtype[j] ==
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
ptrs[num_ins + j] = outs[j]->data<double>();
}
args.push_back(&ptrs[num_ins + j]);
}
dev_code->Launch(n, &args);
}
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fusion_group,
GPU,
ALL_LAYOUT,
phi::fusion::FusionGroupKernel,
float,
double,
phi::dtype::float16) {}
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,16 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,16 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_group_op.h" #include "paddle/phi/core/compat/op_utils.h"
#include "paddle/fluid/platform/float16.h" namespace phi {
namespace ops = paddle::operators; KernelSignature FusionGroupOpArgumentMapping(
namespace plat = paddle::platform; const ArgumentMappingContext& ctx) {
PD_REGISTER_STRUCT_KERNEL(fusion_group, return KernelSignature("fusion_group",
GPU, {"Inputs"},
ALL_LAYOUT, {"outs_dtype", "inputs_dtype", "func_name", "type"},
ops::FusionGroupKernel, {"Outs"});
float, }
double,
plat::float16) {} } // namespace phi
PD_REGISTER_ARG_MAPPING_FN(fusion_group, phi::FusionGroupOpArgumentMapping);
...@@ -17,8 +17,8 @@ limitations under the License. */ ...@@ -17,8 +17,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/device_code.h"
#include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/init.h"
#include "paddle/phi/backends/device_code.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -93,11 +93,10 @@ framework::OpDesc* CreateFusionGroupOp( ...@@ -93,11 +93,10 @@ framework::OpDesc* CreateFusionGroupOp(
void PrepareDeviceCode(platform::Place place, void PrepareDeviceCode(platform::Place place,
std::string func_name, std::string func_name,
std::string cuda_kernel_str) { std::string cuda_kernel_str) {
paddle::platform::DeviceCodePool& pool = phi::DeviceCodePool& pool = phi::DeviceCodePool::Init({place});
paddle::platform::DeviceCodePool::Init({place});
std::unique_ptr<paddle::platform::DeviceCode> code( std::unique_ptr<phi::DeviceCode> code(
new paddle::platform::CUDADeviceCode(place, func_name, cuda_kernel_str)); new phi::GPUDeviceCode(place, func_name, cuda_kernel_str));
code->Compile(); code->Compile();
pool.Set(std::move(code)); pool.Set(std::move(code));
} }
...@@ -183,7 +182,7 @@ void TestMain(const std::vector<std::string>& input_names, ...@@ -183,7 +182,7 @@ void TestMain(const std::vector<std::string>& input_names,
} }
TEST(FusionGroupOp, elementwise) { TEST(FusionGroupOp, elementwise) {
if (!platform::dynload::HasNVRTC() || !platform::dynload::HasCUDADriver()) { if (!phi::dynload::HasNVRTC() || !phi::dynload::HasCUDADriver()) {
return; return;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册