未验证 提交 4f735dbd 编写于 作者: Z Zhang Ting 提交者: GitHub

matmul use fp32 compute_type (#8733)

上级 bdfa1d2f
......@@ -152,9 +152,10 @@ def main(config, device, logger, vdl_writer):
AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
if paddle.is_compiled_with_cuda():
AMP_RELATED_FLAGS_SETTING.update({
'FLAGS_cudnn_batchnorm_spatial_persistent': 1
'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
'FLAGS_gemm_use_half_precision_compute_type': 0,
})
paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
paddle.set_flags(AMP_RELATED_FLAGS_SETTING)
scale_loss = config["Global"].get("scale_loss", 1.0)
use_dynamic_loss_scaling = config["Global"].get(
"use_dynamic_loss_scaling", False)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册