未验证 提交 5b09dd56 编写于 作者: T Thomas Young 提交者: GitHub

[AMP OP&Test] add bf16 fp16 type support for expand_v2_op and top_k_v2_op (#51263)

上级 a6ae1e35
...@@ -26,6 +26,7 @@ limitations under the License. */ ...@@ -26,6 +26,7 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h" #include "paddle/phi/common/float16.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h" #include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
...@@ -49,6 +50,10 @@ namespace detail { ...@@ -49,6 +50,10 @@ namespace detail {
template <> template <>
struct radix_key_codec_base<phi::dtype::float16> struct radix_key_codec_base<phi::dtype::float16>
: radix_key_codec_integral<phi::dtype::float16, uint16_t> {}; : radix_key_codec_integral<phi::dtype::float16, uint16_t> {};
template <>
struct radix_key_codec_base<phi::dtype::bfloat16>
: radix_key_codec_integral<phi::dtype::bfloat16, uint16_t> {};
} // namespace detail } // namespace detail
} // namespace rocprim } // namespace rocprim
namespace cub = hipcub; namespace cub = hipcub;
...@@ -58,6 +63,12 @@ namespace cub { ...@@ -58,6 +63,12 @@ namespace cub {
template <> template <>
struct NumericTraits<phi::dtype::float16> struct NumericTraits<phi::dtype::float16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::float16> {}; : BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::float16> {};
template <>
struct NumericTraits<phi::dtype::bfloat16>
: BaseTraits<FLOATING_POINT, true, false, uint16_t, phi::dtype::bfloat16> {
};
} // namespace cub } // namespace cub
#endif #endif
...@@ -586,6 +597,24 @@ struct RadixTypeConfig<phi::dtype::float16> { ...@@ -586,6 +597,24 @@ struct RadixTypeConfig<phi::dtype::float16> {
} }
}; };
template <>
struct RadixTypeConfig<phi::dtype::bfloat16> {
typedef uint32_t RadixType;
static inline __device__ RadixType Convert(phi::dtype::bfloat16 v) {
RadixType x = v.x;
RadixType mask = (x & 0x00008000) ? 0x0000ffff : 0x00008000;
return (v == v) ? (x ^ mask) : 0xffff;
}
static inline __device__ phi::dtype::bfloat16 Deconvert(RadixType v) {
RadixType mask = (v & 0x00008000) ? 0x00008000 : 0x0000ffff;
phi::dtype::bfloat16 r;
r.x = (v ^ mask);
return r;
}
};
/*---------------------------Helper Functions------------------*/ /*---------------------------Helper Functions------------------*/
__device__ __forceinline__ int GetLaneId() { __device__ __forceinline__ int GetLaneId() {
int lane_id; int lane_id;
......
...@@ -13,8 +13,8 @@ ...@@ -13,8 +13,8 @@
// limitations under the License. // limitations under the License.
#include "paddle/phi/kernels/top_k_grad_kernel.h" #include "paddle/phi/kernels/top_k_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h" #include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
...@@ -89,4 +89,5 @@ PD_REGISTER_KERNEL(topk_grad, ...@@ -89,4 +89,5 @@ PD_REGISTER_KERNEL(topk_grad,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) {} phi::dtype::float16,
phi::dtype::bfloat16) {}
...@@ -15,11 +15,13 @@ ...@@ -15,11 +15,13 @@
#include "paddle/phi/kernels/top_k_kernel.h" #include "paddle/phi/kernels/top_k_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/gather.cu.h" #include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/top_k_function_cuda.h" #include "paddle/phi/kernels/funcs/top_k_function_cuda.h"
namespace phi { namespace phi {
#define FIXED_BLOCK_DIM_BASE(dim, ...) \ #define FIXED_BLOCK_DIM_BASE(dim, ...) \
...@@ -348,6 +350,7 @@ PD_REGISTER_KERNEL(topk, ...@@ -348,6 +350,7 @@ PD_REGISTER_KERNEL(topk,
double, double,
int, int,
int64_t, int64_t,
phi::dtype::float16) { phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64); kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
} }
...@@ -17,7 +17,7 @@ import unittest ...@@ -17,7 +17,7 @@ import unittest
import gradient_checker import gradient_checker
import numpy as np import numpy as np
from decorator_helper import prog_scope from decorator_helper import prog_scope
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -202,6 +202,56 @@ class TestExpandV2OpInt64_t(OpTest): ...@@ -202,6 +202,56 @@ class TestExpandV2OpInt64_t(OpTest):
self.check_output() self.check_output()
# Situation 7: input x is Float16
class TestExpandV2FP16Op(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "prim"
self.dtype = np.float16
self.python_api = paddle.expand
self.public_python_api = paddle.expand
self.inputs = {
'X': np.random.randint(10, size=(8, 8, 5)).astype(self.dtype)
}
self.attrs = {'shape': [8, 8, 5]}
output = np.tile(self.inputs['X'], (1, 1, 1))
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
# Situation 8: input x is BF16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestExpandV2BF16Op(OpTest):
def setUp(self):
self.op_type = "expand_v2"
self.prim_op_type = "prim"
self.dtype = np.uint16
self.python_api = paddle.expand
self.public_python_api = paddle.expand
x = np.random.randint(10, size=(8, 8, 5)).astype(np.float32)
self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {'shape': [8, 8, 5]}
output = np.tile(x, (1, 1, 1)).astype(np.float32)
self.outputs = {'Out': convert_float_to_uint16(output)}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', check_prim=True)
class TestExpandV2Error(unittest.TestCase): class TestExpandV2Error(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
...@@ -338,7 +388,7 @@ class TestExpandTripleGradCheck(unittest.TestCase): ...@@ -338,7 +388,7 @@ class TestExpandTripleGradCheck(unittest.TestCase):
self.func(p) self.func(p)
# Situation 7: comp case, shape is a list(without tensor) # Situation 9: comp case, shape is a list(without tensor)
class TestExpandV2CompOpRank1(OpTest): class TestExpandV2CompOpRank1(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand_v2" self.op_type = "expand_v2"
...@@ -392,7 +442,7 @@ class TestExpandV2CompOpRank4(TestExpandV2CompOpRank1): ...@@ -392,7 +442,7 @@ class TestExpandV2CompOpRank4(TestExpandV2CompOpRank1):
self.expand_times = (1, 1, 1, 1) self.expand_times = (1, 1, 1, 1)
# Situation 8: comp case, input x is Integer # Situation 10: comp case, input x is Integer
class TestExpandV2CompOpInteger(OpTest): class TestExpandV2CompOpInteger(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand_v2" self.op_type = "expand_v2"
...@@ -410,7 +460,7 @@ class TestExpandV2CompOpInteger(OpTest): ...@@ -410,7 +460,7 @@ class TestExpandV2CompOpInteger(OpTest):
self.check_output(check_prim=True) self.check_output(check_prim=True)
# Situation 9: comp case, input x is Bool # Situation 11: comp case, input x is Bool
class TestExpandV2CompOpBoolean(OpTest): class TestExpandV2CompOpBoolean(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand_v2" self.op_type = "expand_v2"
...@@ -426,7 +476,7 @@ class TestExpandV2CompOpBoolean(OpTest): ...@@ -426,7 +476,7 @@ class TestExpandV2CompOpBoolean(OpTest):
self.check_output(check_prim=True) self.check_output(check_prim=True)
# Situation 10: comp case, input x is Integer # Situation 12: comp case, input x is Integer
class TestExpandV2CompOpInt64_t(OpTest): class TestExpandV2CompOpInt64_t(OpTest):
def setUp(self): def setUp(self):
self.op_type = "expand_v2" self.op_type = "expand_v2"
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest, convert_float_to_uint16
import paddle import paddle
from paddle.fluid import core from paddle.fluid import core
...@@ -189,6 +189,56 @@ class TestTopkOp7(TestTopkOp): ...@@ -189,6 +189,56 @@ class TestTopkOp7(TestTopkOp):
self.outputs = {'Out': output, 'Indices': indices} self.outputs = {'Out': output, 'Indices': indices}
class TestTopkFP16Op(TestTopkOp):
def setUp(self):
self.op_type = "top_k_v2"
self.python_api = paddle.topk
self.public_python_api = paddle.topk
self.dtype = np.float16
self.prim_op_type = "prim"
self.input_data = np.random.rand(10, 20).astype(self.dtype)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=self.largest
)
self.outputs = {'Out': output, 'Indices': indices}
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support the bfloat16",
)
class TestTopkBF16Op(TestTopkOp):
def setUp(self):
self.op_type = "top_k_v2"
self.python_api = paddle.topk
self.public_python_api = paddle.topk
self.dtype = np.uint16
self.prim_op_type = "prim"
self.input_data = np.random.rand(10, 20).astype(np.float32)
self.init_args()
self.inputs = {'X': convert_float_to_uint16(self.input_data)}
self.attrs = {'k': self.k, 'axis': self.axis, 'largest': self.largest}
output, indices = numpy_topk(
self.input_data, axis=self.axis, k=self.k, largest=self.largest
)
self.outputs = {
'Out': convert_float_to_uint16(output),
'Indices': indices,
}
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=True)
def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, {'X'}, 'Out', check_eager=True)
class TestTopKAPI(unittest.TestCase): class TestTopKAPI(unittest.TestCase):
def setUp(self): def setUp(self):
np.random.seed(123) np.random.seed(123)
......
...@@ -3418,7 +3418,15 @@ def expand(x, shape, name=None): ...@@ -3418,7 +3418,15 @@ def expand(x, shape, name=None):
check_variable_and_dtype( check_variable_and_dtype(
x, x,
'x', 'x',
['bool', 'float16', 'float32', 'float64', 'int32', 'int64'], [
'bool',
'float16',
'float32',
'float64',
'int32',
'int64',
'uint16',
],
'expand', 'expand',
) )
check_type(shape, 'shape', (list, tuple, Variable), 'expand') check_type(shape, 'shape', (list, tuple, Variable), 'expand')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册