未验证 提交 0c43ce22 编写于 作者: A arlesniak 提交者: GitHub

Update BF16 amp list (#39304)

* amp list updated

* tests updated

* gray list updated

* amp list updated

* test updated
上级 ebd14743
......@@ -83,15 +83,18 @@ class AutoMixedPrecisionListsBF16(object):
bf16_initializer_list = {'fill_constant', 'uniform_random'}
# always bf16
bf16_list = {'elementwise_add', 'mul'}
bf16_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
# depends on the prev_op type
gray_list = {
'cast',
'fill_constant',
'reduce_mean',
'reshape2',
'scale',
'elementwise_add', 'elementwise_sub', 'elementwise_mul', 'elementwise_div',
'relu', 'layer_norm', 'slice', 'concat', 'uniform_random', 'reshape2',
'transpose2', 'pool2d', 'sigmoid', 'cast', 'scale', 'fill_constant', 'split'
}
_, _, _sys_unsupported_bf16_list = core.op_supported_infos(
......
......@@ -57,20 +57,20 @@ class AMPTest(unittest.TestCase):
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'lstm'})
def test_amp_lists_4(self):
# 4. w=None, b={'elementwise_add'}
self.bf16_list.remove('elementwise_add')
self.fp32_list.add('elementwise_add')
# 4. w=None, b={'matmul_v2'}
self.bf16_list.remove('matmul_v2')
self.fp32_list.add('matmul_v2')
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'})
custom_fp32_list={'matmul_v2'})
def test_amp_lists_5(self):
# 5. w=None, b={'elementwise_add'}
self.fp32_list.add('elementwise_add')
self.bf16_list.remove('elementwise_add')
# 5. w=None, b={'matmul_v2'}
self.fp32_list.add('matmul_v2')
self.bf16_list.remove('matmul_v2')
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'elementwise_add'})
custom_fp32_list={'matmul_v2'})
def test_amp_lists_6(self):
# 6. w=None, b={'lstm'}
......
......@@ -19,6 +19,7 @@ import paddle.fluid as fluid
import contextlib
import unittest
import numpy as np
import struct
import paddle.fluid.layers as layers
import paddle.static.amp as amp
from paddle.fluid import core
......@@ -26,6 +27,20 @@ from paddle.fluid import core
paddle.enable_static()
def convert_uint16_to_float(in_list):
if in_list.dtype == np.uint16:
in_list = np.asarray(in_list)
out = np.vectorize(
lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
otypes=[np.float32])(in_list.flat)
return np.reshape(out, in_list.shape)
else:
return in_list
cutf = convert_uint16_to_float
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestModelCastBF16(unittest.TestCase):
......@@ -111,10 +126,13 @@ class TestModelCastBF16(unittest.TestCase):
'tt_bf16': nn_bf16,
},
fetch_list=[ret_bf16, ret, ret_fp32bf16],
amp_fun=lambda prog: amp.bf16.rewrite_program_bf16(prog))
amp_fun=_amp_fun,
startup_prog=startup_prog)
self.assertTrue(np.allclose(static_ret_bf16, static_ret, 1e-2))
self.assertTrue(np.allclose(static_ret_bf16, ret_fp32bf16, 1e-2))
self.assertTrue(
np.allclose(cutf(static_ret_bf16), cutf(static_ret), 1e-2))
self.assertTrue(
np.allclose(cutf(static_ret_bf16), cutf(ret_fp32bf16), 1e-2))
with self.static_graph():
t = layers.data(name='t', shape=[size, size], dtype='float32')
......@@ -141,6 +159,7 @@ class TestModelCastBF16(unittest.TestCase):
self._graph_common(lambda prog: amp.bf16.rewrite_program_bf16(
prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_bf16_list={'elementwise_add'},
custom_fp32_varnames={'elementwise_add_0.tmp_0'})
))
......@@ -149,6 +168,7 @@ class TestModelCastBF16(unittest.TestCase):
prog,
startup_prog,
amp.bf16.AutoMixedPrecisionListsBF16(
custom_bf16_list={'elementwise_add'},
custom_fp32_list={'elementwise_mul'}),
use_bf16_guard=True
), startup_prog=fluid.default_startup_program())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册