未验证 提交 285f3a71 编写于 作者: Z zhouzj 提交者: GitHub

[auto-compression] solve prediction problems and fix docs. (#1142)

* solve prediction problems.

* fix docs.
上级 2695a087
......@@ -45,7 +45,8 @@ pip install paddlepaddle-gpu
安装paddleslim:
```shell
pip install paddleslim
https://github.com/PaddlePaddle/PaddleSlim.git
python setup.py install
```
安装paddledet:
......
......@@ -43,7 +43,8 @@ pip install paddlepaddle-gpu
安装paddleslim:
```shell
pip install paddleslim
https://github.com/PaddlePaddle/PaddleSlim.git
python setup.py install
```
#### 3.2 准备数据集
......
......@@ -56,7 +56,8 @@ pip install paddlepaddle-gpu
安装paddleslim:
```shell
pip install paddleslim
https://github.com/PaddlePaddle/PaddleSlim.git
python setup.py install
```
安装paddlenlp:
......
......@@ -48,7 +48,8 @@ pip install paddlepaddle-gpu
安装paddleslim:
```shell
pip install paddleslim
https://github.com/PaddlePaddle/PaddleSlim.git
python setup.py install
```
安装paddleseg
......
......@@ -27,9 +27,9 @@ def predict_compressed_model(model_dir,
latency_dict(dict): The latency latency of the model under various compression strategies.
"""
local_rank = paddle.distributed.get_rank()
quant_model_path = f'quant_model/rank_{local_rank}'
prune_model_path = f'prune_model/rank_{local_rank}'
sparse_model_path = f'sparse_model/rank_{local_rank}'
quant_model_path = f'quant_model_rank_{local_rank}_tmp'
prune_model_path = f'prune_model_rank_{local_rank}_tmp'
sparse_model_path = f'sparse_model_rank_{local_rank}_tmp'
latency_dict = {}
......@@ -116,7 +116,7 @@ def predict_compressed_model(model_dir,
model_dir=sparse_model_path,
model_filename=model_filename,
params_filename=params_filename,
save_model_path='quant_model',
save_model_path=quant_model_path,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
activation_bits=8,
......@@ -131,10 +131,10 @@ def predict_compressed_model(model_dir,
latency_dict.update({f'sparse_{sparse_ratio}_int8': latency})
# NOTE: Delete temporary model files
if os.path.exists('quant_model'):
shutil.rmtree('quant_model', ignore_errors=True)
if os.path.exists('prune_model'):
shutil.rmtree('prune_model', ignore_errors=True)
if os.path.exists('sparse_model'):
shutil.rmtree('sparse_model', ignore_errors=True)
if os.path.exists(quant_model_path):
shutil.rmtree(quant_model_path, ignore_errors=True)
if os.path.exists(prune_model_path):
shutil.rmtree(prune_model_path, ignore_errors=True)
if os.path.exists(sparse_model_path):
shutil.rmtree(sparse_model_path, ignore_errors=True)
return latency_dict
......@@ -122,7 +122,7 @@ def get_prune_model(model_file, param_file, ratio, save_path):
main_prog = static.Program()
startup_prog = static.Program()
place = paddle.CPUPlace()
exe = paddle.static.Executor()
exe = paddle.static.Executor(place)
scope = static.global_scope()
exe.run(startup_prog)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册