提交 d5637367 编写于 作者: W weishengyu

add export method

上级 932e0eac
......@@ -57,18 +57,6 @@ def build_gallery_layer(configs, feature_extractor):
return gallery_layer
def export_fuse_model(model, config):
model.eval()
model.quanter.save_quantized_model(
model.base_model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + config["Global"]["image_shape"],
dtype='float32')
])
class FuseModel(paddle.nn.Layer):
def __init__(self, configs):
super().__init__()
......@@ -85,12 +73,25 @@ class FuseModel(paddle.nn.Layer):
return x
def export_fuse_model(configs):
fuse_model = FuseModel(configs)
fuse_model.eval()
save_path = configs["Global"]["save_inference_dir"]
fuse_model.quanter.save_quantized_model(
fuse_model,
save_path,
input_spec=[
paddle.static.InputSpec(
shape=[None] + configs["Global"]["image_shape"],
dtype='float32')
])
def main():
args = parse_args()
configs = parse_config(args.config)
init_logger(name='gallery2fc')
fuse_model = FuseModel(configs)
# save_fuse_model(fuse_model)
export_fuse_model(configs)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册