提交 d6137854 编写于 作者: B Bobholamovic 提交者: cuicheng01

Accommodate UAPI

上级 653e2cd8
......@@ -30,6 +30,7 @@ AMP:
Arch:
name: MobileNetV3_small_x1_0
class_num: 1000
pretrained: True
# loss function config for traing/eval process
Loss:
......
......@@ -31,6 +31,7 @@ AMP:
Arch:
name: PPHGNet_small
class_num: 1000
pretrained: True
# loss function config for traing/eval process
Loss:
......
......@@ -31,6 +31,7 @@ AMP:
Arch:
name: PPHGNet_tiny
class_num: 1000
pretrained: True
# loss function config for traing/eval process
Loss:
......
......@@ -29,6 +29,7 @@ AMP:
Arch:
name: PPLCNet_x1_0
class_num: 1000
pretrained: True
# loss function config for traing/eval process
Loss:
......
......@@ -32,6 +32,7 @@ AMP:
Arch:
name: ResNet50
class_num: 1000
pretrained: True
# loss function config for traing/eval process
Loss:
......
......@@ -32,6 +32,7 @@ AMP:
Arch:
name: SwinTransformer_base_patch4_window7_224
class_num: 1000
pretrained: True
# loss function config for traing/eval process
Loss:
......
# for quantizaiton or prune model
Slim:
## for prune
quant:
name: pact
\ No newline at end of file
......@@ -15,6 +15,7 @@ from __future__ import division
from __future__ import print_function
import os
import shutil
import platform
import paddle
import paddle.distributed as dist
......@@ -72,8 +73,7 @@ class Engine(object):
# init logger
self.output_dir = self.config['Global']['output_dir']
log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
f"{mode}.log")
log_file = os.path.join(self.output_dir, f"{mode}.log")
init_logger(log_file=log_file)
print_config(config)
......@@ -519,6 +519,10 @@ class Engine(object):
save_path + "_int8")
else:
paddle.jit.save(model, save_path)
if self.config["Global"].get("export_for_fd", False):
src_path = self.config["Global"]["infer_config_path"]
dst_path = os.path.join(self.config["Global"]["save_inference_dir"], 'inference.yml')
shutil.copy(src_path, dst_path)
logger.info(
f"Export succeeded! The inference model exported has been saved in \"{self.config['Global']['save_inference_dir']}\"."
)
......
......@@ -163,7 +163,11 @@ def save_model(net,
"""
if paddle.distributed.get_rank() != 0:
return
model_path = os.path.join(model_path, model_name)
if prefix == 'best_model':
uapi_best_model_path = os.path.join(model_path, 'best_model')
_mkdir_if_not_exist(uapi_best_model_path)
_mkdir_if_not_exist(model_path)
model_path = os.path.join(model_path, prefix)
......@@ -182,6 +186,11 @@ def save_model(net,
paddle.save(s_params, model_path + "_student.pdparams")
paddle.save(params_state_dict, model_path + ".pdparams")
if prefix == 'best_model':
uapi_best_model_path = os.path.join(uapi_best_model_path, 'model')
paddle.save(params_state_dict, uapi_best_model_path + ".pdparams")
if ema is not None:
paddle.save(ema.state_dict(), model_path + ".ema.pdparams")
paddle.save([opt.state_dict() for opt in optimizer], model_path + ".pdopt")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册