未验证 提交 663b5083 编写于 作者: Z zhouzj 提交者: GitHub

[ACT] fix mkdirs on distributed training. (#1666)

* fix bugs

* fix dead links

* fix bug
上级 e33dc481
...@@ -255,7 +255,7 @@ ac.compress() ...@@ -255,7 +255,7 @@ ac.compress()
## 进阶使用 ## 进阶使用
- ACT可以自动处理常见的预测模型,如果有更特殊的改造需求,可以参考[ACT超参配置教程](./hyperparameter_tutorial.md)来进行单独配置压缩策略。 - ACT可以自动处理常见的预测模型,如果有更特殊的改造需求,可以参考[ACT超参配置教程](./hyperparameter_tutorial.md)来进行单独配置压缩策略。
- ACT接口各个参数详细含义可以参考 [ACT API文档](../docs/zh_cn/api_cn/static/auto-compression/auto_compression_api.rst) - ACT接口各个参数详细含义可以参考 [ACT API文档](../../docs/zh_cn/api_cn/static/auto-compression/auto_compression_api.rst)
## 社区交流 ## 社区交流
......
...@@ -127,7 +127,7 @@ class AutoCompression: ...@@ -127,7 +127,7 @@ class AutoCompression:
self.final_dir = save_dir self.final_dir = save_dir
if not os.path.exists(self.final_dir): if not os.path.exists(self.final_dir):
os.makedirs(self.final_dir) os.makedirs(self.final_dir, exist_ok=True)
# load config # load config
if isinstance(config, str): if isinstance(config, str):
...@@ -263,7 +263,7 @@ class AutoCompression: ...@@ -263,7 +263,7 @@ class AutoCompression:
op.desc.infer_shape(block.desc) op.desc.infer_shape(block.desc)
save_path = os.path.join(save_path, "infered_shape") save_path = os.path.join(save_path, "infered_shape")
os.makedirs(save_path) os.makedirs(save_path, exist_ok=True)
paddle.static.save_inference_model( paddle.static.save_inference_model(
save_path, save_path,
feed_vars, feed_vars,
...@@ -763,8 +763,13 @@ class AutoCompression: ...@@ -763,8 +763,13 @@ class AutoCompression:
inference_program, feed_target_names, fetch_targets, patterns, inference_program, feed_target_names, fetch_targets, patterns,
strategy, config, train_config) strategy, config, train_config)
if 'unstructure' in strategy: if 'unstructure' in strategy:
if isinstance(test_program_info.program,
paddle.static.CompiledProgram):
test_program_info.program._program = remove_unused_var_nodes( test_program_info.program._program = remove_unused_var_nodes(
test_program_info.program._program) test_program_info.program._program)
else:
test_program_info.program = remove_unused_var_nodes(
test_program_info.program)
test_program_info = self._start_train( test_program_info = self._start_train(
train_program_info, test_program_info, strategy, train_config) train_program_info, test_program_info, strategy, train_config)
if paddle.distributed.get_rank() == 0: if paddle.distributed.get_rank() == 0:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册