From 343890b2af1eeb520d3f9354be9b1a592383fd58 Mon Sep 17 00:00:00 2001 From: lijianshe02 <48898730+lijianshe02@users.noreply.github.com> Date: Fri, 15 May 2020 18:27:49 +0800 Subject: [PATCH] fix fpgm pruning memory bug runing on Windows test=develop (#285) --- paddleslim/prune/criterion.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/paddleslim/prune/criterion.py b/paddleslim/prune/criterion.py index e224930f..6a8fb7d9 100644 --- a/paddleslim/prune/criterion.py +++ b/paddleslim/prune/criterion.py @@ -57,17 +57,21 @@ def geometry_median(group, graph): scores = [] name, value, axis = group[0] assert (len(value.shape) == 4) - w = value.view() - channel_num = value.shape[0] - w.shape = value.shape[0], np.product(value.shape[1:]) - x = w.repeat(channel_num, axis=0) - y = np.zeros_like(x) - for i in range(channel_num): - y[i * channel_num:(i + 1) * channel_num] = np.tile(w[i], - (channel_num, 1)) - tmp = np.sqrt(np.sum((x - y)**2, -1)) - tmp = tmp.reshape((channel_num, channel_num)) - tmp = np.sum(tmp, -1) + + def get_distance_sum(value, out_idx): + w = value.view() + w.shape = value.shape[0], np.product(value.shape[1:]) + selected_filter = np.tile(w[out_idx], (w.shape[0], 1)) + x = w - selected_filter + x = np.sqrt(np.sum(x * x, -1)) + return x.sum() + + dist_sum_list = [] + for out_i in range(value.shape[0]): + dist_sum = get_distance_sum(value, out_i) + dist_sum_list.append(dist_sum) + + tmp = np.array(dist_sum_list) for name, value, axis in group: scores.append((name, axis, tmp)) -- GitLab