未验证 提交 424700ff 编写于 作者: N niuliling123 提交者: GitHub

Replace clip, bce_loss, full and full_like with elementwise (#39197)

* Replace clip, bce_loss, full and full_like with elementwise
上级 23d559dd
......@@ -21,40 +21,45 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct BCELossGradFunctor {
T one = static_cast<T>(1.0f);
T eps = static_cast<T>(1e-12);
__device__ __forceinline__ T operator()(const T x, const T label,
const T dout) const {
T term1 = max((one - x) * x, eps);
return (dout * (x - label) / term1);
}
};
struct BCELossFunctor {
T one;
T neg_100;
template <typename T>
__global__ void GPUBCELossForward(const T* x_data, const T* label_data,
T* out_data, const int in_numel) {
CUDA_KERNEL_LOOP(i, in_numel) {
T x = x_data[i];
T label = label_data[i];
T one = static_cast<T>(1.);
T neg_100 = static_cast<T>(-100.);
HOSTDEVICE inline BCELossFunctor() {
one = static_cast<T>(1.0f);
neg_100 = static_cast<T>(-100.);
}
HOSTDEVICE inline T operator()(const T& x, const T& label) const {
PADDLE_ENFORCE(
(x >= static_cast<T>(0)) && (x <= one),
"Input is expected to be within the interval [0, 1], but recieved %f.",
x);
T term1 = max(real_log(x), neg_100);
T term2 = max(real_log(one - x), neg_100);
return (((label - one) * term2) - (label * term1));
}
};
template <typename T>
struct BCELossGradFunctor {
T one;
T eps;
out_data[i] = ((label - one) * term2) - (label * term1);
HOSTDEVICE inline BCELossGradFunctor() {
one = static_cast<T>(1.0f);
eps = static_cast<T>(1e-12);
}
}
HOSTDEVICE inline T operator()(const T& x, const T& label,
const T& dout) const {
T term1 = max((one - x) * x, eps);
return (dout * (x - label) / term1);
}
};
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class BCELossCUDAKernel : public framework::OpKernel<T> {
......@@ -63,18 +68,13 @@ class BCELossCUDAKernel : public framework::OpKernel<T> {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* out = ctx.Output<Tensor>("Out");
const auto* x_data = x->data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
auto x_numel = x->numel();
auto& dev_ctx = ctx.cuda_device_context();
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(dev_ctx, x_numel);
GPUBCELossForward<T><<<config.block_per_grid, config.thread_per_block, 0,
dev_ctx.stream()>>>(x_data, labels->data<T>(),
out_data, x_numel);
out->mutable_data<T>(ctx.GetPlace());
std::vector<const framework::Tensor*> ins = {x, labels};
std::vector<framework::Tensor*> outs = {out};
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
auto functor = BCELossFunctor<T>();
paddle::operators::LaunchSameDimsElementwiseCudaKernel<
ElementwiseType::kBinary, T, T>(dev_ctx, ins, &outs, functor);
}
};
......
......@@ -172,6 +172,15 @@ class ClipGradKernel : public framework::OpKernel<T> {
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
auto* x = context.Input<framework::LoDTensor>("X");
#if defined(__NVCC__) || defined(__HIPCC__)
std::vector<const framework::Tensor*> ins = {d_out, x};
std::vector<framework::Tensor*> outs = {d_x};
auto functor = ClipGradFunctor<T>(min, max);
d_x->mutable_data<T>(context.GetPlace());
LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kBinary, T, T>(
context.template device_context<platform::CUDADeviceContext>(), ins,
&outs, functor);
#else
int64_t numel = d_out->numel();
auto* d_x_data = d_x->mutable_data<T>(context.GetPlace());
const T* d_out_data = d_out->data<T>();
......@@ -179,6 +188,7 @@ class ClipGradKernel : public framework::OpKernel<T> {
Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), d_out_data,
d_out_data + numel, x_data, d_x_data, ClipGradFunctor<T>(min, max));
#endif
}
}
};
......
......@@ -18,6 +18,18 @@ limitations under the License. */
namespace paddle {
namespace platform {
template <int Arity, typename... Args>
struct IsPointerArgs {
static_assert(Arity == sizeof...(Args), "Arity and Args not match!");
static const bool value = false;
};
template <typename... Args>
struct IsPointerArgs<1, Args...> {
static_assert(1 == sizeof...(Args), "Arity and Args not match!");
static const bool value = std::is_pointer<
typename std::tuple_element<0, std::tuple<Args...>>::type>::value;
};
// Declare a template class with a single template parameter.
template <typename>
......@@ -41,10 +53,7 @@ struct FunctionTraits<ReturnType (ClassType::*)(Args...)>
template <typename ReturnType, typename... Args>
struct FunctionTraits<ReturnType(Args...)> {
static const size_t arity = sizeof...(Args);
static const bool has_pointer_args =
(arity == 1) &&
(std::is_pointer<
typename std::tuple_element<0, std::tuple<Args...>>::type>::value);
static const bool has_pointer_args = IsPointerArgs<arity, Args...>::value;
};
} // namespace platform
......
......@@ -31,6 +31,8 @@ namespace kps = pten::kps;
#endif
#define BASE_SIZE 1 // To avoid running errors when Arity == 0 in args[Arity]
namespace pten {
enum ElementwiseType { kUnary = 1, kBinary = 2, kTernary = 3, kAny = -1 };
......@@ -475,6 +477,15 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, Arity, true> {
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 0, false> {
__device__ inline void operator()(Functor func,
InT (*args)[VecSize],
OutT *result) {
kps::ElementwiseFillConst<InT, OutT, VecSize, 1, 1, Functor>(result, func);
}
};
template <typename InT, typename OutT, int VecSize, typename Functor>
struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 1, false> {
__device__ inline void operator()(Functor func,
......@@ -548,12 +559,14 @@ template <typename InT,
int VecSize,
bool IsBoundary>
__device__ void VectorizedElementwiseKernelImpl(
const pten::framework::Array<const _ptr_ InT *__restrict__, Arity> &in,
const pten::framework::Array<const _ptr_ InT *__restrict__,
Arity + BASE_SIZE> &in,
pten::framework::Array<_ptr_ OutT *, NumOuts> outs,
int num,
int data_offset,
Functor func) {
InT args[Arity][VecSize];
InT args[Arity + BASE_SIZE][VecSize];
ConditionalT<OutT, NumOuts> result[VecSize];
#pragma unroll
......@@ -583,7 +596,8 @@ template <typename InT,
int NumOuts,
int VecSize>
__global__ void VectorizedElementwiseKernel(
pten::framework::Array<const _ptr_ InT *__restrict__, Arity> ins,
pten::framework::Array<const _ptr_ InT *__restrict__, Arity + BASE_SIZE>
ins,
pten::framework::Array<_ptr_ OutT *, NumOuts> outs,
int size,
int main_offset,
......@@ -623,8 +637,9 @@ void ElementwiseCudaKernel(const KPDevice &ctx,
const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
Functor func) {
auto numel = ins[0]->numel();
pten::framework::Array<const _ptr_ InT *__restrict__, Arity> ins_data;
auto numel = (*outs)[0]->numel();
pten::framework::Array<const _ptr_ InT *__restrict__, Arity + BASE_SIZE>
ins_data;
pten::framework::Array<_ptr_ OutT *, NumOuts> outs_data;
for (int i = 0; i < Arity; ++i) {
......
......@@ -16,7 +16,89 @@ limitations under the License. */
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/kernel_registry.h"
#include "paddle/pten/kernels/impl/full_kernel_impl.h"
#include "paddle/pten/kernels/funcs/elementwise_base.h"
namespace pten {
template <typename InT, typename OutT = InT>
struct FullFuctor {
OutT value;
template <typename VType>
explicit inline FullFuctor(VType val) {
value = static_cast<OutT>(val);
}
__device__ __forceinline__ OutT operator()() const {
return static_cast<OutT>(value);
}
};
template <typename T, typename ContextT>
void FullKernel(const ContextT& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData()));
int numel = out->numel();
out->mutable_data<T>(dev_ctx.GetPlace());
if (numel > 0) {
// in transformer model the numel of outpout will be zero.
std::vector<const DenseTensor*> inputs = {};
std::vector<DenseTensor*> outputs = {out};
// This function has no input, so the inputs.size() == 0. Use kUnary, but
// the data will not be loaded in the kernel because the number of
// parameters in the operator is 0
pten::funcs::LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary,
T,
T>(
dev_ctx, inputs, &outputs, FullFuctor<T>(val.to<T>()));
}
}
template <typename T, typename ContextT>
void FullLikeKernel(const ContextT& dev_ctx,
const Scalar& val,
DenseTensor* out) {
auto value = val.to<float>();
using CommonType = typename std::common_type<
float,
typename std::conditional<
std::is_same<T, paddle::platform::float16>::value,
float,
T>::type>::type;
auto common_type_value = static_cast<CommonType>(value);
PADDLE_ENFORCE_EQ(
(common_type_value >=
static_cast<CommonType>(std::numeric_limits<T>::lowest())) &&
(common_type_value <=
static_cast<CommonType>(std::numeric_limits<T>::max())),
true,
paddle::platform::errors::InvalidArgument(
"The filled value is out of range for target type, "
"current kernel type is %s, the range should between %f "
"and %f, but now value is %f.",
typeid(T).name(),
static_cast<CommonType>(std::numeric_limits<T>::lowest()),
static_cast<CommonType>(std::numeric_limits<T>::max()),
static_cast<float>(value)));
std::vector<const DenseTensor*> inputs = {};
std::vector<DenseTensor*> outputs = {out};
out->mutable_data<T>(dev_ctx.GetPlace());
// This function has no input, so the inputs.size() == 0. Use kUnary, but the
// data will not be loaded in the kernel because the number of parameters in
// the operator is 0
int numel = out->numel();
if (numel > 0) {
pten::funcs::LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary,
T,
T>(
dev_ctx, inputs, &outputs, FullFuctor<T>(value));
}
}
} // namespace pten
PT_REGISTER_KERNEL(full,
GPU,
......
......@@ -28,11 +28,10 @@ struct ScaleFunctor {
InT scale;
bool bias_after_scale;
ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle) {
scale = scale_data;
bias = bias_data;
bias_after_scale = is_bias_after_sacle;
}
ScaleFunctor(InT scale_data, InT bias_data, bool is_bias_after_sacle)
: bias(bias_data),
scale(scale_data),
bias_after_scale(is_bias_after_sacle) {}
__device__ __forceinline__ InT operator()(const InT x) const {
if (bias_after_scale) {
......
......@@ -414,5 +414,19 @@ __device__ __forceinline__ void Reduce(T* out,
}
}
template <typename InT,
typename OutT,
int NX,
int NY,
int BlockSize,
class OpFunc>
__device__ __forceinline__ void ElementwiseFillConst(OutT* out,
OpFunc compute) {
#pragma unroll
for (int idx = 0; idx < NX * NY; idx++) {
out[idx] = static_cast<OutT>(compute());
}
}
} // namespace kps
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册