diff --git a/ppcls/static/program.py b/ppcls/static/program.py index 188393d1700f9b7600afe02970a66679598f0223..505a6765ad896e20272381c54f634d9dcb5bf08a 100644 --- a/ppcls/static/program.py +++ b/ppcls/static/program.py @@ -242,11 +242,14 @@ def build(config, mode = "Train" if is_train else "Eval" use_mix = "batch_transform_ops" in config["DataLoader"][mode][ "dataset"] + data_dtype = "float32" + if 'AMP' in config and config["AMP"]["level"] == 'O2': + data_dtype = "float16" feeds = create_feeds( config["Global"]["image_shape"], use_mix, class_num=class_num, - dtype="float32") + dtype=data_dtype) # build model # data_format should be assigned in arch-dict