未验证 提交 1f445bf3 编写于 作者: S sneaxiy 提交者: GitHub

Support FP16 for more ops (#38123)

* support FP16 for more ops

* add amp list tests

* refine reduce_mean_grad

* fix OP benchmark ci

* fix fp16 reduce_mean

* updat ut, but still have some problems

* remove mean/reduce_mean fp16 kernel
上级 f8955602
......@@ -112,4 +112,4 @@ TEST(Analyzer_Resnet50_ipu, compare_results_2_batch) {
}
} // namespace inference
} // namespace paddle
\ No newline at end of file
} // namespace paddle
......@@ -41,12 +41,16 @@ namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
elementwise_min,
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseMinKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
elementwise_min_grad,
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>,
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseMinGradKernel<paddle::platform::CUDADeviceContext, int>,
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/eigen_ext.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -67,6 +68,28 @@ struct MinGradDy {
}
};
#ifdef PADDLE_CUDA_FP16
template <>
struct MinGradDx<platform::float16> {
HOSTDEVICE platform::float16 operator()(platform::float16 x,
platform::float16 y,
platform::float16 out,
platform::float16 dout) const {
return x < y ? dout : static_cast<platform::float16>(0);
}
};
template <>
struct MinGradDy<platform::float16> {
HOSTDEVICE platform::float16 operator()(platform::float16 x,
platform::float16 y,
platform::float16 out,
platform::float16 dout) const {
return x >= y ? dout : static_cast<platform::float16>(0);
}
};
#endif
template <typename DeviceContext, typename T>
class ElementwiseMinGradKernel : public ElemwiseGradKernel<T> {
public:
......
......@@ -17,6 +17,9 @@ from ... import core
__all__ = ["CustomOpLists", "AutoMixedPrecisionLists"]
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_unsupported_fp16_list = {'lookup_table', 'lookup_table_v2'}
class AutoMixedPrecisionLists(object):
"""
......@@ -60,6 +63,8 @@ class AutoMixedPrecisionLists(object):
elif op_name in self.gray_list:
self.gray_list.remove(op_name)
self.white_list.add(op_name)
if op_name in _extra_unsupported_fp16_list:
self.unsupported_list.remove(op_name)
if self._custom_black_list:
for op_name in self._custom_black_list:
if op_name in self.white_list:
......@@ -170,7 +175,6 @@ else:
_, _, _sys_unsupported_fp16_list = core.op_supported_infos(
'GPU', core.VarDesc.VarType.FP16)
unsupported_fp16_list = {'lookup_table',
'lookup_table_v2'} | _sys_unsupported_fp16_list
unsupported_fp16_list = _extra_unsupported_fp16_list | _sys_unsupported_fp16_list
CustomOpLists = AutoMixedPrecisionLists
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import unittest
from paddle.fluid.contrib.mixed_precision.fp16_lists import AutoMixedPrecisionLists
class TestAMPList(unittest.TestCase):
def test_main(self):
custom_white_list = [
'lookup_table',
'lookup_table_v2',
]
amp_list = AutoMixedPrecisionLists(custom_white_list=custom_white_list)
for op in custom_white_list:
self.assertTrue(op in amp_list.white_list)
self.assertTrue(op not in amp_list.black_list)
self.assertTrue(op not in amp_list.unsupported_list)
if __name__ == "__main__":
unittest.main()
......@@ -17,6 +17,11 @@ from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
paddle.enable_static()
class TestElementwiseOp(OpTest):
......@@ -142,5 +147,54 @@ class TestElementwiseMinOp_broadcast_4(TestElementwiseOp):
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
class TestElementwiseMinOpFP16(unittest.TestCase):
def get_out_and_grad(self, x_np, y_np, axis, place, use_fp32=False):
assert x_np.dtype == np.float16
assert y_np.dtype == np.float16
if use_fp32:
x_np = x_np.astype(np.float32)
y_np = y_np.astype(np.float32)
dtype = np.float16
with fluid.dygraph.guard(place):
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
x.stop_gradient = False
y.stop_gradient = False
z = fluid.layers.elementwise_min(x, y, axis)
x_g, y_g = paddle.grad([z], [x, y])
return z.numpy().astype(dtype), x_g.numpy().astype(
dtype), y_g.numpy().astype(dtype)
def check_main(self, x_shape, y_shape, axis=-1):
if not paddle.is_compiled_with_cuda():
return
place = paddle.CUDAPlace(0)
if not core.is_float16_supported(place):
return
x_np = np.random.random(size=x_shape).astype(np.float16)
y_np = np.random.random(size=y_shape).astype(np.float16)
z_1, x_g_1, y_g_1 = self.get_out_and_grad(x_np, y_np, axis, place,
False)
z_2, x_g_2, y_g_2 = self.get_out_and_grad(x_np, y_np, axis, place, True)
self.assertTrue(np.array_equal(z_1, z_2), "{} vs {}".format(z_1, z_2))
self.assertTrue(
np.array_equal(x_g_1, x_g_2), "{} vs {}".format(x_g_1, x_g_2))
self.assertTrue(
np.array_equal(y_g_1, y_g_2), "{} vs {}".format(y_g_1, y_g_2))
def test_main(self):
self.check_main((13, 17), (13, 17))
self.check_main((10, 3, 4), (1, ))
self.check_main((100, ), (100, ))
self.check_main((100, 3, 2), (100, ), 0)
self.check_main((2, 100, 3), (100, ), 1)
self.check_main((2, 3, 100), (100, ))
self.check_main((2, 25, 4, 1), (25, 4), 1)
self.check_main((2, 10, 2, 5), (2, 10, 1, 5))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册