未验证 提交 9dbc8f02 编写于 作者: N niuliling123 提交者: GitHub

Set bf16 black_list and white_list (#55713)

上级 cd450c0a
......@@ -12,22 +12,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# The set of ops that support fp16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16.
FP16_WHITE_LIST = {
# The set of ops that support fp16 and bf16 calculation and are considered numerically-
# safe and performance-critical. These ops are always converted to fp16 or bf16.
WHITE_LIST = {
'conv2d',
'einsum',
'matmul',
'matmul_v2',
'max_pool2d_with_index',
'mul',
'fused_gemm_epilogue',
}
# The set of ops that support fp16, and bf16 was unsupported.
ONLY_FP16_WHITE_LIST = {
'fake_quantize_dequantize_abs_max',
'fake_quantize_dequantize_moving_average_abs_max',
'fused_gemm_epilogue',
'fused_attention',
'fused_feedforward',
}
FP16_WHITE_LIST = WHITE_LIST | ONLY_FP16_WHITE_LIST
# The set of ops that support fp16 calculation and are considered numerically-
# dangerous and whose effects may also be observed in downstream ops.
FP16_BLACK_LIST = {
......@@ -90,8 +96,8 @@ EXTRA_BLACK_LIST = {
'scatter',
}
BF16_WHITE_LIST = {'conv2d', 'einsum', 'matmul_v2'}
BF16_BLACK_LIST = set()
BF16_WHITE_LIST = WHITE_LIST
BF16_BLACK_LIST = FP16_BLACK_LIST
# At OD level, ops in WHITE_LIST will use FP16/BF16 and the others will use FP32.
......
......@@ -14,6 +14,7 @@
import copy
from paddle.amp.amp_lists import BF16_WHITE_LIST
from paddle.fluid import core
from ..fp16_lists import black_list as black_list_fp16
......@@ -86,33 +87,10 @@ class AutoMixedPrecisionListsBF16:
bf16_initializer_list = {'fill_constant', 'uniform_random'}
# always bf16
bf16_list = {
'conv2d',
'matmul',
'matmul_v2',
'mul',
}
bf16_list = BF16_WHITE_LIST
# depends on the prev_op type
gray_list = {
'elementwise_add',
'elementwise_sub',
'elementwise_mul',
'elementwise_div',
'relu',
'layer_norm',
'slice',
'concat',
'uniform_random',
'reshape2',
'transpose2',
'pool2d',
'sigmoid',
'cast',
'scale',
'fill_constant',
'split',
}
gray_list = gray_list_fp16
_, _, _sys_unsupported_bf16_list = core.op_supported_infos(
'CPU', core.VarDesc.VarType.BF16
......
......@@ -46,10 +46,10 @@ class AMPTest(unittest.TestCase):
def test_amp_lists_2(self):
# 2. w={'tanh'}, b=None
self.fp32_list.remove('tanh')
self.bf16_list.add('tanh')
self.fp32_list.remove('tan')
self.bf16_list.add('tan')
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tanh'})
self.amp_lists_ = amp.bf16.AutoMixedPrecisionListsBF16({'tan'})
def test_amp_lists_3(self):
# 3. w={'lstm'}, b=None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册