未验证 提交 c3055d23 编写于 作者: C chenxujun 提交者: GitHub

【Hackathon No.60】prelu, clip_by_norm, multi_dot 算子FP16/BF16单测完善 (#52666)

* Add prelu, clip_by_norm, multi_dot tests

* Fix code

* Fix code
上级 534efcb6
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/prelu.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
......@@ -135,14 +136,17 @@ void PreluScalarDirectCUDAFunctor<T>::operator()(gpuStream_t stream,
template class PreluChannelWiseDirectCUDAFunctor<float>;
template class PreluChannelWiseDirectCUDAFunctor<platform::float16>;
template class PreluChannelWiseDirectCUDAFunctor<platform::bfloat16>;
template class PreluChannelWiseDirectCUDAFunctor<double>;
template class PreluElementWiseDirectCUDAFunctor<float>;
template class PreluElementWiseDirectCUDAFunctor<platform::float16>;
template class PreluElementWiseDirectCUDAFunctor<platform::bfloat16>;
template class PreluElementWiseDirectCUDAFunctor<double>;
template class PreluScalarDirectCUDAFunctor<float>;
template class PreluScalarDirectCUDAFunctor<platform::float16>;
template class PreluScalarDirectCUDAFunctor<platform::bfloat16>;
template class PreluScalarDirectCUDAFunctor<double>;
} // namespace math
......
......@@ -17,7 +17,7 @@
#include <typeinfo>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
......@@ -34,7 +34,7 @@ void ClipByNormKernel(const Context& dev_ctx,
return ClipByNormFunctor<float, Context>(dev_ctx, in, max_norm, output);
}
auto input = &in;
dev_ctx.template Alloc<dtype::float16>(output);
dev_ctx.template Alloc<T>(output);
PADDLE_ENFORCE_NOT_NULL(input,
phi::errors::InvalidArgument(
......@@ -49,20 +49,14 @@ void ClipByNormKernel(const Context& dev_ctx,
auto* tmp = &tmp_tensor;
tmp->Resize({1});
dev_ctx.template Alloc<float>(tmp);
phi::funcs::ReduceKernel<dtype::float16,
float,
kps::AddFunctor,
kps::SquareFunctor<dtype::float16, float>>(
dev_ctx,
*input,
tmp,
kps::SquareFunctor<dtype::float16, float>(),
reduce_dims);
phi::funcs::
ReduceKernel<T, float, kps::AddFunctor, kps::SquareFunctor<T, float>>(
dev_ctx, *input, tmp, kps::SquareFunctor<T, float>(), reduce_dims);
auto tmp_eigen = phi::EigenVector<float>::Flatten(*tmp);
auto x_norm = tmp_eigen.sqrt();
auto x = phi::EigenVector<dtype::float16>::Flatten(*input);
auto out = phi::EigenVector<dtype::float16>::Flatten(*output);
auto x = phi::EigenVector<T>::Flatten(*input);
auto out = phi::EigenVector<T>::Flatten(*output);
auto* place = dev_ctx.eigen_device();
auto temp = (x_norm <= max_norm).template cast<float>();
......@@ -72,7 +66,7 @@ void ClipByNormKernel(const Context& dev_ctx,
auto scaling =
(temp + (static_cast<float>(1) - temp) * max_norm / (x_norm + epsilon))
.template cast<dtype::float16>();
.template cast<T>();
Eigen::array<int, 1> one_dim{{1}};
Eigen::DSizes<int, 1> m_dsize(input->numel());
......@@ -86,4 +80,5 @@ PD_REGISTER_KERNEL(clip_by_norm,
ALL_LAYOUT,
phi::ClipByNormKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,16 +15,15 @@ limitations under the License. */
#include "paddle/phi/kernels/multi_dot_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/multi_dot_kernel_impl.h"
using float16 = phi::dtype::float16;
PD_REGISTER_KERNEL(multi_dot_grad,
GPU,
ALL_LAYOUT,
phi::MultiDotGradKernel,
float,
double,
float16) {}
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -15,11 +15,15 @@ limitations under the License. */
#include "paddle/phi/kernels/multi_dot_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/multi_dot_kernel_impl.h"
using float16 = phi::dtype::float16;
PD_REGISTER_KERNEL(
multi_dot, GPU, ALL_LAYOUT, phi::MultiDotKernel, float, double, float16) {}
PD_REGISTER_KERNEL(multi_dot,
GPU,
ALL_LAYOUT,
phi::MultiDotKernel,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
......@@ -189,4 +189,5 @@ PD_REGISTER_KERNEL(prelu_grad,
phi::PReluGradKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
double) {}
......@@ -79,4 +79,5 @@ PD_REGISTER_KERNEL(prelu,
phi::PReluKernel,
float,
phi::dtype::float16,
phi::dtype::bfloat16,
double) {}
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
from op import Operator
import paddle
......@@ -102,6 +102,48 @@ class TestClipByNormOpFp16Case3(TestClipByNormOpFp16):
self.max_norm = 1.0
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestClipByNormBF16Op(OpTest):
def setUp(self):
self.max_relative_error = 0.006
self.python_api = clip.clip_by_norm
self.init_dtype()
self.initTestCase()
input = np.random.random(self.shape).astype(self.np_dtype)
input[np.abs(input) < self.max_relative_error] = 0.5
self.op_type = "clip_by_norm"
self.inputs = {
'X': input,
}
self.attrs = {}
self.attrs['max_norm'] = self.max_norm
norm = np.sqrt(np.sum(np.square(input)))
if norm > self.max_norm:
output = self.max_norm * input / norm
else:
output = input
self.outputs = {'Out': output}
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
self.place = core.CUDAPlace(0)
def test_check_output(self):
self.check_output_with_place(self.place)
def initTestCase(self):
self.shape = (100,)
self.max_norm = 1.0
def init_dtype(self):
self.dtype = np.uint16
self.np_dtype = np.float32
class TestClipByNormOpWithSelectedRows(unittest.TestCase):
def check_with_place(self, place):
self.config_test_case()
......
......@@ -15,10 +15,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
from numpy.linalg import multi_dot
import paddle
from paddle.fluid import core
paddle.enable_static()
......@@ -49,6 +50,53 @@ class TestMultiDotOp(OpTest):
self.check_grad(['x1'], 'Out')
class TestMultiDotFP16Op(TestMultiDotOp):
def get_dtype(self):
return "float16"
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not compiled with CUDA or not support bfloat16",
)
class TestMultiDotBF16Op(OpTest):
def setUp(self):
self.op_type = "multi_dot"
self.python_api = paddle.linalg.multi_dot
self.dtype = self.get_dtype()
self.get_inputs_and_outputs()
self.place = core.CUDAPlace(0)
def get_dtype(self):
self.np_dtype = "float32"
return np.uint16
def get_inputs_and_outputs(self):
self.A = np.random.random((2, 8)).astype(self.np_dtype)
self.B = np.random.random((8, 4)).astype(self.np_dtype)
self.inputs = {
'X': [
('x0', convert_float_to_uint16(self.A)),
('x1', convert_float_to_uint16(self.B)),
]
}
self.outputs = {
'Out': convert_float_to_uint16(multi_dot([self.A, self.B]))
}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['x0'], 'Out', numeric_grad_delta=0.01
)
self.check_grad_with_place(
self.place, ['x1'], 'Out', numeric_grad_delta=0.01
)
# (A*B)*C
class TestMultiDotOp3Mat(TestMultiDotOp):
def get_inputs_and_outputs(self):
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest, skip_check_grad_ci
from eager_op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
import paddle
import paddle.nn.functional as F
......@@ -174,7 +174,11 @@ class PReluTest(OpTest):
self.op_type = "prelu"
self.python_api = prelu_api_wrapper
x_np = np.random.uniform(-1, 1, self.x_shape).astype(self.dtype)
if self.dtype == np.uint16:
as_type = self.np_dtype
else:
as_type = self.dtype
x_np = np.random.uniform(-1, 1, self.x_shape).astype(as_type)
# Since zero point in prelu is not differentiable, avoid randomize
# zero.
x_np[np.abs(x_np) < 0.005] = 0.02
......@@ -190,7 +194,7 @@ class PReluTest(OpTest):
alpha_np = np.random.uniform(-1, -0.5, [1, 1, 1, self.x_shape[-1]])
else:
alpha_np = np.random.uniform(-1, -0.5, [1] + self.x_shape[1:])
alpha_np = alpha_np.astype(self.dtype)
alpha_np = alpha_np.astype(as_type)
self.inputs = {'X': x_np, 'Alpha': alpha_np}
......@@ -393,18 +397,48 @@ def create_test_fp16_class(
def test_check_grad(self):
place = core.CUDAPlace(0)
if core.is_float16_supported(place) and check_grad:
self.check_grad_with_place(
place,
['X', 'Alpha'],
'Out',
max_relative_error=max_relative_error,
)
# Use the default max_relative_error, not use max_relative_error
self.check_grad_with_place(place, ['X', 'Alpha'], 'Out')
cls_name = "{}_{}".format(parent.__name__, "Fp16Op")
TestPReluFp16Case.__name__ = cls_name
globals()[cls_name] = TestPReluFp16Case
def create_test_bf16_class(
parent, check_grad=True, atol=1e-3, max_relative_error=0.05
):
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestPReluBF16Op(parent):
def setUp(self):
super().setUp()
self.inputs['X'] = convert_float_to_uint16(self.inputs['X'])
self.inputs['Alpha'] = convert_float_to_uint16(self.inputs['Alpha'])
self.outputs['Out'] = convert_float_to_uint16(self.outputs['Out'])
def init_dtype(self):
self.dtype = np.uint16
self.np_dtype = np.float32
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, atol=atol)
def test_check_grad(self):
place = core.CUDAPlace(0)
if check_grad:
# Use the default max_relative_error, not use max_relative_error
self.check_grad_with_place(place, ['X', 'Alpha'], 'Out')
cls_name = "{}_{}".format(parent.__name__, "BF16Op")
TestPReluBF16Op.__name__ = cls_name
globals()[cls_name] = TestPReluBF16Op
create_test_fp16_class(TestModeElt)
create_test_fp16_class(TestModeAllRank3)
create_test_fp16_class(TestModeAllRank6)
......@@ -420,6 +454,21 @@ create_test_fp16_class(TestModeChannelRank6NHWC)
create_test_fp16_class(TestModeElementRank3NHWC)
create_test_fp16_class(TestModeElementRank6NHWC)
create_test_bf16_class(TestModeElt)
create_test_bf16_class(TestModeAllRank3)
create_test_bf16_class(TestModeAllRank6)
create_test_bf16_class(TestModeChannelRank3)
create_test_bf16_class(TestModeChannelRank6)
create_test_bf16_class(TestModeElementRank3)
create_test_bf16_class(TestModeElementRank6)
create_test_bf16_class(TestModeEltNHWC)
create_test_bf16_class(TestModeAllRank3NHWC)
create_test_bf16_class(TestModeAllRank6NHWC)
create_test_bf16_class(TestModeChannelRank3NHWC)
create_test_bf16_class(TestModeChannelRank6NHWC)
create_test_bf16_class(TestModeElementRank3NHWC)
create_test_bf16_class(TestModeElementRank6NHWC)
def prelu_t(x, mode, param_attr=None, name=None, data_format='NCHW'):
helper = fluid.layer_helper.LayerHelper('prelu', **locals())
......
......@@ -63,7 +63,9 @@ def clip_by_norm(x, max_norm, name=None):
return _legacy_C_ops.clip_by_norm(x, 'max_norm', max_norm)
helper = LayerHelper("clip_by_norm", **locals())
check_variable_and_dtype(x, 'X', ['float32', 'float16'], 'clip_by_norm')
check_variable_and_dtype(
x, 'X', ['float16', 'float32', 'uint16'], 'clip_by_norm'
)
check_type(max_norm, 'max_norm', (float), 'clip_by_norm')
if name is None:
......
......@@ -538,10 +538,13 @@ def prelu(x, weight, data_format="NCHW", name=None):
return _C_ops.prelu(x, weight, data_format, mode)
else:
check_variable_and_dtype(
x, 'x', ['float16', 'float32', 'float64'], 'prelu'
x, 'x', ['float16', 'float32', 'float64', 'uint16'], 'prelu'
)
check_variable_and_dtype(
weight, 'weight', ['float16', 'float32', 'float64'], 'prelu'
weight,
'weight',
['float16', 'float32', 'float64', 'uint16'],
'prelu',
)
helper = LayerHelper('prelu', **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......
......@@ -2489,7 +2489,7 @@ def multi_dot(x, name=None):
check_variable_and_dtype(
item,
'x[' + str(id) + ']',
['float16', 'float32', 'float64'],
['float16', 'float32', 'float64', 'uint16'],
'multi_dot',
)
if item.dtype != x[0].dtype:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册