未验证 提交 93016331 编写于 作者: Z zhangbo9674 提交者: GitHub

[bf16] add bf16 kernel: elementwise_max (#39461)

* add elementwise_max & unittest

* refine cuda register and unittest

* refine unittest

* refine uinttest for bf16

* refine optest

* refine code

* refine unittest

* refine unittest
上级 c984cd85
......@@ -124,13 +124,17 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMaxKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_CPU_KERNEL(
elementwise_max_grad,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMaxGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>);
REGISTER_OP_VERSION(elementwise_max)
.AddCheckpoint(
......
......@@ -69,6 +69,8 @@ REGISTER_OP_CUDA_KERNEL(
elementwise_max,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMaxKernel<paddle::platform::CUDADeviceContext, int>,
......@@ -77,6 +79,8 @@ REGISTER_OP_CUDA_KERNEL(
elementwise_max_grad,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::bfloat16>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMaxGradKernel<paddle::platform::CUDADeviceContext, int>,
......
......@@ -168,8 +168,10 @@ def get_numeric_gradient(place,
elif tensor_to_check._dtype() == core.VarDesc.VarType.BF16:
numpy_tensor = np.array(tensor).astype(np.uint16)
numpy_tensor = numpy_tensor.flatten()
return struct.unpack('<f', struct.pack('<I', numpy_tensor[i]
<< 16))[0]
return struct.unpack('<f',
struct.pack('<I',
np.uint32(numpy_tensor[i])
<< np.uint32(16)))[0]
elif tensor_to_check_dtype == np.float32:
return tensor._get_float_element(i)
elif tensor_to_check_dtype == np.float64:
......@@ -272,7 +274,7 @@ def convert_float_to_uint16(float_list, data_format="NCHW"):
def convert_uint16_to_float(in_list):
in_list = np.asarray(in_list)
out = np.vectorize(
lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
lambda x: struct.unpack('<f', struct.pack('<I', np.uint32(x) << np.uint32(16)))[0],
otypes=[np.float32])(in_list.flat)
return np.reshape(out, in_list.shape)
......
......@@ -16,7 +16,10 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from op_test import OpTest, skip_check_grad_ci, convert_float_to_uint16
import os
import re
import paddle.fluid.core as core
class TestElementwiseOp(OpTest):
......@@ -46,6 +49,38 @@ class TestElementwiseOp(OpTest):
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
@unittest.skipIf(
core.is_compiled_with_cuda() and core.cudnn_version() < 8100,
"run test when gpu is availble and the minimum cudnn version is 8.1.0.")
class TestElementwiseBF16Op(OpTest):
def setUp(self):
self.op_type = "elementwise_max"
self.dtype = np.uint16
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
x = np.random.uniform(0.1, 1, [13, 17]).astype(np.float32)
sgn = np.random.choice([-1, 1], [13, 17]).astype(np.float32)
y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype(np.float32)
self.inputs = {
'X': convert_float_to_uint16(x),
'Y': convert_float_to_uint16(y)
}
self.outputs = {'Out': convert_float_to_uint16(np.maximum(x, y))}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X', 'Y'], 'Out')
def test_check_grad_ingore_x(self):
self.check_grad(['Y'], 'Out', no_grad_set=set("X"))
def test_check_grad_ingore_y(self):
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
@skip_check_grad_ci(
reason="[skip shape check] Use y_shape(1) to test broadcast.")
class TestElementwiseMaxOp_scalar(TestElementwiseOp):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册