未验证 提交 3a9d38b5 编写于 作者: S shangliang Xu 提交者: GitHub

[ce test] fix amp bug in ce test (#4196)

上级 2d8d2430
......@@ -30,8 +30,8 @@ def add_coord(x, data_format):
else:
h, w = x.shape[1], x.shape[2]
gx = paddle.arange(w, dtype=x.dtype) / ((w - 1.) * 2.0) - 1.
gy = paddle.arange(h, dtype=x.dtype) / ((h - 1.) * 2.0) - 1.
gx = paddle.cast(paddle.arange(w) / ((w - 1.) * 2.0) - 1., x.dtype)
gy = paddle.cast(paddle.arange(h) / ((h - 1.) * 2.0) - 1., x.dtype)
if data_format == 'NCHW':
gx = gx.reshape([1, 1, 1, w]).expand([b, 1, h, w])
......
......@@ -283,6 +283,9 @@ else
run_train=${pact_trainer}
run_export=${pact_export}
flag_quant=True
if [ ${autocast} = "amp" ]; then
continue
fi
elif [ ${trainer} = "${fpgm_key}" ]; then
run_train=${fpgm_trainer}
run_export=${fpgm_export}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册