未验证 提交 3a9d38b5 编写于 作者: S shangliang Xu 提交者: GitHub

[ce test] fix amp bug in ce test (#4196)

上级 2d8d2430
...@@ -30,8 +30,8 @@ def add_coord(x, data_format): ...@@ -30,8 +30,8 @@ def add_coord(x, data_format):
else: else:
h, w = x.shape[1], x.shape[2] h, w = x.shape[1], x.shape[2]
gx = paddle.arange(w, dtype=x.dtype) / ((w - 1.) * 2.0) - 1. gx = paddle.cast(paddle.arange(w) / ((w - 1.) * 2.0) - 1., x.dtype)
gy = paddle.arange(h, dtype=x.dtype) / ((h - 1.) * 2.0) - 1. gy = paddle.cast(paddle.arange(h) / ((h - 1.) * 2.0) - 1., x.dtype)
if data_format == 'NCHW': if data_format == 'NCHW':
gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w]) gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w])
......
...@@ -283,6 +283,9 @@ else ...@@ -283,6 +283,9 @@ else
run_train=${pact_trainer} run_train=${pact_trainer}
run_export=${pact_export} run_export=${pact_export}
flag_quant=True flag_quant=True
if [ ${autocast} = "amp" ]; then
continue
fi
elif [ ${trainer} = "${fpgm_key}" ]; then elif [ ${trainer} = "${fpgm_key}" ]; then
run_train=${fpgm_trainer} run_train=${fpgm_trainer}
run_export=${fpgm_export} run_export=${fpgm_export}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册