未验证 提交 a13fe40b 编写于 作者: L lijianshe02 提交者: GitHub

fix some prune API related bugs test=develop (#232)

上级 7160d704
......@@ -17,7 +17,7 @@
import logging
import numpy as np
from ..common import get_logger
from ..core import Registry
from ..core import Registry, GraphWrapper
__all__ = ["l1_norm", "CRITERION"]
......@@ -61,13 +61,16 @@ def geometry_median(group, graph):
channel_num = value.shape[0]
w.shape = value.shape[0], np.product(value.shape[1:])
x = w.repeat(channel_num, axis=0)
y = np.tile(channel_num, (channel_num, 1))
y = np.zeros_like(x)
for i in range(channel_num):
y[i * channel_num:(i + 1) * channel_num] = np.tile(channel_num,
(channel_num, 1))
tmp = np.sqrt(np.sum((x - y)**2, -1))
tmp = tmp.reshape((channel_num, channel_num))
tmp = np.sum(tmp, -1)
for name, value, axis in group:
scores.append(name, axis, tmp)
scores.append((name, axis, tmp))
return scores
......
......@@ -96,7 +96,7 @@ def optimal_threshold(group, ratio):
name, axis, score = group[
0] # sort channels by the first convolution's score
score[scoew < 1e-18] = 1e-18
score[score < 1e-18] = 1e-18
score_sorted = np.sort(score)
score_square = score_sorted**2
total_sum = score_square.sum()
......
......@@ -49,7 +49,8 @@ class TestPrune(unittest.TestCase):
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
criterion = 'optimal_threshold'
criterion = 'bn_scale'
idx_selector = 'optimal_threshold'
pruner = Pruner(criterion)
main_program, _, _ = pruner.prune(
main_program,
......
......@@ -15,7 +15,7 @@ import sys
sys.path.append("../")
import unittest
import paddle.fluid as fluid
from paddleslim.prune.walk_pruner import Pruner
from paddleslim.prune import Pruner
from layers import conv_bn_layer
......@@ -72,7 +72,8 @@ class TestPrune(unittest.TestCase):
for param in main_program.global_block().all_parameters():
if "weights" in param.name:
print("param: {}; param shape: {}".format(param.name, param.shape))
print("param: {}; param shape: {}".format(param.name,
param.shape))
self.assertTrue(param.shape == shapes[param.name])
......
......@@ -49,7 +49,7 @@ class TestPrune(unittest.TestCase):
exe = fluid.Executor(place)
scope = fluid.Scope()
exe.run(startup_program, scope=scope)
criterion = 'batch_norm_scale'
criterion = 'bn_scale'
pruner = Pruner(criterion)
main_program, _, _ = pruner.prune(
main_program,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册