diff --git a/surgery.py b/surgery.py index f3a270821cfb60a8878531d54085076adb970a2f..01c37290e63a5c18a7de0b7f54494737e3e98ff9 100644 --- a/surgery.py +++ b/surgery.py @@ -3,6 +3,18 @@ import caffe import numpy as np def transplant(new_net, net, suffix=''): + """ + Transfer weights by copying matching parameters, coercing parameters of + incompatible shape, and dropping unmatched parameters. + + The coercion is useful to convert fully connected layers to their + equivalent convolutional layers, since the weights are the same and only + the shapes are different. In particular, equivalent fully connected and + convolution layers have shapes O x I and O x I x H x W respectively for O + outputs channels, I input channels, H kernel height, and W kernel width. + + Both `net` to `new_net` arguments must be instantiated `caffe.Net`s. + """ for p in net.params: p_new = p + suffix if p_new not in new_net.params: @@ -18,12 +30,10 @@ def transplant(new_net, net, suffix=''): print 'copying', p, ' -> ', p_new, i new_net.params[p_new][i].data.flat = net.params[p][i].data.flat -def expand_score(new_net, new_layer, net, layer): - old_cl = net.params[layer][0].num - new_net.params[new_layer][0].data[:old_cl][...] = net.params[layer][0].data - new_net.params[new_layer][1].data[0,0,0,:old_cl][...] = net.params[layer][1].data - def upsample_filt(size): + """ + Make a 2D bilinear kernel suitable for upsampling of the given (h, w) size. + """ factor = (size + 1) // 2 if size % 2 == 1: center = factor - 1 @@ -34,6 +44,9 @@ def upsample_filt(size): (1 - abs(og[1] - center) / factor) def interp(net, layers): + """ + Set weights of each layer in layers to bilinear kernels for interpolation. + """ for l in layers: m, k, h, w = net.params[l][0].data.shape if m != k and k != 1: @@ -44,3 +57,12 @@ def interp(net, layers): raise filt = upsample_filt(h) net.params[l][0].data[range(m), range(k), :, :] = filt + +def expand_score(new_net, new_layer, net, layer): + """ + Transplant an old score layer's parameters, with k < k' classes, into a new + score layer with k classes s.t. the first k' are the old classes. + """ + old_cl = net.params[layer][0].num + new_net.params[new_layer][0].data[:old_cl][...] = net.params[layer][0].data + new_net.params[new_layer][1].data[0,0,0,:old_cl][...] = net.params[layer][1].data