提交 e7bef51f 编写于 作者: Z zhangting2020 提交者: Tingquan Gao

fix data dtype for amp training

上级 731006f1
...@@ -242,11 +242,14 @@ def build(config, ...@@ -242,11 +242,14 @@ def build(config,
mode = "Train" if is_train else "Eval" mode = "Train" if is_train else "Eval"
use_mix = "batch_transform_ops" in config["DataLoader"][mode][ use_mix = "batch_transform_ops" in config["DataLoader"][mode][
"dataset"] "dataset"]
data_dtype = "float32"
if 'AMP' in config and config["AMP"]["level"] == 'O2':
data_dtype = "float16"
feeds = create_feeds( feeds = create_feeds(
config["Global"]["image_shape"], config["Global"]["image_shape"],
use_mix, use_mix,
class_num=class_num, class_num=class_num,
dtype="float32") dtype=data_dtype)
# build model # build model
# data_format should be assigned in arch-dict # data_format should be assigned in arch-dict
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册