未验证 提交 f3448977 编写于 作者: Y yangjianfengo1 提交者: GitHub

[AMP OP&Test] arange op support fp16/bf16 (#51106)

* AMP arange & Test

* fix arange bfloat16 dtype

* update for review

* update for review2

* fix tile

* update

* fix ci

* r

* f

* fix windows ci

* update bfloat data

* fix bloat16 input

* add print

* Update test_where_op.py

* update kernel

* del repeat

* update review
上级 2727dddb
......@@ -15,6 +15,9 @@
#include "paddle/phi/kernels/arange_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
......@@ -23,9 +26,11 @@
namespace phi {
template <typename T>
__global__ void Range(T start, T step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
template <typename T, typename OUT_TYPE>
__global__ void Range(T start, T step, int64_t size, OUT_TYPE* out) {
CUDA_KERNEL_LOOP(index, size) {
out[index] = static_cast<OUT_TYPE>(start + step * index);
}
}
template <typename T, typename Context>
......@@ -34,9 +39,11 @@ void ArangeKernel(const Context& dev_ctx,
const DenseTensor& end,
const DenseTensor& step,
DenseTensor* out) {
T start_value = GetValue<T, Context>(dev_ctx, start);
T end_value = GetValue<T, Context>(dev_ctx, end);
T step_value = GetValue<T, Context>(dev_ctx, step);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType start_value =
static_cast<MPType>(GetValue<T, Context>(dev_ctx, start));
MPType end_value = static_cast<MPType>(GetValue<T, Context>(dev_ctx, end));
MPType step_value = static_cast<MPType>(GetValue<T, Context>(dev_ctx, step));
int64_t size = 0;
phi::funcs::GetSize(start_value, end_value, step_value, &size);
......@@ -49,7 +56,8 @@ void ArangeKernel(const Context& dev_ctx,
return;
}
int64_t grid = (size + block - 1) / block;
Range<T><<<grid, block, 0, stream>>>(start_value, step_value, size, out_data);
Range<MPType, T>
<<<grid, block, 0, stream>>>(start_value, step_value, size, out_data);
}
template <typename T, typename Context>
......@@ -78,8 +86,16 @@ template decltype(ArangeNullaryKernel<int, phi::GPUContext>)
ArangeNullaryKernel;
} // namespace phi
PD_REGISTER_KERNEL(
arange, GPU, ALL_LAYOUT, phi::ArangeKernel, float, double, int64_t, int) {
PD_REGISTER_KERNEL(arange,
GPU,
ALL_LAYOUT,
phi::ArangeKernel,
float,
double,
int64_t,
int,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
......@@ -58,6 +58,50 @@ class TestFloatArangeOp(TestArangeOp):
self.case = (0, 5, 1)
class TestFloa16ArangeOp(TestArangeOp):
def init_config(self):
self.dtype = np.float16
self.python_api = paddle.arange
self.case = (0, 5, 1)
def test_check_output(self):
self.check_output()
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestBFloat16ArangeOp(OpTest):
def setUp(self):
self.op_type = "range"
self.init_config()
self.inputs = {
'Start': convert_float_to_uint16(self.start),
'End': convert_float_to_uint16(self.end),
'Step': convert_float_to_uint16(self.step),
}
self.outputs = {
'Out': convert_float_to_uint16(
np.arange(self.start, self.end, self.step)
)
}
def init_config(self):
self.dtype = np.uint16
self.python_api = arange_wrapper
self.case = (0, 5, 1)
self.start = np.array([self.case[0]]).astype(np.float32)
self.end = np.array([self.case[1]]).astype(np.float32)
self.step = np.array([self.case[2]]).astype(np.float32)
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
class TestInt32ArangeOp(TestArangeOp):
def init_config(self):
self.dtype = np.int32
......
......@@ -1233,7 +1233,7 @@ def arange(start=0, end=None, step=1, dtype=None, name=None):
check_dtype(
dtype,
'dtype',
['float32', 'float64', 'int32', 'int64'],
['float32', 'float64', 'int32', 'int64', 'float16', 'uint16'],
'range/arange',
)
helper = LayerHelper('range', **locals())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册