未验证 提交 283c8916 编写于 作者: L LielinJiang 提交者: GitHub

add export model doc (#472)

上级 ee9fae9d
# PaddleGAN模型导出教程
## 一、模型导出
本章节介绍如何使用`tools/export_model.py`脚本导出模型。
### 1、启动参数说明
| FLAG | 用途 | 默认值 | 备注 |
|:--------------:|:--------------:|:------------:|:-----------------------------------------:|
| -c | 指定配置文件 | None | |
| --load | 指定加载的模型参数路径 | None | |
| -s|--inputs_size | 指定模型输入形状 | None | |
| --output_dir | 模型保存路径 | `./inference_model` | |
### 2、使用示例
使用训练得到的模型进行试用,这里使用CycleGAN模型为例,脚本如下
```bash
# 下载预训练好的CycleGAN_horse2zebra模型
wget https://paddlegan.bj.bcebos.com/models/CycleGAN_horse2zebra.pdparams
# 导出Cylclegan模型
python -u tools/export_model.py -c configs/cyclegan_horse2zebra.yaml --load CycleGAN_horse2zebra.pdparams --inputs_size="-1,3,-1,-1;-1,3,-1,-1"
```
### 3、config配置说明
```python
export_model:
- {name: 'netG_A', inputs_num: 1}
- {name: 'netG_B', inputs_num: 1}
```
以上为```configs/cyclegan_horse2zebra.yaml```中的配置, 由于```CycleGAN_horse2zebra.pdparams```是个字典,需要制定其中用于导出模型的权重键值。```inputs_num```
为该网络的输入个数。
预测模型会导出到`inference_model/`目录下,分别为`cycleganmodel_netG_A.pdiparams`, `cycleganmodel_netG_A.pdiparams.info`, `cycleganmodel_netG_A.pdmodel`, `cycleganmodel_netG_B.pdiparams`, `cycleganmodel_netG_B.pdiparams.info`, `cycleganmodel_netG_B.pdmodel`,。
......@@ -186,12 +186,18 @@ class BaseModel(ABC):
def export_model(self, export_model, output_dir=None, inputs_size=[]):
inputs_num = 0
for net in export_model:
input_spec = [paddle.static.InputSpec(
shape=inputs_size[inputs_num + i], dtype="float32") for i in range(net["inputs_num"])]
input_spec = [
paddle.static.InputSpec(shape=inputs_size[inputs_num + i],
dtype="float32")
for i in range(net["inputs_num"])
]
inputs_num = inputs_num + net["inputs_num"]
static_model = paddle.jit.to_static(self.nets[net["name"]],
input_spec=input_spec)
if output_dir is None:
output_dir = 'export_model'
paddle.jit.save(static_model, os.path.join(
output_dir, '{}_{}'.format(self.__class__.__name__.lower(), net["name"])))
output_dir = 'inference_model'
paddle.jit.save(
static_model,
os.path.join(
output_dir, '{}_{}'.format(self.__class__.__name__.lower(),
net["name"])))
......@@ -24,11 +24,6 @@ from ppgan.engine.trainer import Trainer
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--export_model",
default=None,
type=str,
help="The path prefix of inference model to be used.", )
parser.add_argument('-c',
'--config-file',
metavar="FILE",
......@@ -50,6 +45,12 @@ def parse_args():
default=None,
required=True,
help="the inputs size")
parser.add_argument(
"--output_dir",
default=None,
type=str,
help="The path prefix of inference model to be used.",
)
args = parser.parse_args()
return args
......@@ -63,7 +64,7 @@ def main(args, cfg):
for net_name, net in model.nets.items():
if net_name in state_dicts:
net.set_state_dict(state_dicts[net_name])
model.export_model(cfg.export_model, args.export_model, inputs_size)
model.export_model(cfg.export_model, args.output_dir, inputs_size)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册