diff --git a/docs/docs/api/prune_api.md b/docs/docs/api/prune_api.md index 0f72ef54aee35b1097b09b3eaf934a35645e1af4..eb36e86233eb3ef4dd88f46fefe9c23c20a4ebed 100644 --- a/docs/docs/api/prune_api.md +++ b/docs/docs/api/prune_api.md @@ -307,6 +307,28 @@ paddleslim.prune.merge_sensitive(sensitivities)[源代码](https://github.com/Pa 示例: +``` +from paddleslim.prune import merge_sensitive +sen0 = {"weight_0": + {0.1: 0.22, + 0.2: 0.33 + }, + "weight_1": + {0.1: 0.21, + 0.2: 0.4 + } +} +sen1 = {"weight_0": + {0.3: 0.41, + }, + "weight_2": + {0.1: 0.10, + 0.2: 0.35 + } +} +sensitivities = merge_sensitive([sen0, sen1]) +print(sensitivities) +``` ## load_sensitivities paddleslim.prune.load_sensitivities(sensitivities_file)[源代码](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L184) @@ -323,6 +345,24 @@ paddleslim.prune.load_sensitivities(sensitivities_file)[源代码](https://githu 示例: +``` +import pickle +from paddleslim.prune import load_sensitivities +sen = {"weight_0": + {0.1: 0.22, + 0.2: 0.33 + }, + "weight_1": + {0.1: 0.21, + 0.2: 0.4 + } +} +sensitivities_file = "sensitive_api_demo.data" +with open(sensitivities_file, 'w') as f: + pickle.dump(sen, f) +sensitivities = load_sensitivities(sensitivities_file) +print(sensitivities) +``` ## get_ratios_by_loss paddleslim.prune.get_ratios_by_loss(sensitivities, loss)[源代码](https://github.com/PaddlePaddle/PaddleSlim/blob/develop/paddleslim/prune/sensitive.py#L206) @@ -338,3 +378,22 @@ paddleslim.prune.get_ratios_by_loss(sensitivities, loss)[源代码](https://gith 返回: - **ratios(dict)** - 一组剪切率。`key`是待剪裁参数的名称。`value`是对应参数的剪裁率。 + +示例: + +``` +from paddleslim.prune import get_ratios_by_loss +sen = {"weight_0": + {0.1: 0.22, + 0.2: 0.33 + }, + "weight_1": + {0.1: 0.21, + 0.2: 0.4 + } +} + +ratios = get_ratios_by_loss(sen, 0.3) +print(ratios) + +```