未验证 提交 4f74656d 编写于 作者: H huangjiyi 提交者: GitHub

fix Kunlun-KP-Build (#52188)

* fix kp compile

* test

* Revert "test"

This reverts commit 3a1cbfaa0f23e6e06d3dcd8d0b0c28aa63a98e70.

* update copyright

* update cmake

* update cmake

* update cmake

* update cmake
上级 2a8cfb55
......@@ -219,6 +219,16 @@ function(op_library TARGET)
list(APPEND mlu_cc_srcs ${src})
elseif(${src} MATCHES ".*\\.cc$")
list(APPEND cc_srcs ${src})
elseif((WITH_ROCM OR WITH_GPU) AND ${src} MATCHES ".*\\.kps$")
string(REPLACE ".kps" ".cu" src_cu ${src})
file(COPY ${src} DESTINATION ${CMAKE_CURRENT_BINARY_DIR})
file(RENAME ${CMAKE_CURRENT_BINARY_DIR}/${src}
${CMAKE_CURRENT_BINARY_DIR}/${src_cu})
if(WITH_ROCM)
list(APPEND hip_srcs ${CMAKE_CURRENT_BINARY_DIR}/${src_cu})
else()
list(APPEND cu_srcs ${CMAKE_CURRENT_BINARY_DIR}/${src_cu})
endif()
else()
message(
FATAL_ERROR
......
......@@ -97,7 +97,7 @@ endif()
set(OP_HEADER_DEPS ${OP_HEADER_DEPS} phi phi_utils backward_infermeta sparse_backward_infermeta static_prim_api get_expected_kernel_func)
register_operators(EXCLUDES py_func_op warpctc_op dgc_op generated_op1 generated_op2 generated_op3 generated_op4 load_combine_op lstm_op run_program_op eye_op quantize_linear_op
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
recurrent_op save_combine_op sparse_attention_op sync_batch_norm_op activation_op ${OP_MKL_DEPS} DEPS ${OP_HEADER_DEPS})
op_library(generated_op UNITY SRCS generated_op1.cc generated_op2.cc generated_op3.cc generated_op4.cc DEPS ${OP_HEADER_DEPS})
op_library(run_program_op SRCS run_program_op.cc run_program_op.cu.cc run_program_op_npu.cc DEPS executor_cache ${OP_HEADER_DEPS})
......@@ -106,6 +106,14 @@ op_library(quantize_linear_op DEPS phi)
op_library(save_combine_op DEPS string_array phi)
op_library(load_combine_op DEPS string_array)
if (WITH_GPU OR WITH_ROCM)
op_library(activation_op SRCS activation_op.cc activation_op.kps soft_relu_op.cu DEPS ${OP_HEADER_DEPS})
elseif (WITH_XPU_KP)
op_library(activation_op SRCS activation_op.cc activation_op.kps DEPS ${OP_HEADER_DEPS})
else()
op_library(activation_op SRCS activation_op.cc DEPS ${OP_HEADER_DEPS})
endif()
if (WITH_GPU OR WITH_ROCM)
if(WITH_ROCM)
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale SRCS warpctc_op.cc)
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
namespace paddle {
namespace operators {
template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold)))
// threshold should not be negative
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType t = static_cast<MPType>(threshold);
MPType temp_min = x < t ? x : t;
MPType temp_max = temp_min > -t ? temp_min : -t;
return static_cast<T>(log(one + exp(temp_max)));
}
};
template <typename T>
struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0
// threshold should not be negative
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_out) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType t = static_cast<MPType>(threshold);
return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out)))
: static_cast<T>(0.0f);
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename DeviceContext, typename Functor>
class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const phi::DenseTensor* x = nullptr;
phi::DenseTensor* out = nullptr;
ExtractActivationTensor(ctx, &x, &out);
out->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::vector<const phi::DenseTensor*> ins = {x};
std::vector<phi::DenseTensor*> outs = {out};
auto functor = Functor();
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
dev_ctx, ins, &outs, functor);
}
};
template <typename DeviceContext, typename Functor>
class ActivationGradCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const phi::DenseTensor *x, *out, *d_out;
phi::DenseTensor* d_x = nullptr;
x = out = d_out = nullptr;
ExtractActivationGradTensor<Functor::FwdDeps()>(
ctx, &x, &out, &d_out, &d_x);
d_x->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto functor = Functor();
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
std::vector<const phi::DenseTensor*> ins = {d_out};
std::vector<phi::DenseTensor*> outs = {d_x};
if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
// Only need forward output Out
ins.push_back(out);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
dev_ctx, ins, &outs, functor);
} else if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
// Only need forward input X
ins.push_back(x);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
dev_ctx, ins, &outs, functor);
} else {
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
dev_ctx, ins, &outs, functor);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -9,127 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/fluid/operators/activation_op.cu.h"
namespace paddle {
namespace operators {
template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// soft_relu(x) = log(1 + exp(max(min(x, threshold), -threshold)))
// threshold should not be negative
__device__ __forceinline__ T operator()(const T arg_x) const {
MPType x = static_cast<MPType>(arg_x);
MPType t = static_cast<MPType>(threshold);
MPType temp_min = x < t ? x : t;
MPType temp_max = temp_min > -t ? temp_min : -t;
return static_cast<T>(log(one + exp(temp_max)));
}
};
template <typename T>
struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
// dx = (out > -threshold && out < threshold) ? dout * (1 - exp(-out)) : 0
// threshold should not be negative
__device__ __forceinline__ T operator()(const T arg_dout,
const T arg_out) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType out = static_cast<MPType>(arg_out);
MPType t = static_cast<MPType>(threshold);
return (out > -t && out < t) ? static_cast<T>(dout * (one - exp(-out)))
: static_cast<T>(0.0f);
}
static constexpr ActBwdOpFwdDeps FwdDeps() {
return ActBwdOpFwdDeps::kDepOut;
}
};
template <typename DeviceContext, typename Functor>
class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const phi::DenseTensor* x = nullptr;
phi::DenseTensor* out = nullptr;
ExtractActivationTensor(ctx, &x, &out);
out->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
std::vector<const phi::DenseTensor*> ins = {x};
std::vector<phi::DenseTensor*> outs = {out};
auto functor = Functor();
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
dev_ctx, ins, &outs, functor);
}
};
template <typename DeviceContext, typename Functor>
class ActivationGradCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const phi::DenseTensor *x, *out, *d_out;
phi::DenseTensor* d_x = nullptr;
x = out = d_out = nullptr;
ExtractActivationGradTensor<Functor::FwdDeps()>(
ctx, &x, &out, &d_out, &d_x);
d_x->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto functor = Functor();
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
std::vector<const phi::DenseTensor*> ins = {d_out};
std::vector<phi::DenseTensor*> outs = {d_x};
if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
// Only need forward output Out
ins.push_back(out);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
dev_ctx, ins, &outs, functor);
} else if (static_cast<int>(Functor::FwdDeps()) ==
static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
// Only need forward input X
ins.push_back(x);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
dev_ctx, ins, &outs, functor);
} else {
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
dev_ctx, ins, &outs, functor);
}
}
};
template <typename T>
using CudaBReluFunctor = phi::funcs::CudaHardTanhFunctor<T>;
template <typename T>
......@@ -192,42 +76,12 @@ template <typename T>
using CudaELUGradNegativeAlphaFunctor =
phi::funcs::CudaELUGradNegativeAlphaFunctor<T>;
#define DEFINE_ACTIVATION_CUDA_KERNEL(op_name, functor, grad_functor) \
template <typename T, typename DeviceContext> \
class op_name##CudaKernel \
: public ActivationCudaKernel<DeviceContext, functor<T>> {}; \
\
template <typename T, typename DeviceContext> \
class op_name##GradCudaKernel \
: public ActivationGradCudaKernel<DeviceContext, grad_functor<T>> {};
DEFINE_ACTIVATION_CUDA_KERNEL(SoftRelu,
CudaSoftReluFunctor,
CudaSoftReluGradFunctor)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(soft_relu,
GPU,
ALL_LAYOUT,
ops::SoftReluCudaKernel,
float,
double,
plat::float16,
plat::bfloat16) {}
PD_REGISTER_STRUCT_KERNEL(soft_relu_grad,
GPU,
ALL_LAYOUT,
ops::SoftReluGradCudaKernel,
float,
double,
plat::float16,
plat::bfloat16) {}
#ifdef PADDLE_WITH_XPU_KP
REGISTER_OP_KERNEL(
brelu,
......
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/activation_op.cu.h"
namespace paddle {
namespace operators {
#define DEFINE_ACTIVATION_CUDA_KERNEL(op_name, functor, grad_functor) \
template <typename T, typename DeviceContext> \
class op_name##CudaKernel \
: public ActivationCudaKernel<DeviceContext, functor<T>> {}; \
\
template <typename T, typename DeviceContext> \
class op_name##GradCudaKernel \
: public ActivationGradCudaKernel<DeviceContext, grad_functor<T>> {};
DEFINE_ACTIVATION_CUDA_KERNEL(SoftRelu,
CudaSoftReluFunctor,
CudaSoftReluGradFunctor)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(soft_relu,
GPU,
ALL_LAYOUT,
ops::SoftReluCudaKernel,
float,
double,
plat::float16,
plat::bfloat16) {}
PD_REGISTER_STRUCT_KERNEL(soft_relu_grad,
GPU,
ALL_LAYOUT,
ops::SoftReluGradCudaKernel,
float,
double,
plat::float16,
plat::bfloat16) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册