未验证 提交 7790b333 编写于 作者: W whs 提交者: GitHub

Add code demo for pruning api. (#54)

上级 66497ea6
...@@ -307,6 +307,28 @@ paddleslim.prune.merge_sensitive(sensitivities)[源代码](https://github.com/Pa ...@@ -307,6 +307,28 @@ paddleslim.prune.merge_sensitive(sensitivities)[源代码](https://github.com/Pa
示例: 示例:
```
from paddleslim.prune import merge_sensitive
sen0 = {"weight_0":
{0.1: 0.22,
0.2: 0.33
},
"weight_1":
{0.1: 0.21,
0.2: 0.4
}
}
sen1 = {"weight_0":
{0.3: 0.41,
},
"weight_2":
{0.1: 0.10,
0.2: 0.35
}
}
sensitivities = merge_sensitive([sen0, sen1])
print(sensitivities)
```
## load_sensitivities ## load_sensitivities
paddleslim.prune.load_sensitivities(sensitivities_file)[源代码](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L184) paddleslim.prune.load_sensitivities(sensitivities_file)[源代码](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L184)
...@@ -323,6 +345,24 @@ paddleslim.prune.load_sensitivities(sensitivities_file)[源代码](https://githu ...@@ -323,6 +345,24 @@ paddleslim.prune.load_sensitivities(sensitivities_file)[源代码](https://githu
示例: 示例:
```
import pickle
from paddleslim.prune import load_sensitivities
sen = {"weight_0":
{0.1: 0.22,
0.2: 0.33
},
"weight_1":
{0.1: 0.21,
0.2: 0.4
}
}
sensitivities_file = "sensitive_api_demo.data"
with open(sensitivities_file, 'w') as f:
pickle.dump(sen, f)
sensitivities = load_sensitivities(sensitivities_file)
print(sensitivities)
```
## get_ratios_by_loss ## get_ratios_by_loss
paddleslim.prune.get_ratios_by_loss(sensitivities, loss)[源代码](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L206) paddleslim.prune.get_ratios_by_loss(sensitivities, loss)[源代码](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L206)
...@@ -338,3 +378,22 @@ paddleslim.prune.get_ratios_by_loss(sensitivities, loss)[源代码](https://gith ...@@ -338,3 +378,22 @@ paddleslim.prune.get_ratios_by_loss(sensitivities, loss)[源代码](https://gith
返回: 返回:
- **ratios(dict)** - 一组剪切率。`key`是待剪裁参数的名称。`value`是对应参数的剪裁率。 - **ratios(dict)** - 一组剪切率。`key`是待剪裁参数的名称。`value`是对应参数的剪裁率。
示例:
```
from paddleslim.prune import get_ratios_by_loss
sen = {"weight_0":
{0.1: 0.22,
0.2: 0.33
},
"weight_1":
{0.1: 0.21,
0.2: 0.4
}
}
ratios = get_ratios_by_loss(sen, 0.3)
print(ratios)
```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册