未验证 提交 29d75c14 编写于 作者: L limingshu 提交者: GitHub

Add bfloat16 type support for abs op (#48205)

* first commit

* 2nd commit
上级 edf46919
......@@ -110,53 +110,6 @@ struct AbsFunctor<T, NoComplex<T, dtype::Real<T>>> {
int64_t numel_;
};
template <typename T>
struct AbsGradCUDAFunctor {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
HOSTDEVICE inline T operator()(const T x, const T dout) const {
T output;
if (x == T(0)) {
output = T(0);
} else {
output = T(dout) * (x / T(std::abs(x)));
}
return output;
}
};
template <>
struct AbsGradCUDAFunctor<phi::dtype::complex<float>> {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
HOSTDEVICE inline phi::dtype::complex<float> operator()(
const phi::dtype::complex<float> x, const float dout) const {
phi::dtype::complex<float> output;
if (x == phi::dtype::complex<float>(0)) {
output = phi::dtype::complex<float>(0);
} else {
output = phi::dtype::complex<float>(dout) *
(x / phi::dtype::complex<float>(abs(x)));
}
return output;
}
};
template <>
struct AbsGradCUDAFunctor<phi::dtype::complex<double>> {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
HOSTDEVICE inline phi::dtype::complex<double> operator()(
const phi::dtype::complex<double> x, const double dout) const {
phi::dtype::complex<double> output;
if (x == phi::dtype::complex<double>(0)) {
output = phi::dtype::complex<double>(0);
} else {
output = phi::dtype::complex<double>(dout) *
(x / phi::dtype::complex<double>(abs(x)));
}
return output;
}
};
template <typename T>
struct AbsGradFunctor {
AbsGradFunctor(const dtype::Real<T>* dout,
......@@ -179,6 +132,28 @@ struct AbsGradFunctor {
int64_t numel_;
};
template <>
struct AbsGradFunctor<phi::dtype::bfloat16> {
AbsGradFunctor(const dtype::Real<phi::dtype::bfloat16>* dout,
const phi::dtype::bfloat16* x,
phi::dtype::bfloat16* output,
int64_t numel)
: dout_(dout), x_(x), output_(output), numel_(numel) {}
HOSTDEVICE void operator()(int64_t idx) const {
if (x_[idx] == static_cast<phi::dtype::bfloat16>(0)) {
output_[idx] = static_cast<phi::dtype::bfloat16>(0);
} else {
output_[idx] = dout_[idx] * (x_[idx] / (abs(x_[idx])));
}
}
const dtype::Real<phi::dtype::bfloat16>* dout_;
const phi::dtype::bfloat16* x_;
phi::dtype::bfloat16* output_;
int64_t numel_;
};
template <>
struct AbsGradFunctor<phi::dtype::complex<float>> {
AbsGradFunctor(const float* dout,
......
......@@ -31,6 +31,7 @@ PD_REGISTER_KERNEL(abs_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
complex<float>,
complex<double>) {
kernel->InputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
......
......@@ -16,8 +16,8 @@
#include <algorithm>
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
......@@ -36,7 +36,18 @@ struct CudaAbsFunctor<T, phi::funcs::Complex<T, phi::dtype::Real<T>>> {
};
template <typename T>
struct CudaAbsFunctor<T, phi::funcs::NoComplex<T, phi::dtype::Real<T>>> {
struct CudaAbsFunctor<
T,
std::enable_if_t<std::is_same<T, phi::dtype::Real<T>>::value &&
std::is_same<T, phi::dtype::bfloat16>::value>> {
__device__ __forceinline__ T operator()(const T x) const { return abs(x); }
};
template <typename T>
struct CudaAbsFunctor<
T,
std::enable_if_t<std::is_same<T, phi::dtype::Real<T>>::value &&
!std::is_same<T, phi::dtype::bfloat16>::value>> {
__device__ __forceinline__ T operator()(const T x) const {
return std::abs(x);
}
......@@ -63,5 +74,6 @@ PD_REGISTER_KERNEL(abs,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
......@@ -14,6 +14,7 @@
#pragma once
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/kernels/abs_grad_kernel.h"
#include "paddle/phi/kernels/funcs/complex_functors.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
......@@ -22,6 +23,70 @@
namespace phi {
#if defined(__NVCC__)
template <typename T>
struct AbsGradCUDAFunctor {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
HOSTDEVICE inline T operator()(const T x, const T dout) const {
T output;
if (x == T(0)) {
output = T(0);
} else {
output = T(dout) * (x / T(std::abs(x)));
}
return output;
}
};
template <>
struct AbsGradCUDAFunctor<phi::dtype::bfloat16> {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
HOSTDEVICE inline phi::dtype::bfloat16 operator()(
const phi::dtype::bfloat16 x, const phi::dtype::bfloat16 dout) const {
phi::dtype::bfloat16 output;
if (x == phi::dtype::bfloat16(0)) {
output = static_cast<phi::dtype::bfloat16>(0);
} else {
output = (dout) * (x / abs(x));
}
return output;
}
};
template <>
struct AbsGradCUDAFunctor<phi::dtype::complex<float>> {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
HOSTDEVICE inline phi::dtype::complex<float> operator()(
const phi::dtype::complex<float> x, const float dout) const {
phi::dtype::complex<float> output;
if (x == phi::dtype::complex<float>(0)) {
output = phi::dtype::complex<float>(0);
} else {
output = phi::dtype::complex<float>(dout) *
(x / phi::dtype::complex<float>(abs(x)));
}
return output;
}
};
template <>
struct AbsGradCUDAFunctor<phi::dtype::complex<double>> {
HOSTDEVICE inline AbsGradCUDAFunctor() {}
HOSTDEVICE inline phi::dtype::complex<double> operator()(
const phi::dtype::complex<double> x, const double dout) const {
phi::dtype::complex<double> output;
if (x == phi::dtype::complex<double>(0)) {
output = phi::dtype::complex<double>(0);
} else {
output = phi::dtype::complex<double>(dout) *
(x / phi::dtype::complex<double>(abs(x)));
}
return output;
}
};
template <typename T>
void AbsGradKernelImpl(const GPUContext& dev_ctx,
const DenseTensor& x,
......@@ -30,9 +95,10 @@ void AbsGradKernelImpl(const GPUContext& dev_ctx,
std::vector<const DenseTensor*> ins = {&x, &dout};
std::vector<DenseTensor*> outs = {dx};
dev_ctx.Alloc<T>(dx);
phi::funcs::AbsGradCUDAFunctor<T> abs_grad_cuda_functor;
AbsGradCUDAFunctor<T> abs_grad_cuda_functor;
phi::funcs::ElementwiseKernel<T>(dev_ctx, ins, &outs, abs_grad_cuda_functor);
}
template <typename T, typename Context>
void AbsGradKernel(const Context& dev_ctx,
const DenseTensor& x,
......
......@@ -3699,6 +3699,7 @@ def create_test_act_bf16_class(
create_test_act_bf16_class(TestRelu)
create_test_act_bf16_class(TestAbs)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册