未验证 提交 ea590ef6 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

[dtype] add fp16 support for dist_kernel (#56184)

* [dtype] add fp16 support for dist_kernel

* fix typo

* fix CE

* fix CE

* fix CE

* fix CE

* fix CE

* refactor

* fix CE

* fix CE

* fix varname

* add bf16

* add ut for bf16

* fix CE
上级 ac44d798
......@@ -98,6 +98,12 @@ PD_REGISTER_KERNEL(
dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(
dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
PD_REGISTER_KERNEL(dist_grad,
GPU,
ALL_LAYOUT,
phi::DistGradKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
#endif
......@@ -12,9 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/dist_kernel.h"
#include <algorithm>
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/gpu/reduce.h"
......@@ -24,47 +27,53 @@ namespace phi {
#define FULL_MASK 0xffffffff
template <typename T>
template <typename Tx, typename Ty = Tx>
struct ZeroOrderFunctor {
public:
__device__ T operator()(const T& x, const T& y) const {
return static_cast<T>((x - y) != 0);
HOSTDEVICE explicit inline ZeroOrderFunctor() {}
HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const {
return static_cast<Ty>(x != y);
}
};
template <typename T>
template <typename Tx, typename Ty = Tx>
struct OtherOrderFunctor {
explicit OtherOrderFunctor(const T& p_order) : p_order_(p_order) {}
__device__ T operator()(const T& x, const T& y) const {
return static_cast<T>(pow(abs(x - y), p_order_));
HOSTDEVICE explicit inline OtherOrderFunctor(const Ty& p_order)
: p_order_(p_order) {}
HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const {
return static_cast<Ty>(
pow(abs(static_cast<Ty>(x) - static_cast<Ty>(y)), p_order_));
}
private:
T p_order_;
Ty p_order_;
};
template <typename T>
template <typename Tx, typename Ty = Tx>
struct PowFunctor {
explicit PowFunctor(const T& p_order) : p_order_(p_order) {}
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(pow(x, p_order_));
HOSTDEVICE explicit inline PowFunctor(const Ty& p_order)
: p_order_(p_order) {}
HOSTDEVICE inline Tx operator()(const Tx x) const {
return static_cast<Tx>(pow(static_cast<Ty>(x), p_order_));
}
T p_order_;
Ty p_order_;
};
template <typename T, typename Functor>
__global__ void ReduceSumWithSubtract(
const T* x, const T* y, T* out, int64_t N, Functor func) {
T sum_val = 0;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT sum_val(0.0);
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
sum_val += func(x[i], y[i]);
}
__syncthreads();
sum_val = phi::funcs::BlockReduceSum<T>(sum_val, FULL_MASK);
sum_val = phi::funcs::BlockReduceSum<MT>(sum_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = sum_val;
out[blockIdx.x] = static_cast<T>(sum_val);
}
}
......@@ -73,16 +82,17 @@ __global__ void ReduceMaxWithSubtract(const T* x,
const T* y,
T* out,
int64_t N) {
T max_val = -1e10f;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT max_val = std::numeric_limits<MT>::min();
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
max_val = max(max_val, abs(x[i] - y[i]));
max_val = max(max_val, abs(static_cast<MT>(x[i]) - static_cast<MT>(y[i])));
}
__syncthreads();
max_val = phi::funcs::BlockReduceMax<T>(max_val, FULL_MASK);
max_val = phi::funcs::BlockReduceMax<MT>(max_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = max_val;
out[blockIdx.x] = static_cast<T>(max_val);
}
}
......@@ -91,16 +101,17 @@ __global__ void ReduceMinWithSubtract(const T* x,
const T* y,
T* out,
int64_t N) {
T min_val = 1e10f;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT min_val = std::numeric_limits<MT>::max();
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
min_val = min(min_val, abs(x[i] - y[i]));
min_val = min(min_val, abs(static_cast<MT>(x[i]) - static_cast<MT>(y[i])));
}
__syncthreads();
min_val = phi::funcs::BlockReduceMin(min_val, FULL_MASK);
min_val = phi::funcs::BlockReduceMin<MT>(min_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = min_val;
out[blockIdx.x] = static_cast<T>(min_val);
}
}
......@@ -110,6 +121,7 @@ void DistKernel(const Context& dev_ctx,
const DenseTensor& y,
float p,
DenseTensor* out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
DenseTensor intermediate;
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
......@@ -130,10 +142,9 @@ void DistKernel(const Context& dev_ctx,
if (p == 0) {
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor<T>());
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);
x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor<T, MT>());
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<MT>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<MT>(), reduce_axis);
} else if (p == INFINITY) {
ReduceMaxWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
......@@ -150,19 +161,19 @@ void DistKernel(const Context& dev_ctx,
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);
} else {
T p_order = static_cast<T>(p);
MT p_order = static_cast<MT>(p);
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T>(p_order));
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T, MT>(p_order));
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<MT>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<MT>(), reduce_axis);
const DenseTensor* tmp_norm = out;
std::vector<const DenseTensor*> ins = {tmp_norm};
std::vector<DenseTensor*> outs = {out};
T p_order_ = static_cast<T>(1. / p_order);
MT p_order_ = static_cast<MT>(static_cast<MT>(1.) / p_order);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, PowFunctor<T>(p_order_));
dev_ctx, ins, &outs, PowFunctor<T, MT>(p_order_));
}
} else {
......@@ -173,4 +184,11 @@ void DistKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {}
PD_REGISTER_KERNEL(dist,
GPU,
ALL_LAYOUT,
phi::DistKernel,
float,
double,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -670,8 +670,8 @@ def dist(x, y, p=2, name=None):
||z||_{p}=(\sum_{i=1}^{m}|z_i|^p)^{\\frac{1}{p}}
Args:
x (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
x (Tensor): 1-D to 6-D Tensor, its data type is bfloat16, float16, float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is bfloat16, float16, float32 or float64.
p (float, optional): The norm to be computed, its data type is float32 or float64. Default: 2.
name (str, optional): The default value is `None`. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
......@@ -701,8 +701,12 @@ def dist(x, y, p=2, name=None):
if in_dynamic_mode():
return _C_ops.dist(x, y, p)
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'dist')
check_variable_and_dtype(y, 'dtype', ['float32', 'float64'], 'dist')
check_variable_and_dtype(
x, 'dtype', ['bfloat16', 'float16', 'float32', 'float64'], 'dist'
)
check_variable_and_dtype(
y, 'dtype', ['bfloat16', 'float16', 'float32', 'float64'], 'dist'
)
check_type(p, 'p', (float, int), 'dist')
helper = LayerHelper("dist", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......
......@@ -158,6 +158,86 @@ class TestDistOpCase5(TestDistOp):
self.p = 1.5
class TestDistBF16Op(OpTest):
def init_data_type(self):
self.data_type = 'bfloat16'
class TestDistBF16OpCase1(TestDistBF16Op):
def init_case(self):
self.x_shape = (3, 5, 5, 6)
self.y_shape = (5, 5, 6)
self.p = 1.0
class TestDistBF16OpCase2(TestDistBF16Op):
def init_case(self):
self.x_shape = (10, 10)
self.y_shape = (4, 10, 10)
self.p = 2.0
class TestDistBF16OpCase3(TestDistBF16Op):
def init_case(self):
self.x_shape = (15, 10)
self.y_shape = (15, 10)
self.p = float("inf")
class TestDistBF16OpCase4(TestDistBF16Op):
def init_case(self):
self.x_shape = (2, 3, 4, 5, 8)
self.y_shape = (3, 1, 5, 8)
self.p = float("-inf")
class TestDistBF16OpCase5(TestDistBF16Op):
def init_case(self):
self.x_shape = (4, 1, 4, 8)
self.y_shape = (2, 2, 1, 4, 4, 8)
self.p = 1.5
class TestDistFP16Op(OpTest):
def init_data_type(self):
self.data_type = 'float16'
class TestDistFP16OpCase1(TestDistFP16Op):
def init_case(self):
self.x_shape = (3, 5, 5, 6)
self.y_shape = (5, 5, 6)
self.p = 1.0
class TestDistFP16OpCase2(TestDistFP16Op):
def init_case(self):
self.x_shape = (10, 10)
self.y_shape = (4, 10, 10)
self.p = 2.0
class TestDistFP16OpCase3(TestDistFP16Op):
def init_case(self):
self.x_shape = (15, 10)
self.y_shape = (15, 10)
self.p = float("inf")
class TestDistFP16OpCase4(TestDistFP16Op):
def init_case(self):
self.x_shape = (2, 3, 4, 5, 8)
self.y_shape = (3, 1, 5, 8)
self.p = float("-inf")
class TestDistFP16OpCase5(TestDistFP16Op):
def init_case(self):
self.x_shape = (4, 1, 4, 8)
self.y_shape = (2, 2, 1, 4, 4, 8)
self.p = 1.5
class TestDistAPI(unittest.TestCase):
def init_data_type(self):
self.data_type = (
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册