Created by: chenwhql
PR types
Function optimization
PR changes
APIs
Describe
Refine jit.save implement to adapt InputSpec using cases
使用场景梳理:
- 不对模型进行剪枝
- Layer.forward装饰了to_static
- 训练后存储部署模型
- jit.save(layer, model_path),input_spec可以不指定
- 不训练直接存储部署模型
- InputSpec必须在to_static时指定,jit.save(layer, model_path),input_spec可以不指定
- 训练后存储部署模型
- Layer.forward没装饰to_static
- 训练后存储部署模型
- jit.save(layer, model_path, input_spec=[InputSpec/example_input]): input_spec必须指定
- 不训练直接存储部署模型
- jit.save(layer, model_path, input_spec=[InputSpec/example_input]): input_spec必须指定
- 训练后存储部署模型
- Layer.forward装饰了to_static
- 对模型进行剪枝:
- Layer.forward装饰了to_static,要求to_static传入完整的input_spec
- 训练后存储部署模型
- jit.save(layer, model_path, input_spec=[InputSpec/example_input]): input_spec必须指定
- 不训练直接存储部署模型,裁剪输入需要依赖于concrete_program.outputs
- jit.save(layer, model_path, input_spec=[InputSpec/example_input]): input_spec必须指定
- 训练后存储部署模型
- Layer.forward没装饰to_static
- 不支持剪枝,由于input_spec不全,会在to_static转换的时候报错
- Layer.forward装饰了to_static,要求to_static传入完整的input_spec
TODO:
- 可能需要支持jit.save传入list[var_name]作为input_spec