diff --git a/paddlex/cv/models/base.py b/paddlex/cv/models/base.py index a61209c15f2a95ab0a47aaeb5119e29b8c4cc475..234193e73d530c734dd3b71bdf74e7536d455a9a 100644 --- a/paddlex/cv/models/base.py +++ b/paddlex/cv/models/base.py @@ -255,7 +255,10 @@ class BaseAPI: if osp.exists(save_dir): os.remove(save_dir) os.makedirs(save_dir) - fluid.save(self.train_prog, osp.join(save_dir, 'model')) + if self.train_prog is not None: + fluid.save(self.train_prog, osp.join(save_dir, 'model')) + else: + fluid.save(self.test_prog, osp.join(save_dir, 'model')) model_info = self.get_model_info() model_info['status'] = self.status with open(