未验证 提交 6310419b 编写于 作者: H Hui Zhang 提交者: GitHub

floor div support int8/int16/int32/int64/uint8/float32/float64/bfloat16/float16 (#53854)

* floor div support float/double/bfloat16/float16

* add ut

* fix bug

* fix fft.ifftshift for floor_divide upgrade

* fix comment

* fix bugs

* fix bug
上级 d7ad0e42
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <type_traits>
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/hostdevice.h"
namespace phi {
// Returns false since we cannot have x < 0 if x is unsigned.
template <typename T>
static inline constexpr bool is_negative(const T& x, std::true_type) {
return false;
}
// Returns true if a signed variable x < 0
template <typename T>
static inline constexpr bool is_negative(const T& x, std::false_type) {
return x < T(0);
}
// Returns true if x < 0
template <typename T>
inline constexpr bool is_negative(const T& x) {
return is_negative(x, std::is_unsigned<T>());
}
// Note: Explicit implementation of copysign for float16 and bfloat16
// is needed to workaround g++-7/8 crash on aarch64, but also makes
// copysign faster for the half-precision types
template <typename T, typename U>
inline HOSTDEVICE auto copysign(const T& a, const U& b) {
return std::copysign(a, b);
}
// Implement copysign for half precision floats using bit ops
// Sign is the most significant bit for both float16 and bfloat16 types
inline HOSTDEVICE phi::dtype::float16 copysign(phi::dtype::float16 a,
phi::dtype::float16 b) {
return phi::dtype::raw_uint16_to_float16((a.x & 0x7fff) | (b.x & 0x8000));
}
inline HOSTDEVICE phi::dtype::bfloat16 copysign(phi::dtype::bfloat16 a,
phi::dtype::bfloat16 b) {
return phi::dtype::raw_uint16_to_bfloat16((a.x & 0x7fff) | (b.x & 0x8000));
}
} // namespace phi
......@@ -118,8 +118,19 @@ PD_REGISTER_KERNEL(remainder,
double,
int,
int64_t) {}
PD_REGISTER_KERNEL(
floor_divide, CPU, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {}
PD_REGISTER_KERNEL(floor_divide,
CPU,
ALL_LAYOUT,
phi::FloorDivideKernel,
uint8_t,
int8_t,
int16_t,
int32_t,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(elementwise_pow,
CPU,
ALL_LAYOUT,
......@@ -129,7 +140,6 @@ PD_REGISTER_KERNEL(elementwise_pow,
int,
int64_t,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(heaviside,
CPU,
ALL_LAYOUT,
......
......@@ -26,6 +26,7 @@ limitations under the License. */
#include "xpu/kernel/math_xpu2.h" // pow()
#endif
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/type_safe_sign_math.h"
namespace phi {
namespace funcs {
......@@ -604,26 +605,233 @@ struct ElementwiseHeavisideFunctor {
}
};
template <typename T>
template <typename T, typename Enable = void>
struct FloorDivideFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
#ifndef PADDLE_WITH_XPU_KP
PADDLE_ENFORCE(b != 0, DIV_ERROR_INFO);
#endif
if (phi::is_negative(a) != phi::is_negative(b)) {
// Subtracts one from the results of truncation division if the
// divisor and dividend have different sign(bit)s and the remainder of
// the division is nonzero
const auto quot = a / b;
const auto rem = a % b;
auto ret = rem ? quot - 1 : quot;
return static_cast<T>(ret);
}
return static_cast<T>(a / b);
}
};
template <typename T>
struct FloorDivideFunctor<
T,
typename std::enable_if_t<std::is_floating_point<T>::value>> {
inline HOSTDEVICE T operator()(const T a, const T b) const {
if (UNLIKELY(b == 0)) {
// Divide by zero: return standard IEEE result
return static_cast<T>(a / b);
}
auto mod = std::fmod(a, b);
auto div = (a - mod) / b;
if ((mod != 0) && (b < 0) != (mod < 0)) {
div -= T(1);
}
T floordiv;
if (div != 0) {
floordiv = std::floor(div);
if (div - floordiv > T(0.5)) {
floordiv += T(1.0);
}
} else {
floordiv = phi::copysign(T(0), a / b);
}
return floordiv;
}
};
template <>
struct FloorDivideFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
const dtype::float16 b) const {
float b_float = static_cast<float>(b);
float a_float = static_cast<float>(a);
if (UNLIKELY(b_float == 0)) {
// Divide by zero: return standard IEEE result
return static_cast<dtype::float16>(a_float / b_float);
}
auto mod = std::fmod(a_float, b_float);
auto div = (a_float - mod) / b_float;
if ((mod != 0) && (b_float < 0) != (mod < 0)) {
div -= static_cast<float>(1);
}
float floordiv;
if (div != 0) {
floordiv = std::floor(div);
if (div - floordiv > static_cast<float>(0.5)) {
floordiv += static_cast<float>(1.0);
}
} else {
floordiv = phi::copysign(static_cast<float>(0), a_float / b_float);
}
return static_cast<dtype::float16>(floordiv);
}
};
template <>
struct FloorDivideFunctor<dtype::bfloat16> {
inline HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16 a,
const dtype::bfloat16 b) const {
float b_float = static_cast<float>(b);
float a_float = static_cast<float>(a);
if (UNLIKELY(b_float == 0)) {
// Divide by zero: return standard IEEE result
return static_cast<dtype::bfloat16>(a_float / b_float);
}
auto mod = std::fmod(a_float, b_float);
auto div = (a_float - mod) / b_float;
if ((mod != 0) && (b_float < 0) != (mod < 0)) {
div -= static_cast<float>(1);
}
float floordiv;
if (div != 0) {
floordiv = std::floor(div);
if (div - floordiv > static_cast<float>(0.5)) {
floordiv += static_cast<float>(1.0);
}
} else {
floordiv = phi::copysign(static_cast<float>(0), a_float / b_float);
}
return static_cast<dtype::bfloat16>(floordiv);
}
};
template <typename T, typename Enable = void>
struct InverseFloorDivideFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
#ifndef PADDLE_WITH_XPU_KP
PADDLE_ENFORCE(a != 0, DIV_ERROR_INFO);
#endif
if (phi::is_negative(a) != phi::is_negative(b)) {
// Subtracts one from the results of truncation division if the
// divisor and dividend have different sign(bit)s and the remainder of
// the division is nonzero
const auto quot = b / a;
const auto rem = b % a;
auto ret = rem ? quot - 1 : quot;
return static_cast<T>(ret);
}
return static_cast<T>(b / a);
}
};
template <typename T>
struct InverseFloorDivideFunctor<
T,
typename std::enable_if_t<std::is_floating_point<T>::value>> {
inline HOSTDEVICE T operator()(const T a, const T b) const {
if (UNLIKELY(a == 0)) {
// Divide by zero: return standard IEEE result
return static_cast<T>(b / a);
}
auto mod = std::fmod(b, a);
auto div = (b - mod) / a;
if ((mod != 0) && (a < 0) != (mod < 0)) {
div -= T(1);
}
T floordiv;
if (div != 0) {
floordiv = std::floor(div);
if (div - floordiv > T(0.5)) {
floordiv += T(1.0);
}
} else {
floordiv = phi::copysign(T(0), b / a);
}
return floordiv;
}
};
template <>
struct InverseFloorDivideFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
const dtype::float16 b) const {
float b_float = static_cast<float>(a);
float a_float = static_cast<float>(b);
if (UNLIKELY(b_float == 0)) {
// Divide by zero: return standard IEEE result
return static_cast<dtype::float16>(a_float / b_float);
}
auto mod = std::fmod(a_float, b_float);
auto div = (a_float - mod) / b_float;
if ((mod != 0) && (b_float < 0) != (mod < 0)) {
div -= static_cast<float>(1);
}
float floordiv;
if (div != 0) {
floordiv = std::floor(div);
if (div - floordiv > static_cast<float>(0.5)) {
floordiv += static_cast<float>(1.0);
}
} else {
floordiv = phi::copysign(static_cast<float>(0), a_float / b_float);
}
return static_cast<dtype::float16>(floordiv);
}
};
template <>
struct InverseFloorDivideFunctor<dtype::bfloat16> {
inline HOSTDEVICE dtype::bfloat16 operator()(const dtype::bfloat16 a,
const dtype::bfloat16 b) const {
float b_float = static_cast<float>(a);
float a_float = static_cast<float>(b);
if (UNLIKELY(b_float == 0)) {
// Divide by zero: return standard IEEE result
return static_cast<dtype::bfloat16>(a_float / b_float);
}
auto mod = std::fmod(a_float, b_float);
auto div = (a_float - mod) / b_float;
if ((mod != 0) && (b_float < 0) != (mod < 0)) {
div -= static_cast<float>(1);
}
float floordiv;
if (div != 0) {
floordiv = std::floor(div);
if (div - floordiv > static_cast<float>(0.5)) {
floordiv += static_cast<float>(1.0);
}
} else {
floordiv = phi::copysign(static_cast<float>(0), a_float / b_float);
}
return static_cast<dtype::bfloat16>(floordiv);
}
};
#if defined(__CUDA_ARCH__) || defined(__HIPCC__)
template <typename T, typename MPType>
inline HOSTDEVICE typename std::enable_if<std::is_integral<T>::value, T>::type
......
......@@ -119,8 +119,19 @@ PD_REGISTER_KERNEL(remainder,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(
floor_divide, KPS, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {}
PD_REGISTER_KERNEL(floor_divide,
KPS,
ALL_LAYOUT,
phi::FloorDivideKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(elementwise_pow,
KPS,
ALL_LAYOUT,
......
......@@ -134,8 +134,15 @@ PD_REGISTER_KERNEL(floor_divide_raw,
CPU,
ALL_LAYOUT,
phi::FloorDivideRawKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
PD_REGISTER_KERNEL(elementwise_pow_raw,
CPU,
ALL_LAYOUT,
......
......@@ -163,8 +163,15 @@ PD_REGISTER_KERNEL(floor_divide_raw,
KPS,
ALL_LAYOUT,
phi::FloorDivideRawKernel,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
int64_t,
float,
double,
float16,
bfloat16) {}
PD_REGISTER_KERNEL(elementwise_pow_raw,
KPS,
ALL_LAYOUT,
......
......@@ -1412,11 +1412,11 @@ def ifftshift(x, axes=None, name=None):
# shift all axes
rank = len(x.shape)
axes = list(range(0, rank))
shifts = -shape // 2
shifts = (shape + 1) // 2
elif isinstance(axes, int):
shifts = -shape[axes] // 2
shifts = (shape[axes] + 1) // 2
else:
shifts = paddle.concat([-shape[ax : ax + 1] // 2 for ax in axes])
shifts = paddle.concat([(shape[ax : ax + 1] + 1) // 2 for ax in axes])
return paddle.roll(x, shifts, axes, name=name)
......
......@@ -830,8 +830,8 @@ def floor_divide(x, y, name=None):
Also note that the name ``floor_divide`` can be misleading, as the quotinents are actually rounded toward zero, not toward negative infinite.
Args:
x (Tensor): the input tensor, it's data type should be int32, int64.
y (Tensor): the input tensor, it's data type should be int32, int64.
x (Tensor): the input tensor, it's data type should be uint8, int8, int32, int64, float32, float64, float16, bfloat16.
y (Tensor): the input tensor, it's data type should be uint8, int8, int32, int64, float32, float64, float16, bfloat16.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
......
......@@ -14,6 +14,7 @@
import random
import unittest
from contextlib import contextmanager
import numpy as np
from eager_op_test import OpTest, paddle_static_guard
......@@ -95,6 +96,13 @@ class TestElementwiseModOpInverse(TestElementwiseModOp):
self.out = np.floor_divide(self.x, self.y)
@contextmanager
def device_guard(device=None):
old = paddle.get_device()
yield paddle.set_device(device)
paddle.set_device(old)
class TestFloorDivideOp(unittest.TestCase):
def test_name(self):
with paddle_static_guard():
......@@ -106,17 +114,92 @@ class TestFloorDivideOp(unittest.TestCase):
self.assertEqual(('div_res' in y_1.name), True)
def test_dygraph(self):
with fluid.dygraph.guard():
np_x = np.array([2, 3, 8, 7]).astype('int64')
np_y = np.array([1, 5, 3, 3]).astype('int64')
x = paddle.to_tensor(np_x)
y = paddle.to_tensor(np_y)
paddle.disable_static()
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
for dtype in (
'uint8',
'int8',
'int16',
'int32',
'int64',
'float16',
'float32',
'float64',
):
np_x = np.array([2, 3, 8, 7]).astype(dtype)
np_y = np.array([1, 5, 3, 3]).astype(dtype)
x = paddle.to_tensor(np_x)
y = paddle.to_tensor(np_y)
z = paddle.floor_divide(x, y)
np_z = z.numpy()
z_expected = np.floor_divide(np_x, np_y)
self.assertEqual((np_z == z_expected).all(), True)
np_x = np.array([2, 3, 8, 7])
np_y = np.array([1, 5, 3, 3])
x = paddle.to_tensor(np_x, dtype='bfloat16')
y = paddle.to_tensor(np_y, dtype="bfloat16")
z = paddle.floor_divide(x, y)
np_z = z.numpy()
z_expected = np.array([2, 0, 2, 2])
z_expected = np.array([16384, 0, 16384, 16384], dtype='uint16')
self.assertEqual((np_z == z_expected).all(), True)
with fluid.dygraph.guard(fluid.CPUPlace()):
for dtype in (
'int8',
'int16',
'int32',
'int64',
'float16',
'float32',
'float64',
):
np_x = -np.array([2, 3, 8, 7]).astype(dtype)
np_y = np.array([1, 5, 3, 3]).astype(dtype)
x = paddle.to_tensor(np_x)
y = paddle.to_tensor(np_y)
z = paddle.floor_divide(x, y)
np_z = z.numpy()
z_expected = np.floor_divide(np_x, np_y)
self.assertEqual((np_z == z_expected).all(), True)
np_x = -np.array([2, 3, 8, 7])
np_y = np.array([1, 5, 3, 3])
x = paddle.to_tensor(np_x, dtype='bfloat16')
y = paddle.to_tensor(np_y, dtype="bfloat16")
z = paddle.floor_divide(x, y)
np_z = z.numpy()
z_expected = np.array([49152, 49024, 49216, 49216], dtype='uint16')
self.assertEqual((np_z == z_expected).all(), True)
for dtype in ('float32', 'float64', 'float16'):
try:
# divide by zero
np_x = np.array([2])
np_y = np.array([0, 0, 0])
x = paddle.to_tensor(np_x, dtype=dtype)
y = paddle.to_tensor(np_y, dtype=dtype)
z = paddle.floor_divide(x, y)
np_z = z.numpy()
# [np.inf, np.inf, np.inf]
z_expected = np.floor_divide(np_x, np_y)
self.assertEqual((np_z == z_expected).all(), True)
except Exception as e:
pass
# divide by zero
np_x = np.array([2])
np_y = np.array([0, 0, 0])
x = paddle.to_tensor(np_x, dtype='bfloat16')
y = paddle.to_tensor(np_y, dtype="bfloat16")
z = paddle.floor_divide(x, y)
np_z = z.numpy()
z_expected = np.array([32640, 32640, 32640], dtype='uint16')
self.assertEqual((np_z == z_expected).all(), True)
with device_guard('cpu'):
# divide by zero
np_x = np.array([2, 3, 4])
np_y = np.array([0])
......@@ -125,17 +208,20 @@ class TestFloorDivideOp(unittest.TestCase):
try:
z = x // y
except Exception as e:
print("Error: Divide by zero encounter in floor_divide\n")
pass
# divide by zero
np_x = np.array([2])
np_y = np.array([0, 0, 0])
x = paddle.to_tensor(np_x, dtype="int32")
y = paddle.to_tensor(np_y, dtype="int32")
try:
z = x // y
except Exception as e:
print("Error: Divide by zero encounter in floor_divide\n")
for dtype in ("uint8", 'int8', 'int16', 'int32', 'int64'):
np_x = np.array([2])
np_y = np.array([0, 0, 0])
x = paddle.to_tensor(np_x, dtype=dtype)
y = paddle.to_tensor(np_y, dtype=dtype)
try:
z = x // y
except Exception as e:
pass
paddle.enable_static()
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册