diff --git a/paddleslim/prune/criterion.py b/paddleslim/prune/criterion.py index 5f1f78c1baaa4740036d7bf1780ee29dbc95789a..e224930fb1d9c1a463ef769c7b30f1a9e5e5a0a7 100644 --- a/paddleslim/prune/criterion.py +++ b/paddleslim/prune/criterion.py @@ -63,7 +63,7 @@ def geometry_median(group, graph): x = w.repeat(channel_num, axis=0) y = np.zeros_like(x) for i in range(channel_num): - y[i * channel_num:(i + 1) * channel_num] = np.tile(channel_num, + y[i * channel_num:(i + 1) * channel_num] = np.tile(w[i], (channel_num, 1)) tmp = np.sqrt(np.sum((x - y)**2, -1)) tmp = tmp.reshape((channel_num, channel_num))