未验证 提交 26da689d 编写于 作者: H huangjiyi 提交者: GitHub

move fusion_group kernel to phi (#53781)

上级 0bed2203
......@@ -6,13 +6,13 @@ if(WITH_GPU OR WITH_ROCM)
cc_test(
test_code_generator
SRCS code_generator_tester.cc
DEPS code_generator device_code lod_tensor graph_viz_pass)
DEPS code_generator phi_backends lod_tensor graph_viz_pass)
endif()
cc_library(
fusion_group_pass
SRCS fusion_group_pass.cc elementwise_group_detector.cc
DEPS subgraph_detector fuse_pass_base code_generator device_code)
DEPS subgraph_detector fuse_pass_base code_generator phi_backends)
cc_test(
test_fusion_group_pass
SRCS fusion_group_pass_tester.cc
......
......@@ -20,8 +20,8 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/fusion_group/code_generator.h"
#include "paddle/fluid/framework/ir/fusion_group/operation.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device_code.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/device_code.h"
namespace phi {
class DenseTensor;
......@@ -182,7 +182,7 @@ void TestMainImpl(std::string func_name,
std::type_index(typeid(paddle::platform::float16));
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceCode device_code(place, func_name, code_str);
phi::GPUDeviceCode device_code(place, func_name, code_str);
#ifdef PADDLE_WITH_HIP
device_code.Compile(true);
#else
......
......@@ -19,12 +19,10 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/platform/device_code.h"
namespace paddle {
namespace platform {
#include "paddle/phi/backends/device_code.h"
namespace phi {
class DeviceCodePool;
} // namespace platform
} // namespace paddle
} // namespace phi
namespace paddle {
namespace framework {
......@@ -36,7 +34,7 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init("fusion_group_pass", graph);
if (Get<bool>("use_gpu")) {
// TODO(liuyiqun): open this check.
// if (!platform::CUDADeviceCode::IsAvailable()) {
// if (!phi::GPUDeviceCode::IsAvailable()) {
// LOG(WARNING)
// << "Disable fusion_group because CUDA Driver or NVRTC is not
// avaiable.";
......@@ -54,7 +52,7 @@ void FusionGroupPass::ApplyImpl(ir::Graph* graph) const {
int FusionGroupPass::DetectFusionGroup(Graph* graph, int type) const {
// TODO(liuyiqun): supported different places
platform::CUDAPlace place = platform::CUDAPlace(0);
int index = platform::DeviceCodePool::Init({place}).size(place);
int index = phi::DeviceCodePool::Init({place}).size(place);
std::vector<std::vector<Node*>> subgraphs =
fusion_group::ElementwiseGroupDetector()(graph);
......@@ -88,11 +86,11 @@ bool FusionGroupPass::GenerateCode(fusion_group::SubGraph* subgraph) const {
// TODO(liuyiqun): supported different places
platform::CUDAPlace place = platform::CUDAPlace(0);
std::unique_ptr<platform::CUDADeviceCode> device_code(
new platform::CUDADeviceCode(place, subgraph->GetFuncName(), code_str));
std::unique_ptr<phi::GPUDeviceCode> device_code(
new phi::GPUDeviceCode(place, subgraph->GetFuncName(), code_str));
bool is_compiled = device_code->Compile();
if (is_compiled) {
platform::DeviceCodePool& pool = platform::DeviceCodePool::Init({place});
phi::DeviceCodePool& pool = phi::DeviceCodePool::Init({place});
pool.Set(std::move(device_code));
}
return is_compiled;
......
......@@ -73,7 +73,7 @@ if(WITH_GPU OR WITH_ROCM)
op_library(fused_gate_attention_op)
# fusion_group
if(NOT APPLE AND NOT WIN32)
op_library(fusion_group_op DEPS device_code)
op_library(fusion_group_op)
endif()
# fused_bn_add_activation
# HIP not support bn act fuse in MIOPEN
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_group_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_code.h"
namespace paddle {
namespace operators {
template <typename DeviceContext>
static void MutableMultiTypeData(std::vector<phi::DenseTensor*>* var,
const std::vector<int>& data_type,
const DeviceContext& dev_ctx,
const platform::Place& place) {
for (size_t i = 0; i < var->size(); i++) {
if (data_type[i] == framework::proto::VarType::FP32) {
dev_ctx.template Alloc<float>((*var)[i],
(*var)[i]->numel() * sizeof(float));
} else if (data_type[i] == framework::proto::VarType::FP16) {
dev_ctx.template Alloc<paddle::platform::float16>(
(*var)[i], (*var)[i]->numel() * sizeof(paddle::platform::float16));
} else if (data_type[i] == framework::proto::VarType::FP64) {
dev_ctx.template Alloc<double>((*var)[i],
(*var)[i]->numel() * sizeof(double));
}
}
}
template <typename T, typename DeviceContext>
class FusionGroupKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<phi::DenseTensor>("Inputs");
auto outs = ctx.MultiOutput<phi::DenseTensor>("Outs");
int type = ctx.Attr<int>("type");
const auto& outs_dtype = ctx.Attr<std::vector<int>>("outs_dtype");
const auto& inputs_dtype = ctx.Attr<std::vector<int>>("inputs_dtype");
size_t num_ins = ins.size();
size_t num_outs = outs.size();
auto place = ctx.GetPlace();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
MutableMultiTypeData(&outs, outs_dtype, dev_ctx, place);
std::string func_name = ctx.Attr<std::string>("func_name");
platform::DeviceCode* dev_code =
platform::DeviceCodePool::Instance().Get(place, func_name);
VLOG(3) << "func_name: " << func_name;
if (type == 0) {
size_t n = ins[0]->numel();
std::vector<void*> args;
args.push_back(&n);
std::vector<const void*> ptrs(num_ins + num_outs);
for (size_t i = 0; i < num_ins; ++i) {
if (inputs_dtype[i] == framework::proto::VarType::FP16) {
ptrs[i] = ins[i]->data<paddle::platform::float16>();
} else if (inputs_dtype[i] == framework::proto::VarType::FP32) {
ptrs[i] = ins[i]->data<float>();
} else if (inputs_dtype[i] == framework::proto::VarType::FP64) {
ptrs[i] = ins[i]->data<double>();
}
args.push_back(&ptrs[i]);
}
for (size_t j = 0; j < num_outs; ++j) {
if (outs_dtype[j] == framework::proto::VarType::FP16) {
ptrs[num_ins + j] = outs[j]->data<paddle::platform::float16>();
} else if (outs_dtype[j] == framework::proto::VarType::FP32) {
ptrs[num_ins + j] = outs[j]->data<float>();
} else if (outs_dtype[j] == framework::proto::VarType::FP64) {
ptrs[num_ins + j] = outs[j]->data<double>();
}
args.push_back(&ptrs[num_ins + j]);
}
dev_code->Launch(n, &args);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -356,15 +356,11 @@ if(WITH_ROCM)
endif()
if(NOT APPLE AND NOT WIN32)
cc_library(
device_code
SRCS device_code.cc
DEPS device_context)
if(WITH_GPU OR WITH_ROCM)
cc_test(
device_code_test
SRCS device_code_test.cc
DEPS device_code lod_tensor)
DEPS phi_backends lod_tensor)
endif()
endif()
......
......@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_code.h"
#include "paddle/phi/backends/device_code.h"
#include <utility>
......@@ -47,14 +47,13 @@ void saxpy_kernel(float a, float *x, float* y, float* z, size_t n) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
TEST(DeviceCode, cuda) {
if (!paddle::platform::dynload::HasNVRTC() ||
!paddle::platform::dynload::HasCUDADriver()) {
if (!phi::dynload::HasNVRTC() || !phi::dynload::HasCUDADriver()) {
return;
}
paddle::framework::InitDevices({0});
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
paddle::platform::CUDADeviceCode code(place, "saxpy_kernel", saxpy_code);
phi::GPUPlace place = phi::GPUPlace(0);
phi::GPUDeviceCode code(place, "saxpy_kernel", saxpy_code);
phi::DenseTensor cpu_x;
phi::DenseTensor cpu_y;
......@@ -63,8 +62,12 @@ TEST(DeviceCode, cuda) {
float scale = 2;
auto dims =
phi::make_ddim({static_cast<int64_t>(256), static_cast<int64_t>(1024)});
cpu_x.mutable_data<float>(dims, paddle::platform::CPUPlace());
cpu_y.mutable_data<float>(dims, paddle::platform::CPUPlace());
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* cpu_ctx = reinterpret_cast<phi::CPUContext*>(pool.Get(phi::CPUPlace()));
cpu_x.Resize(dims);
cpu_ctx->template Alloc<float>(&cpu_x);
cpu_y.Resize(dims);
cpu_ctx->template Alloc<float>(&cpu_y);
size_t n = cpu_x.numel();
for (size_t i = 0; i < n; ++i) {
......@@ -78,9 +81,13 @@ TEST(DeviceCode, cuda) {
phi::DenseTensor y;
phi::DenseTensor z;
float* x_data = x.mutable_data<float>(dims, place);
float* y_data = y.mutable_data<float>(dims, place);
float* z_data = z.mutable_data<float>(dims, place);
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(pool.Get(place));
x.Resize(dims);
float* x_data = dev_ctx->template Alloc<float>(&x);
y.Resize(dims);
float* y_data = dev_ctx->template Alloc<float>(&y);
z.Resize(dims);
float* z_data = dev_ctx->template Alloc<float>(&z);
paddle::framework::TensorCopySync(cpu_x, place, &x);
paddle::framework::TensorCopySync(cpu_y, place, &y);
......@@ -92,36 +99,33 @@ TEST(DeviceCode, cuda) {
code.SetWorkloadPerThread(1);
code.Launch(n, &args);
auto* dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
dev_ctx->Wait();
paddle::framework::TensorCopySync(z, paddle::platform::CPUPlace(), &cpu_z);
paddle::framework::TensorCopySync(z, phi::CPUPlace(), &cpu_z);
for (size_t i = 0; i < n; i++) {
EXPECT_EQ(cpu_z.data<float>()[i], static_cast<float>(i) * scale + 0.5);
}
}
TEST(DeviceCodePool, cuda) {
if (!paddle::platform::dynload::HasNVRTC() ||
!paddle::platform::dynload::HasCUDADriver()) {
if (!phi::dynload::HasNVRTC() || !phi::dynload::HasCUDADriver()) {
return;
}
paddle::framework::InitDevices({0});
paddle::platform::CUDAPlace place = paddle::platform::CUDAPlace(0);
paddle::platform::DeviceCodePool& pool =
paddle::platform::DeviceCodePool::Init({place});
phi::GPUPlace place = phi::GPUPlace(0);
phi::DeviceCodePool& pool = phi::DeviceCodePool::Init({place});
size_t num_device_codes_before = pool.size(place);
EXPECT_EQ(num_device_codes_before, 0UL);
std::unique_ptr<paddle::platform::DeviceCode> code(
new paddle::platform::CUDADeviceCode(place, "saxpy_kernel", saxpy_code));
std::unique_ptr<phi::DeviceCode> code(
new phi::GPUDeviceCode(place, "saxpy_kernel", saxpy_code));
LOG(INFO) << "origin ptr: " << code.get();
pool.Set(std::move(code));
size_t num_device_codes_after = pool.size(place);
EXPECT_EQ(num_device_codes_after, 1UL);
paddle::platform::DeviceCode* code_get = pool.Get(place, "saxpy_kernel");
phi::DeviceCode* code_get = pool.Get(place, "saxpy_kernel");
LOG(INFO) << "get ptr: " << code_get;
}
#endif
......@@ -14,6 +14,10 @@ if(WITH_XBYAK)
list(APPEND BACKENDS_DEPS xbyak)
endif()
if(NOT APPLE AND NOT WIN32)
list(APPEND BACKENDS_SRCS device_code.cc)
endif()
if(WITH_GPU OR WITH_ROCM)
list(APPEND BACKENDS_SRCS gpu/gpu_context.cc gpu/gpu_info.cc
gpu/gpu_resources.cc)
......
......@@ -12,20 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/device_code.h"
#include "paddle/phi/backends/device_code.h"
#include <glog/logging.h>
#include <sys/stat.h>
#include <algorithm>
#include <set>
#include <utility>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/flags.h"
PHI_DECLARE_string(cuda_dir);
namespace paddle {
namespace platform {
namespace phi {
DeviceCodePool* DeviceCodePool::pool = nullptr;
......@@ -35,7 +37,7 @@ void DeviceCodePool::Set(std::unique_ptr<DeviceCode>&& code) {
auto iter = device_codes_.find(place);
if (iter == device_codes_.end()) {
PADDLE_THROW(platform::errors::NotFound(
PADDLE_THROW(phi::errors::NotFound(
"Place %s is not supported for runtime compiling.", place));
}
......@@ -43,18 +45,18 @@ void DeviceCodePool::Set(std::unique_ptr<DeviceCode>&& code) {
codes_map.emplace(name, std::move(code));
}
platform::DeviceCode* DeviceCodePool::Get(const platform::Place& place,
DeviceCode* DeviceCodePool::Get(const phi::Place& place,
const std::string& name) {
auto iter = device_codes_.find(place);
if (iter == device_codes_.end()) {
PADDLE_THROW(platform::errors::NotFound(
PADDLE_THROW(phi::errors::NotFound(
"Place %s is not supported for runtime compiling.", place));
}
auto& codes_map = iter->second;
auto code_iter = codes_map.find(name);
if (code_iter == codes_map.end()) {
PADDLE_THROW(platform::errors::NotFound(
PADDLE_THROW(phi::errors::NotFound(
"Device code named %s for place %s does not exist.",
name.c_str(),
place));
......@@ -63,7 +65,7 @@ platform::DeviceCode* DeviceCodePool::Get(const platform::Place& place,
return code_iter->second.get();
}
DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) {
DeviceCodePool::DeviceCodePool(const std::vector<phi::Place>& places) {
PADDLE_ENFORCE_GT(places.size(),
0,
errors::InvalidArgument(
......@@ -75,11 +77,11 @@ DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) {
set.insert(p);
}
for (auto& p : set) {
if (is_gpu_place(p)) {
if (p.GetType() == phi::AllocationType::GPU) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
device_codes_.emplace(p, DeviceCodeMap());
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
PADDLE_THROW(phi::errors::PreconditionNotMet(
"CUDAPlace or HIPPlace is not supported, please re-compile with "
"WITH_GPU=ON or WITH_ROCM=ON."));
#endif
......@@ -87,7 +89,7 @@ DeviceCodePool::DeviceCodePool(const std::vector<platform::Place>& places) {
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
CUDADeviceCode::CheckAvailableStatus();
GPUDeviceCode::CheckAvailableStatus();
#endif
}
......@@ -114,8 +116,8 @@ static bool CheckCUDADriverResult(CUresult result,
return true;
}
bool CUDADeviceCode::available_ = false;
void CUDADeviceCode::CheckAvailableStatus() {
bool GPUDeviceCode::available_ = false;
void GPUDeviceCode::CheckAvailableStatus() {
available_ = false;
if (!dynload::HasNVRTC() || !dynload::HasCUDADriver()) {
LOG_FIRST_N(WARNING, 1)
......@@ -215,12 +217,12 @@ static std::string FindCUDAIncludePath() {
return "";
}
CUDADeviceCode::CUDADeviceCode(const Place& place,
GPUDeviceCode::GPUDeviceCode(const Place& place,
const std::string& name,
const std::string& kernel) {
if (!is_gpu_place(place)) {
PADDLE_THROW(platform::errors::PermissionDenied(
"CUDADeviceCode can only launch on GPU place."));
if (place.GetType() != phi::AllocationType::GPU) {
PADDLE_THROW(phi::errors::PermissionDenied(
"GPUDeviceCode can only launch on GPU place."));
}
place_ = place;
......@@ -232,7 +234,7 @@ CUDADeviceCode::CUDADeviceCode(const Place& place,
#endif
}
bool CUDADeviceCode::Compile(bool include_path) {
bool GPUDeviceCode::Compile(bool include_path) {
is_compiled_ = false;
if (!dynload::HasNVRTC() || !dynload::HasCUDADriver()) {
LOG_FIRST_N(WARNING, 1)
......@@ -403,7 +405,7 @@ bool CUDADeviceCode::Compile(bool include_path) {
return true;
}
void CUDADeviceCode::Launch(const size_t n, std::vector<void*>* args) const {
void GPUDeviceCode::Launch(const size_t n, std::vector<void*>* args) const {
PADDLE_ENFORCE_EQ(
is_compiled_,
true,
......@@ -454,7 +456,7 @@ void CUDADeviceCode::Launch(const size_t n, std::vector<void*>* args) const {
}
#ifdef PADDLE_WITH_HIP
bool CUDADeviceCode::CheckNVRTCResult(hiprtcResult result,
bool GPUDeviceCode::CheckNVRTCResult(hiprtcResult result,
std::string function) {
if (result != HIPRTC_SUCCESS) {
LOG_FIRST_N(WARNING, 1)
......@@ -463,8 +465,7 @@ bool CUDADeviceCode::CheckNVRTCResult(hiprtcResult result,
return false;
}
#else
bool CUDADeviceCode::CheckNVRTCResult(nvrtcResult result,
std::string function) {
bool GPUDeviceCode::CheckNVRTCResult(nvrtcResult result, std::string function) {
if (result != NVRTC_SUCCESS) {
LOG_FIRST_N(WARNING, 1)
<< "Call " << function << " for < " << name_
......@@ -476,5 +477,4 @@ bool CUDADeviceCode::CheckNVRTCResult(nvrtcResult result,
}
#endif
} // namespace platform
} // namespace paddle
} // namespace phi
......@@ -20,18 +20,18 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cuda_driver.h"
#include "paddle/fluid/platform/dynload/nvrtc.h"
#include "paddle/phi/backends/dynload/cuda_driver.h"
#include "paddle/phi/backends/dynload/nvrtc.h"
#endif
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/dynload/hiprtc.h"
#include "paddle/fluid/platform/dynload/rocm_driver.h"
#include "paddle/phi/backends/dynload/hiprtc.h"
#include "paddle/phi/backends/dynload/rocm_driver.h"
#endif
namespace paddle {
namespace platform {
namespace phi {
class DeviceCode {
public:
......@@ -49,9 +49,9 @@ class DeviceCode {
};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
class CUDADeviceCode : public DeviceCode {
class GPUDeviceCode : public DeviceCode {
public:
explicit CUDADeviceCode(const Place& place,
explicit GPUDeviceCode(const Place& place,
const std::string& name,
const std::string& kernel);
bool Compile(bool include_path = false) override;
......@@ -94,7 +94,7 @@ class DeviceCodePool {
using DeviceCodeMap =
std::unordered_map<std::string, std::unique_ptr<DeviceCode>>;
explicit DeviceCodePool(const std::vector<platform::Place>& places);
explicit DeviceCodePool(const std::vector<Place>& places);
static DeviceCodePool& Instance() {
PADDLE_ENFORCE_NOT_NULL(
......@@ -104,7 +104,7 @@ class DeviceCodePool {
return *pool;
}
static DeviceCodePool& Init(const std::vector<platform::Place>& places) {
static DeviceCodePool& Init(const std::vector<Place>& places) {
if (pool == nullptr) {
pool = new DeviceCodePool(places);
}
......@@ -113,10 +113,9 @@ class DeviceCodePool {
void Set(std::unique_ptr<DeviceCode>&& code);
platform::DeviceCode* Get(const platform::Place& place,
const std::string& name);
DeviceCode* Get(const Place& place, const std::string& name);
size_t size(const platform::Place& place) const {
size_t size(const Place& place) const {
auto iter = device_codes_.find(place);
if (iter == device_codes_.end()) {
return 0;
......@@ -130,5 +129,4 @@ class DeviceCodePool {
DISABLE_COPY_AND_ASSIGN(DeviceCodePool);
};
} // namespace platform
} // namespace paddle
} // namespace phi
......@@ -153,6 +153,11 @@ if(WITH_CUTLASS)
list(APPEND kernel_cu ${cutlass_cu})
endif()
if(APPLE OR WIN32)
list(REMOVE_ITEM kernel_cu
"${CMAKE_CURRENT_SOURCE_DIR}/fusion/gpu/fusion_group_kernel.cu")
endif()
if(WITH_MKLDNN)
file(
GLOB
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/phi/backends/device_code.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
namespace fusion {
template <typename DeviceContext>
static void MutableMultiTypeData(std::vector<phi::DenseTensor*>* var,
const std::vector<int>& data_type,
const DeviceContext& dev_ctx) {
for (size_t i = 0; i < var->size(); i++) {
if (data_type[i] == phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
dev_ctx.template Alloc<float>((*var)[i],
(*var)[i]->numel() * sizeof(float));
} else if (data_type[i] ==
phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
dev_ctx.template Alloc<phi::dtype::float16>(
(*var)[i], (*var)[i]->numel() * sizeof(phi::dtype::float16));
} else if (data_type[i] ==
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
dev_ctx.template Alloc<double>((*var)[i],
(*var)[i]->numel() * sizeof(double));
}
}
}
template <typename T, typename Context>
void FusionGroupKernel(const Context& dev_ctx,
const std::vector<const DenseTensor*>& ins,
const std::vector<int>& outs_dtype,
const std::vector<int>& inputs_dtype,
const std::string& func_name,
int type,
std::vector<DenseTensor*> outs) {
size_t num_ins = ins.size();
size_t num_outs = outs.size();
MutableMultiTypeData(&outs, outs_dtype, dev_ctx);
phi::DeviceCode* dev_code =
phi::DeviceCodePool::Instance().Get(dev_ctx.GetPlace(), func_name);
VLOG(3) << "func_name: " << func_name;
if (type == 0) {
size_t n = ins[0]->numel();
std::vector<void*> args;
args.push_back(&n);
std::vector<const void*> ptrs(num_ins + num_outs);
for (size_t i = 0; i < num_ins; ++i) {
if (inputs_dtype[i] == phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
ptrs[i] = ins[i]->data<phi::dtype::float16>();
} else if (inputs_dtype[i] ==
phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
ptrs[i] = ins[i]->data<float>();
} else if (inputs_dtype[i] ==
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
ptrs[i] = ins[i]->data<double>();
}
args.push_back(&ptrs[i]);
}
for (size_t j = 0; j < num_outs; ++j) {
if (outs_dtype[j] == phi::TransToProtoVarType(phi::DataType::FLOAT16)) {
ptrs[num_ins + j] = outs[j]->data<phi::dtype::float16>();
} else if (outs_dtype[j] ==
phi::TransToProtoVarType(phi::DataType::FLOAT32)) {
ptrs[num_ins + j] = outs[j]->data<float>();
} else if (outs_dtype[j] ==
phi::TransToProtoVarType(phi::DataType::FLOAT64)) {
ptrs[num_ins + j] = outs[j]->data<double>();
}
args.push_back(&ptrs[num_ins + j]);
}
dev_code->Launch(n, &args);
}
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fusion_group,
GPU,
ALL_LAYOUT,
phi::fusion::FusionGroupKernel,
float,
double,
phi::dtype::float16) {}
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,16 +12,18 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused/fusion_group_op.h"
#include "paddle/phi/core/compat/op_utils.h"
#include "paddle/fluid/platform/float16.h"
namespace phi {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(fusion_group,
GPU,
ALL_LAYOUT,
ops::FusionGroupKernel,
float,
double,
plat::float16) {}
KernelSignature FusionGroupOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fusion_group",
{"Inputs"},
{"outs_dtype", "inputs_dtype", "func_name", "type"},
{"Outs"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(fusion_group, phi::FusionGroupOpArgumentMapping);
......@@ -17,8 +17,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/platform/device_code.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/phi/backends/device_code.h"
namespace paddle {
namespace operators {
......@@ -93,11 +93,10 @@ framework::OpDesc* CreateFusionGroupOp(
void PrepareDeviceCode(platform::Place place,
std::string func_name,
std::string cuda_kernel_str) {
paddle::platform::DeviceCodePool& pool =
paddle::platform::DeviceCodePool::Init({place});
phi::DeviceCodePool& pool = phi::DeviceCodePool::Init({place});
std::unique_ptr<paddle::platform::DeviceCode> code(
new paddle::platform::CUDADeviceCode(place, func_name, cuda_kernel_str));
std::unique_ptr<phi::DeviceCode> code(
new phi::GPUDeviceCode(place, func_name, cuda_kernel_str));
code->Compile();
pool.Set(std::move(code));
}
......@@ -183,7 +182,7 @@ void TestMain(const std::vector<std::string>& input_names,
}
TEST(FusionGroupOp, elementwise) {
if (!platform::dynload::HasNVRTC() || !platform::dynload::HasCUDADriver()) {
if (!phi::dynload::HasNVRTC() || !phi::dynload::HasCUDADriver()) {
return;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册