提交 e7bef51f 编写于 作者: Z zhangting2020 提交者: Tingquan Gao

fix data dtype for amp training

上级 731006f1
......@@ -242,11 +242,14 @@ def build(config,
mode = "Train" if is_train else "Eval"
use_mix = "batch_transform_ops" in config["DataLoader"][mode][
"dataset"]
data_dtype = "float32"
if 'AMP' in config and config["AMP"]["level"] == 'O2':
data_dtype = "float16"
feeds = create_feeds(
config["Global"]["image_shape"],
use_mix,
class_num=class_num,
dtype="float32")
dtype=data_dtype)
# build model
# data_format should be assigned in arch-dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册