未验证 提交 0c43ce22 编写于 作者: A arlesniak 提交者: GitHub

Update BF16 amp list (#39304)

* amp list updated

* tests updated

* gray list updated

* amp list updated

* test updated
上级 ebd14743
...@@ -83,15 +83,18 @@ class AutoMixedPrecisionListsBF16(object): ...@@ -83,15 +83,18 @@ class AutoMixedPrecisionListsBF16(object):
bf16_initializer_list = {'fill_constant', 'uniform_random'} bf16_initializer_list = {'fill_constant', 'uniform_random'}
# always bf16 # always bf16
bf16_list = {'elementwise_add', 'mul'} bf16_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
# depends on the prev_op type # depends on the prev_op type
gray_list = { gray_list = {
'cast', 'elementwise_add', 'elementwise_sub', 'elementwise_mul', 'elementwise_div',
'fill_constant', 'relu', 'layer_norm', 'slice', 'concat', 'uniform_random', 'reshape2',
'reduce_mean', 'transpose2', 'pool2d', 'sigmoid', 'cast', 'scale', 'fill_constant', 'split'
'reshape2',
'scale',
} }
_, _, _sys_unsupported_bf16_list = core.op_supported_infos( _, _, _sys_unsupported_bf16_list = core.op_supported_infos(
......
...@@ -57,20 +57,20 @@ class AMPTest(unittest.TestCase): ...@@ -57,20 +57,20 @@ class AMPTest(unittest.TestCase):
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'lstm'}) self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'lstm'})
def test_amp_lists_4(self): def test_amp_lists_4(self):
# 4. w=None, b={'elementwise_add'} # 4. w=None, b={'matmul_v2'}
self.bf16_list.remove('elementwise_add') self.bf16_list.remove('matmul_v2')
self.fp32_list.add('elementwise_add') self.fp32_list.add('matmul_v2')
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16( self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'}) custom_fp32_list={'matmul_v2'})
def test_amp_lists_5(self): def test_amp_lists_5(self):
# 5. w=None, b={'elementwise_add'} # 5. w=None, b={'matmul_v2'}
self.fp32_list.add('elementwise_add') self.fp32_list.add('matmul_v2')
self.bf16_list.remove('elementwise_add') self.bf16_list.remove('matmul_v2')
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16( self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'}) custom_fp32_list={'matmul_v2'})
def test_amp_lists_6(self): def test_amp_lists_6(self):
# 6. w=None, b={'lstm'} # 6. w=None, b={'lstm'}
......
...@@ -19,6 +19,7 @@ import paddle.fluid as fluid ...@@ -19,6 +19,7 @@ import paddle.fluid as fluid
import contextlib import contextlib
import unittest import unittest
import numpy as np import numpy as np
import struct
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
import paddle.static.amp as amp import paddle.static.amp as amp
from paddle.fluid import core from paddle.fluid import core
...@@ -26,6 +27,20 @@ from paddle.fluid import core ...@@ -26,6 +27,20 @@ from paddle.fluid import core
paddle.enable_static() paddle.enable_static()
def convert_uint16_to_float(in_list):
if in_list.dtype == np.uint16:
in_list = np.asarray(in_list)
out = np.vectorize(
lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
otypes=[np.float32])(in_list.flat)
return np.reshape(out, in_list.shape)
else:
return in_list
cutf = convert_uint16_to_float
@unittest.skipIf(not core.supports_bfloat16(), @unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation") "place does not support BF16 evaluation")
class TestModelCastBF16(unittest.TestCase): class TestModelCastBF16(unittest.TestCase):
...@@ -111,10 +126,13 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -111,10 +126,13 @@ class TestModelCastBF16(unittest.TestCase):
'tt_bf16': nn_bf16, 'tt_bf16': nn_bf16,
}, },
fetch_list=[ret_bf16, ret, ret_fp32bf16], fetch_list=[ret_bf16, ret, ret_fp32bf16],
amp_fun=lambda prog: amp.bf16.rewrite_program_bf16(prog)) amp_fun=_amp_fun,
startup_prog=startup_prog)
self.assertTrue(np.allclose(static_ret_bf16, static_ret, 1e-2)) self.assertTrue(
self.assertTrue(np.allclose(static_ret_bf16, ret_fp32bf16, 1e-2)) np.allclose(cutf(static_ret_bf16), cutf(static_ret), 1e-2))
self.assertTrue(
np.allclose(cutf(static_ret_bf16), cutf(ret_fp32bf16), 1e-2))
with self.static_graph(): with self.static_graph():
t = layers.data(name='t', shape=[size, size], dtype='float32') t = layers.data(name='t', shape=[size, size], dtype='float32')
...@@ -141,6 +159,7 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -141,6 +159,7 @@ class TestModelCastBF16(unittest.TestCase):
self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16( self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16(
prog, prog,
amp.bf16.AutoMixedPrecisionListsBF16( amp.bf16.AutoMixedPrecisionListsBF16(
custom_bf16_list={'elementwise_add'},
custom_fp32_varnames={'elementwise_add_0.tmp_0'}) custom_fp32_varnames={'elementwise_add_0.tmp_0'})
)) ))
...@@ -149,6 +168,7 @@ class TestModelCastBF16(unittest.TestCase): ...@@ -149,6 +168,7 @@ class TestModelCastBF16(unittest.TestCase):
prog, prog,
startup_prog, startup_prog,
amp.bf16.AutoMixedPrecisionListsBF16( amp.bf16.AutoMixedPrecisionListsBF16(
custom_bf16_list={'elementwise_add'},
custom_fp32_list={'elementwise_mul'}), custom_fp32_list={'elementwise_mul'}),
use_bf16_guard=True use_bf16_guard=True
), startup_prog=fluid.default_startup_program()) ), startup_prog=fluid.default_startup_program())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册