未验证 提交 d8407c51 编写于 作者: L LoneRanger 提交者: GitHub

add fp16 and bf16 for trunc (#53876)

上级 a862debf
......@@ -52,4 +52,6 @@ PD_REGISTER_KERNEL(trunc_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
......@@ -27,7 +28,10 @@ template <typename T>
class TruncFunctor {
public:
__device__ TruncFunctor(const T x) : x_(x) {}
__device__ T operator()() { return trunc(x_); }
__device__ T operator()() {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
return static_cast<T>(trunc(static_cast<MPType>(x_)));
}
public:
const T x_;
......@@ -78,5 +82,13 @@ void TruncKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
trunc, GPU, ALL_LAYOUT, phi::TruncKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(trunc,
GPU,
ALL_LAYOUT,
phi::TruncKernel,
float,
double,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,9 +15,10 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
paddle.enable_static()
......@@ -90,5 +91,35 @@ class TestTruncAPI(unittest.TestCase):
self.assertRaises(TypeError, paddle.trunc, x)
class TestTruncFP16OP(TestTruncOp):
def init_dtype_type(self):
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestTruncBF16OP(OpTest):
def setUp(self):
self.python_api = paddle.trunc
self.op_type = "trunc"
self.dtype = np.uint16
np.random.seed(2021)
x = np.random.random((20, 20)).astype("float32")
out = np.trunc(x)
self.inputs = {'X': convert_float_to_uint16(x)}
self.outputs = {'Out': convert_float_to_uint16(out)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', numeric_grad_delta=1e-5)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册