diff --git a/official/projects/qat/vision/configs/image_classification.py b/official/projects/qat/vision/configs/image_classification.py index 08e01cb65fbbaa800935d8028c4b31f0619f7a8b..b3f0356970c99551771084a5e60657919e8c77fa 100644 --- a/official/projects/qat/vision/configs/image_classification.py +++ b/official/projects/qat/vision/configs/image_classification.py @@ -35,6 +35,8 @@ def image_classification_imagenet() -> cfg.ExperimentConfig: task = ImageClassificationTask.from_args( quantization=common.Quantization(), **config.task.as_dict()) config.task = task + runtime = cfg.RuntimeConfig(enable_xla=False) + config.runtime = runtime return config