From 4f58a37d5e01b5c05b119154a6d618c9b9b7d263 Mon Sep 17 00:00:00 2001 From: SunGaofeng Date: Sun, 13 Oct 2019 10:28:26 +0000 Subject: [PATCH] fix load pretrain and load params due to the shape check in fluid.io --- .../models/nonlocal_model/nonlocal_model.py | 22 +- .../models/nonlocal_model/nonlocal_utils.py | 219 +++++++++++++++--- PaddleCV/PaddleVideo/models/stnet/stnet.py | 78 ++++++- 3 files changed, 254 insertions(+), 65 deletions(-) diff --git a/PaddleCV/PaddleVideo/models/nonlocal_model/nonlocal_model.py b/PaddleCV/PaddleVideo/models/nonlocal_model/nonlocal_model.py index aa75c35d..ab816b9b 100644 --- a/PaddleCV/PaddleVideo/models/nonlocal_model/nonlocal_model.py +++ b/PaddleCV/PaddleVideo/models/nonlocal_model/nonlocal_model.py @@ -18,7 +18,7 @@ import paddle.fluid as fluid from ..model import ModelBase from . import resnet_video -from .nonlocal_utils import load_params_from_file +from .nonlocal_utils import load_pretrain_params_from_file, load_weights_params_from_file import logging logger = logging.getLogger(__name__) @@ -40,7 +40,9 @@ class NonLocal(ModelBase): self.crop_size = self.get_config_from_sec(self.mode, 'crop_size') def build_input(self, use_dataloader=True): - input_shape = [None, 3, self.video_length, self.crop_size, self.crop_size] + input_shape = [ + None, 3, self.video_length, self.crop_size, self.crop_size + ] label_shape = [None, 1] data = fluid.data( @@ -59,7 +61,7 @@ class NonLocal(ModelBase): assert self.mode != 'infer', \ 'dataloader is not recommendated when infer, please set use_dataloader to be false.' self.dataloader = fluid.io.DataLoader.from_generator( - feed_list=[data, label], capacity=4, iterable=True) + feed_list=[data, label], capacity=4, iterable=True) self.feature_input = [data] self.label_input = label @@ -140,20 +142,10 @@ class NonLocal(ModelBase): ) def load_pretrain_params(self, exe, pretrain, prog, place): - load_params_from_file(exe, prog, pretrain, place) + load_pretrain_params_from_file(exe, prog, pretrain, place) def load_test_weights(self, exe, weights, prog, place): - super(NonLocal, self).load_test_weights(exe, weights, prog, place) - pred_w = fluid.global_scope().find_var('pred_w').get_tensor() - pred_array = np.array(pred_w) - pred_w_shape = pred_array.shape - if len(pred_w_shape) == 2: - logger.info('reshape for pred_w when test') - pred_array = np.transpose(pred_array, (1, 0)) - pred_w_shape = pred_array.shape - pred_array = np.reshape( - pred_array, [pred_w_shape[0], pred_w_shape[1], 1, 1, 1]) - pred_w.set(pred_array.astype('float32'), place) + load_weights_params_from_file(exe, prog, weights, place) def get_learning_rate_decay_list(base_learning_rate, lr_decay, step_lists): diff --git a/PaddleCV/PaddleVideo/models/nonlocal_model/nonlocal_utils.py b/PaddleCV/PaddleVideo/models/nonlocal_model/nonlocal_utils.py index 2b6db083..f8deb380 100644 --- a/PaddleCV/PaddleVideo/models/nonlocal_model/nonlocal_utils.py +++ b/PaddleCV/PaddleVideo/models/nonlocal_model/nonlocal_utils.py @@ -19,58 +19,115 @@ import logging logger = logging.getLogger(__name__) -def load_params_from_file(exe, prog, pretrained_file, place): - logger.info('load params from {}'.format(pretrained_file)) +def is_parameter(var): + return isinstance(var, fluid.framework.Parameter) + + +def load_pretrain_params_from_file(exe, prog, pretrained_file, place): + """ + The pretrined_file stores ResNet50/101 parameters pretrained on ImageNet. + However, the conv_weights of Nonlocal model is not the same as that in ResNet50/101 because the + input shape is [N, C, T, H, W] and the convolution kernels' shape is [Cout, Cin, Kt, Kh, Kw]. It is + different from the convolution kernels of ResNet whose shape is typically [Cout, Cin, Kh, Kw]. + When loading conv_weights from the pretrained file, shape mismatch error will be raised due to the check + in fluid.io. This check on params' shape is newly added in fluid.version==1.6.0. So it is recommendated to + treat conv_weights specifically. + The process is as following: + 1, check the params that will be loaded, those with the same name in the target program and pretrained_file. + These params will be called common params in this function. + 2, Create presistable variables in the new_scope with the name of each common params. If it is the weights of + convolution, the created varibale's shape will be set to 2D-convolution-kernel type. + 3, load params from the pretrained_file into those persistable variables created in the new_scope + 4, get the value of common params in the new_scope and transform it if it belongs to conv weights. + 5, set the value to params in the target program + """ + + logger.info('load pretrained params from {}'.format(pretrained_file)) if os.path.isdir(pretrained_file): - param_list = prog.block(0).all_parameters() + # get params' list in prog + param_list = filter(is_parameter, prog.list_vars()) param_name_list = [p.name for p in param_list] - param_shape = {} - for name in param_name_list: - param_tensor = fluid.global_scope().find_var(name).get_tensor() - param_shape[name] = np.array(param_tensor).shape + # get all params' names in pretrained_file param_name_from_file = os.listdir(pretrained_file) + # get common params of prog and pretrained_file + # only those common params will be loaded from pretrained_file into prog common_names = get_common_names(param_name_list, param_name_from_file) - logger.info('-------- loading params -----------') + # get global scope and block for prog + global_scope = fluid.global_scope() + global_block = prog.global_block() - # load params from file - def is_parameter(var): - if isinstance(var, fluid.framework.Parameter): - return isinstance(var, fluid.framework.Parameter) and \ - os.path.exists(os.path.join(pretrained_file, var.name)) + # save details of common params + common_var_map = {} + for name in common_names: + var = global_block.var(name) + var_type = var.type + var_dtype = var.dtype + var_shape = var.shape + if len(var_shape) == 5: + # When param is conv_weights, its shape is [Cout, Cin, Kt, Kh, Kw]. + # The corresponding params in ResNet50/101 is [Cout, Cin, Kh, Kw] + var_shape2d = (var_shape[0], var_shape[1], var_shape[3], + var_shape[4]) + else: + var_shape2d = var_shape[:] + common_var_map[name] = [var_type, var_dtype, var_shape, var_shape2d] - logger.info("Load pretrain weights from file {}".format( - pretrained_file)) - vars = filter(is_parameter, prog.list_vars()) - fluid.io.load_vars(exe, pretrained_file, vars=vars, main_program=prog) + # create new_scope and new_prog to create vars + cpu_place = fluid.CPUPlace() + exe_cpu = fluid.Executor(cpu_place) + new_scope = fluid.Scope() + new_prog = fluid.Program() + new_start_prog = fluid.Program() + new_block = new_prog.global_block() - # reset params if necessary + # create vars in new_scope + created_vars = [] + with fluid.scope_guard(new_scope): + with fluid.program_guard(new_prog, new_start_prog): + for name in common_names: + var_type, var_dtype, var_shape, var_shape2d = common_var_map[ + name] + new_var = new_block.create_var( + name=name, + type=var_type, + shape=var_shape2d, + dtype=var_dtype, + persistable=True) + created_vars.append(new_var) + + # load pretrained_file into the persistable vars created in new_scope + with fluid.scope_guard(new_scope): + fluid.io.load_vars( + exe_cpu, + pretrained_file, + main_program=new_prog, + vars=created_vars) + + logger.info('-------- loading params -----------') for name in common_names: - t = fluid.global_scope().find_var(name).get_tensor() - t_array = np.array(t) - origin_shape = param_shape[name] - if t_array.shape == origin_shape: - logger.info("load param {}".format(name)) - elif (t_array.shape[:2] == origin_shape[:2]) and ( - t_array.shape[-2:] == origin_shape[-2:]): - num_inflate = origin_shape[2] - stack_t_array = np.stack( - [t_array] * num_inflate, axis=2) / float(num_inflate) - assert origin_shape == stack_t_array.shape, "inflated shape should be the same with tensor {}".format( - name) - t.set(stack_t_array.astype('float32'), place) + # get the tensor of vars in new_scope + new_tensor = new_scope.var(name).get_tensor() + new_value = np.array(new_tensor) + + prog_tensor = global_scope.var(name).get_tensor() + var_type, var_dtype, var_shape, var_shape2d = common_var_map[name] + # set the value of loaded vars to those with the same name in the target program + if len(var_shape) == 5: + # transform the loaded conv weights into the format of [Cout, Cin, Kt, Kh, Kw] + num_inflate = var_shape[2] + stacked_array = np.stack( + [new_value] * num_inflate, axis=2) / float(num_inflate) + prog_tensor.set(stacked_array.astype('float32'), place) logger.info("load inflated({}) param {}".format(num_inflate, name)) else: - logger.info("Invalid case for name: {}".format(name)) - raise - logger.info("finished loading params from resnet pretrained model") + prog_tensor.set(new_value, place) + logger.info("load param {}".format(name)) else: - logger.info( - "pretrained file is not in a directory, not suitable to load params". - format(pretrained_file)) - pass + raise TypeError, \ + "pretrained file is not in a directory, not suitable to load params".format(pretrained_file) def get_common_names(param_name_list, param_name_from_file): @@ -96,3 +153,89 @@ def get_common_names(param_name_list, param_name_from_file): file_only_names.append(name) logger.info(name) return common_names + + +def load_weights_params_from_file(exe, prog, weights, place): + """ + The params of the training process is stored in the file named weights. + However, the network of the training and test process is slightly different due to the layer + named "pred" was fc in trainng but convolution in test. When loading weights of pred (pred_w), + from the pretrained file, shape mismatch error will be raised due to the check in fluid.io. + This check on params' shape is newly added in fluid.version==1.6.0. So it is recommendated to + treat pred_w specifically. + The process is as following: + 1, get the details of param_list in the target program (prog) + 2, create persistable vars in new_scope with the same name as those in param_list with + the details stored in step 1. If the name is 'pred_w', the var shape should be [Cin, Cout]. + 3, get the value of vars in the new_scope. + If var.name is 'pred_w', transform it from fc-weights type to be consistent with convolution. + 4, set the value to params in prog + """ + + logger.info('Load test weights from {}'.format(weights)) + + # get the param_list in prog + prog_vars = filter(is_parameter, prog.list_vars()) + + # save the details of params in prog + var_map = {} + for var in prog_vars: + var_name = var.name + var_type = var.type + var_dtype = var.dtype + var_shape = var.shape + # For pred_w, get the fc-weights type shape + if var_name == "pred_w": + assert len( + var_shape + ) == 5, "pred_weights.shape shoud be [Cout, Cin, 1, 1, 1] when test" + var_shape = (var_shape[1], var_shape[0]) + var_map[var_name] = [var_type, var_dtype, var_shape] + + # create new_scope and new_prog + cpu_place = fluid.CPUPlace() + exe_cpu = fluid.Executor(cpu_place) + new_scope = fluid.Scope() + new_prog = fluid.Program() + new_start_prog = fluid.Program() + new_block = new_prog.global_block() + created_vars = [] + # create persistable variables in new_scope + with fluid.scope_guard(new_scope): + with fluid.program_guard(new_prog, new_start_prog): + for var_name in var_map.keys(): + var_type, var_dtype, var_shape = var_map[var_name] + new_var = new_block.create_var( + name=var_name, + type=var_type, + shape=var_shape, + dtype=var_dtype, + persistable=True) + created_vars.append(new_var) + + # load params from file into the above vars created in new_scope + with fluid.scope_guard(new_scope): + fluid.io.load_vars( + exe_cpu, + '', + main_program=new_prog, + vars=created_vars, + filename=weights) + + # get the global scope of prog + global_scope = fluid.global_scope() + + # set value of vars in new_scope to the params of prog with the same name + # and specially treat on "pred_w" + for var_name in var_map.keys(): + global_tensor = global_scope.var(var_name).get_tensor() + new_tensor = new_scope.var(var_name).get_tensor() + new_value = np.array(new_tensor) + if var_name != "pred_w": + global_tensor.set(new_value, place) + else: + pred_array = np.transpose(new_value, (1, 0)) + pred_array = np.reshape( + pred_array, + [pred_array.shape[0], pred_array.shape[1], 1, 1, 1]) + global_tensor.set(pred_array.astype('float32'), place) diff --git a/PaddleCV/PaddleVideo/models/stnet/stnet.py b/PaddleCV/PaddleVideo/models/stnet/stnet.py index 343862a6..39845237 100644 --- a/PaddleCV/PaddleVideo/models/stnet/stnet.py +++ b/PaddleCV/PaddleVideo/models/stnet/stnet.py @@ -57,8 +57,7 @@ class STNET(ModelBase): image_shape = [None, self.seg_num] + image_shape self.use_dataloader = use_dataloader - image = fluid.data( - name='image', shape=image_shape, dtype='float32') + image = fluid.data(name='image', shape=image_shape, dtype='float32') if self.mode != 'infer': label = fluid.data(name='label', shape=[None, 1], dtype='int64') else: @@ -68,7 +67,7 @@ class STNET(ModelBase): assert self.mode != 'infer', \ 'dataloader is not recommendated when infer, please set use_dataloader to be false.' self.dataloader = fluid.io.DataLoader.from_generator( - feed_list=[image, label], capacity=4, iterable=True) + feed_list=[image, label], capacity=4, iterable=True) self.feature_input = [image] self.label_input = label @@ -149,21 +148,76 @@ class STNET(ModelBase): ) def load_pretrain_params(self, exe, pretrain, prog, place): + """ + The pretrained params are ResNet50 pretrained on ImageNet. + However, conv1_weights of StNet is not the same as that in ResNet50 because the input are super-image + concatanated by a series of images. When loading conv1_weights from the pretrained file, shape + mismatch error will be raised due to the check in fluid.io. This check on params' shape is newly + added in fluid.version==1.6.0. So it is recommendated to treat conv1_weights specifically. + The process is as following: + 1, load params except conv1_weights from pretrain + 2, create var named 'conv1_weights' in new_scope, and load the value from the pretrain file + 3, get the value of conv1_weights in the new_scope and transform it + 4, set the transformed value to conv1_weights in prog + """ + def is_parameter(var): if isinstance(var, fluid.framework.Parameter): return isinstance(var, fluid.framework.Parameter) and (not ("fc_0" in var.name)) \ - and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) and (not ("conv3d" in var.name)) + and (not ("batch_norm" in var.name)) and (not ("xception" in var.name)) \ + and (not ("conv3d" in var.name)) and (not ("conv1_weights") in var.name) logger.info( - "Load pretrain weights from {}, exclude fc, batch_norm, xception, conv3d layers.". + "Load pretrain weights from {}, exclude conv1, fc, batch_norm, xception, conv3d layers.". format(pretrain)) - vars = filter(is_parameter, prog.list_vars()) - fluid.io.load_vars(exe, pretrain, vars=vars, main_program=prog) - param_tensor = fluid.global_scope().find_var( - "conv1_weights").get_tensor() - param_numpy = np.array(param_tensor) - param_numpy = np.mean(param_numpy, axis=1, keepdims=True) / self.seglen + # loaded params from pretrained file exclued conv1, fc, batch_norm, xception, conv3d + prog_vars = filter(is_parameter, prog.list_vars()) + fluid.io.load_vars(exe, pretrain, vars=prog_vars, main_program=prog) + + # get global scope and conv1_weights' details + global_scope = fluid.global_scope() + global_block = prog.global_block() + conv1_weights_name = "conv1_weights" + var_conv1_weights = global_block.var(conv1_weights_name) + tensor_conv1_weights = global_scope.var(conv1_weights_name).get_tensor() + + var_type = var_conv1_weights.type + var_dtype = var_conv1_weights.dtype + var_shape = var_conv1_weights.shape + assert var_shape[ + 1] == 3 * self.seglen, "conv1_weights.shape[1] shoud be 3 x seglen({})".format( + self.seglen) + # transform shape to be consistent with conv1_weights of ResNet50 + var_shape = (var_shape[0], 3, var_shape[2], var_shape[3]) + + # create new_scope and new_prog to create var with transformed shape + cpu_place = fluid.CPUPlace() + exe_cpu = fluid.Executor(cpu_place) + new_scope = fluid.Scope() + new_prog = fluid.Program() + new_start_prog = fluid.Program() + new_block = new_prog.global_block() + with fluid.scope_guard(new_scope): + with fluid.program_guard(new_prog, new_start_prog): + new_var = new_block.create_var( + name=conv1_weights_name, + type=var_type, + shape=var_shape, + dtype=var_dtype, + persistable=True) + + # load conv1_weights from pretrain file into the var created in new_scope + with fluid.scope_guard(new_scope): + fluid.io.load_vars( + exe_cpu, pretrain, main_program=new_prog, vars=[new_var]) + + # get the valued of loaded conv1_weights, and transform it + new_tensor = new_scope.var(conv1_weights_name).get_tensor() + new_value = np.array(new_tensor) + param_numpy = np.mean(new_value, axis=1, keepdims=True) / self.seglen param_numpy = np.repeat(param_numpy, 3 * self.seglen, axis=1) - param_tensor.set(param_numpy.astype(np.float32), place) + # set the value of conv1_weights in the original program + tensor_conv1_weights.set(param_numpy.astype(np.float32), place) + # All the expected pretrained params are set to prog now -- GitLab