未验证 提交 9dbc8f02 编写于 作者: N niuliling123 提交者: GitHub

Set bf16 black_list and white_list (#55713)

上级 cd450c0a
...@@ -12,22 +12,28 @@ ...@@ -12,22 +12,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# The set of ops that support fp16 calculation and are considered numerically- # The set of ops that support fp16 and bf16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16. # safe and performance-critical. These ops are always converted to fp16 or bf16.
FP16_WHITE_LIST = { WHITE_LIST = {
'conv2d', 'conv2d',
'einsum', 'einsum',
'matmul', 'matmul',
'matmul_v2', 'matmul_v2',
'max_pool2d_with_index', 'max_pool2d_with_index',
'mul', 'mul',
'fused_gemm_epilogue',
}
# The set of ops that support fp16, and bf16 was unsupported.
ONLY_FP16_WHITE_LIST = {
'fake_quantize_dequantize_abs_max', 'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max', 'fake_quantize_dequantize_moving_average_abs_max',
'fused_gemm_epilogue',
'fused_attention', 'fused_attention',
'fused_feedforward', 'fused_feedforward',
} }
FP16_WHITE_LIST = WHITE_LIST | ONLY_FP16_WHITE_LIST
# The set of ops that support fp16 calculation and are considered numerically- # The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops. # dangerous and whose effects may also be observed in downstream ops.
FP16_BLACK_LIST = { FP16_BLACK_LIST = {
...@@ -90,8 +96,8 @@ EXTRA_BLACK_LIST = { ...@@ -90,8 +96,8 @@ EXTRA_BLACK_LIST = {
'scatter', 'scatter',
} }
BF16_WHITE_LIST = {'conv2d', 'einsum', 'matmul_v2'} BF16_WHITE_LIST = WHITE_LIST
BF16_BLACK_LIST = set() BF16_BLACK_LIST = FP16_BLACK_LIST
# At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32. # At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32.
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import copy import copy
from paddle.amp.amp_lists import BF16_WHITE_LIST
from paddle.fluid import core from paddle.fluid import core
from ..fp16_lists import black_list as black_list_fp16 from ..fp16_lists import black_list as black_list_fp16
...@@ -86,33 +87,10 @@ class AutoMixedPrecisionListsBF16: ...@@ -86,33 +87,10 @@ class AutoMixedPrecisionListsBF16:
bf16_initializer_list = {'fill_constant', 'uniform_random'} bf16_initializer_list = {'fill_constant', 'uniform_random'}
# always bf16 # always bf16
bf16_list = { bf16_list = BF16_WHITE_LIST
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
# depends on the prev_op type # depends on the prev_op type
gray_list = { gray_list = gray_list_fp16
'elementwise_add',
'elementwise_sub',
'elementwise_mul',
'elementwise_div',
'relu',
'layer_norm',
'slice',
'concat',
'uniform_random',
'reshape2',
'transpose2',
'pool2d',
'sigmoid',
'cast',
'scale',
'fill_constant',
'split',
}
_, _, _sys_unsupported_bf16_list = core.op_supported_infos( _, _, _sys_unsupported_bf16_list = core.op_supported_infos(
'CPU', core.VarDesc.VarType.BF16 'CPU', core.VarDesc.VarType.BF16
......
...@@ -46,10 +46,10 @@ class AMPTest(unittest.TestCase): ...@@ -46,10 +46,10 @@ class AMPTest(unittest.TestCase):
def test_amp_lists_2(self): def test_amp_lists_2(self):
# 2. w={'tanh'}, b=None # 2. w={'tanh'}, b=None
self.fp32_list.remove('tanh') self.fp32_list.remove('tan')
self.bf16_list.add('tanh') self.bf16_list.add('tan')
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tanh'}) self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tan'})
def test_amp_lists_3(self): def test_amp_lists_3(self):
# 3. w={'lstm'}, b=None # 3. w={'lstm'}, b=None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册