diff --git a/configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml b/configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml index 827029b7dd24ff1a28c98db24b31b7e01ec15925..dc47e2269f9fce481eab73178eab99051c726bb4 100644 --- a/configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml +++ b/configs/quick_start/MobileNetV3_large_x1_0_finetune.yaml @@ -2,6 +2,7 @@ mode: 'train' ARCHITECTURE: name: 'MobileNetV3_large_x1_0' pretrained_model: "./pretrained/MobileNetV3_large_x1_0_pretrained" +load_static_weights: True model_save_dir: "./output/" classes_num: 102 total_images: 1020 diff --git a/configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml b/configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml index 93839e99cbf0631b37820740e6e32ac1fa67fb09..3e483821263c60e1fdc354c098e0c08115e0d4ba 100644 --- a/configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml +++ b/configs/quick_start/R50_vd_distill_MV3_large_x1_0.yaml @@ -5,6 +5,9 @@ ARCHITECTURE: pretrained_model: - "./pretrained/flowers102_R50_vd_final/ppcls" - "./pretrained/MobileNetV3_large_x1_0_pretrained/" +load_static_weights: + - False + - True model_save_dir: "./output/" classes_num: 102 total_images: 7169 diff --git a/configs/quick_start/ResNet50_vd.yaml b/configs/quick_start/ResNet50_vd.yaml index 913090921da7111d1e0625016158fec8af8c8bcf..76e4d316a1e78c977ed1490c9ce03d71bda6e673 100644 --- a/configs/quick_start/ResNet50_vd.yaml +++ b/configs/quick_start/ResNet50_vd.yaml @@ -1,7 +1,10 @@ mode: 'train' ARCHITECTURE: name: 'ResNet50_vd' + +checkpoints: "" pretrained_model: "" +load_static_weights: True model_save_dir: "./output/" classes_num: 102 total_images: 1020 diff --git a/configs/quick_start/ResNet50_vd_finetune.yaml b/configs/quick_start/ResNet50_vd_finetune.yaml index 415e0e80d1d38da991c6431f406e5d39dc03bb6f..1aa6b19e54074de4045ade5bd69a392de6350c22 100644 --- a/configs/quick_start/ResNet50_vd_finetune.yaml +++ b/configs/quick_start/ResNet50_vd_finetune.yaml @@ -2,6 +2,7 @@ mode: 'train' ARCHITECTURE: name: 'ResNet50_vd' pretrained_model: "./pretrained/ResNet50_vd_pretrained" +load_static_weights: true model_save_dir: "./output/" classes_num: 102 total_images: 1020 diff --git a/configs/quick_start/ResNet50_vd_ssld_finetune.yaml b/configs/quick_start/ResNet50_vd_ssld_finetune.yaml index d6bf863414a23dd9de3e3c490e5c09a7e20e01c7..511dff0054dd6db1ee64a5fb33ee5574635efb2b 100644 --- a/configs/quick_start/ResNet50_vd_ssld_finetune.yaml +++ b/configs/quick_start/ResNet50_vd_ssld_finetune.yaml @@ -4,6 +4,7 @@ ARCHITECTURE: params: lr_mult_list: [0.1, 0.1, 0.2, 0.2, 0.3] pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained" +load_static_weights: True model_save_dir: "./output/" classes_num: 102 total_images: 1020 diff --git a/configs/quick_start/ResNet50_vd_ssld_random_erasing_finetune.yaml b/configs/quick_start/ResNet50_vd_ssld_random_erasing_finetune.yaml index 629f050eda26edae9efe512980dc7faea28fc39f..0687ea3b7c741ebe21343652b57324f2fbc5ae4b 100644 --- a/configs/quick_start/ResNet50_vd_ssld_random_erasing_finetune.yaml +++ b/configs/quick_start/ResNet50_vd_ssld_random_erasing_finetune.yaml @@ -4,6 +4,7 @@ ARCHITECTURE: params: lr_mult_list: [0.1, 0.1, 0.2, 0.2, 0.3] pretrained_model: "./pretrained/ResNet50_vd_ssld_pretrained" +load_static_weights: True model_save_dir: "./output/" classes_num: 102 total_images: 1020 diff --git a/ppcls/modeling/architectures/__init__.py b/ppcls/modeling/architectures/__init__.py index ffc0851752f0808d33f352da9001c41c9a682576..82d1b2a2e665e8933d058d2fb6d5346da02932f9 100644 --- a/ppcls/modeling/architectures/__init__.py +++ b/ppcls/modeling/architectures/__init__.py @@ -28,3 +28,5 @@ from .mobilenet_v1 import MobileNetV1_x0_25, MobileNetV1_x0_5, MobileNetV1_x0_75 from .mobilenet_v2 import MobileNetV2_x0_25, MobileNetV2_x0_5, MobileNetV2_x0_75, MobileNetV2, MobileNetV2_x1_5, MobileNetV2_x2_0 from .mobilenet_v3 import MobileNetV3_small_x0_35, MobileNetV3_small_x0_5, MobileNetV3_small_x0_75, MobileNetV3_small_x1_0, MobileNetV3_small_x1_25, MobileNetV3_large_x0_35, MobileNetV3_large_x0_5, MobileNetV3_large_x0_75, MobileNetV3_large_x1_0, MobileNetV3_large_x1_25 from .shufflenet_v2 import ShuffleNetV2_x0_25, ShuffleNetV2_x0_33, ShuffleNetV2_x0_5, ShuffleNetV2, ShuffleNetV2_x1_5, ShuffleNetV2_x2_0, ShuffleNetV2_swish + +from .distillation_models import ResNet50_vd_distill_MobileNetV3_large_x1_0 \ No newline at end of file diff --git a/ppcls/modeling/architectures/distillation_models.py b/ppcls/modeling/architectures/distillation_models.py index f5f24b36a260f7d816a164dd0a8e86266550b0dc..928c1c8c7de337fe2aac60fb1e23519f455bd906 100644 --- a/ppcls/modeling/architectures/distillation_models.py +++ b/ppcls/modeling/architectures/distillation_models.py @@ -1,16 +1,16 @@ -#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. # -#Licensed under the Apache License, Version 2.0 (the "License"); -#you may not use this file except in compliance with the License. -#You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # -#Unless required by applicable law or agreed to in writing, software -#distributed under the License is distributed on an "AS IS" BASIS, -#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -#See the License for the specific language governing permissions and -#limitations under the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from __future__ import absolute_import from __future__ import division @@ -32,27 +32,35 @@ __all__ = [ ] -class ResNet50_vd_distill_MobileNetV3_large_x1_0(): - def net(self, input, class_dim=1000): - # student - student = MobileNetV3_large_x1_0() - out_student = student.net(input, class_dim=class_dim) - # teacher - teacher = ResNet50_vd() - out_teacher = teacher.net(input, class_dim=class_dim) - out_teacher.stop_gradient = True +class ResNet50_vd_distill_MobileNetV3_large_x1_0(fluid.dygraph.Layer): + def __init__(self, class_dim=1000, **args): + super(ResNet50_vd_distill_MobileNetV3_large_x1_0, self).__init__() - return out_teacher, out_student + self.teacher = ResNet50_vd(class_dim=class_dim, **args) + self.student = MobileNetV3_large_x1_0(class_dim=class_dim, **args) -class ResNeXt101_32x16d_wsl_distill_ResNet50_vd(): - def net(self, input, class_dim=1000): - # student - student = ResNet50_vd() - out_student = student.net(input, class_dim=class_dim) - # teacher - teacher = ResNeXt101_32x16d_wsl() - out_teacher = teacher.net(input, class_dim=class_dim) - out_teacher.stop_gradient = True + def forward(self, input): + teacher_label = self.teacher(input) + teacher_label.stop_gradient = True - return out_teacher, out_student + student_label = self.student(input) + + return teacher_label, student_label + + +class ResNeXt101_32x16d_wsl_distill_ResNet50_vd(fluid.dygraph.Layer): + def __init__(self, class_dim=1000, **args): + super(ResNet50_vd_distill_MobileNetV3_large_x1_0, self).__init__() + + self.teacher = ResNeXt101_32x16d_wsl(class_dim=class_dim, **args) + + self.student = ResNet50_vd(class_dim=class_dim, **args) + + def forward(self, input): + teacher_label = self.teacher(input) + teacher_label.stop_gradient = True + + student_label = self.student(input) + + return teacher_label, student_label \ No newline at end of file diff --git a/ppcls/optimizer/learning_rate.py b/ppcls/optimizer/learning_rate.py index 0227ae43e0f5c99894dcaeb1b51425fbdc9a8c82..87c08e747b52f91f555a5787390d5d55c617278f 100644 --- a/ppcls/optimizer/learning_rate.py +++ b/ppcls/optimizer/learning_rate.py @@ -112,35 +112,19 @@ class CosineWarmup(object): self.lr = lr self.step_each_epoch = step_each_epoch self.epochs = epochs - self.warmup_epoch = fluid.layers.fill_constant( - shape=[1], - value=float(warmup_epoch), - dtype='float32', - force_cpu=True) + self.warmup_epoch = warmup_epoch def __call__(self): - global_step = _decay_step_counter() - learning_rate = fluid.layers.tensor.create_global_var( - shape=[1], - value=0.0, - dtype='float32', - persistable=True, - name="learning_rate") - epoch = ops.floor(global_step / self.step_each_epoch) - with fluid.layers.control_flow.Switch() as switch: - with switch.case(epoch < self.warmup_epoch): - decayed_lr = self.lr * \ - (global_step / (self.step_each_epoch * self.warmup_epoch)) - fluid.layers.tensor.assign( - input=decayed_lr, output=learning_rate) - with switch.default(): - current_step = global_step - self.warmup_epoch * self.step_each_epoch - total_step = ( - self.epochs - self.warmup_epoch) * self.step_each_epoch - decayed_lr = self.lr * \ - (ops.cos(current_step * math.pi / total_step) + 1) / 2 - fluid.layers.tensor.assign( - input=decayed_lr, output=learning_rate) + learning_rate = fluid.layers.cosine_decay( + learning_rate=self.lr, + step_each_epoch=self.step_each_epoch, + epochs=self.epochs) + + learning_rate = fluid.layers.linear_lr_warmup( + learning_rate, + warmup_steps=self.warmup_epoch * self.step_each_epoch, + start_lr=0.0, + end_lr=self.lr) return learning_rate @@ -169,37 +153,22 @@ class ExponentialWarmup(object): super(ExponentialWarmup, self).__init__() self.lr = lr self.step_each_epoch = step_each_epoch - self.decay_epochs = decay_epochs * self.step_each_epoch + self.decay_epochs = decay_epochs self.decay_rate = decay_rate - self.warmup_epoch = fluid.layers.fill_constant( - shape=[1], - value=float(warmup_epoch), - dtype='float32', - force_cpu=True) + self.warmup_epoch = warmup_epoch def __call__(self): - global_step = _decay_step_counter() - learning_rate = fluid.layers.tensor.create_global_var( - shape=[1], - value=0.0, - dtype='float32', - persistable=True, - name="learning_rate") - - epoch = ops.floor(global_step / self.step_each_epoch) - with fluid.layers.control_flow.Switch() as switch: - with switch.case(epoch < self.warmup_epoch): - decayed_lr = self.lr * \ - (global_step / (self.step_each_epoch * self.warmup_epoch)) - fluid.layers.tensor.assign( - input=decayed_lr, output=learning_rate) - with switch.default(): - rest_step = global_step - self.warmup_epoch * self.step_each_epoch - div_res = ops.floor(rest_step / self.decay_epochs) - - decayed_lr = self.lr * (self.decay_rate**div_res) - fluid.layers.tensor.assign( - input=decayed_lr, output=learning_rate) + learning_rate = fluid.layers.exponential_decay( + learning_rate=self.lr, + decay_steps=self.decay_epochs * self.step_each_epoch, + decay_rate=self.decay_rate, + staircase=False) + + learning_rate = fluid.layers.linear_lr_warmup( + learning_rate, + warmup_steps=self.warmup_epoch * self.step_each_epoch, + start_lr=0.0, + end_lr=self.lr) return learning_rate diff --git a/ppcls/utils/check.py b/ppcls/utils/check.py index c8f13eb4ab97cd8d412aba9e50f9368015278e41..b09a2498c970270d0ecd6a3022a2f3e08c16c88c 100644 --- a/ppcls/utils/check.py +++ b/ppcls/utils/check.py @@ -31,12 +31,12 @@ def check_version(): Log error and exit when the installed version of paddlepaddle is not satisfied. """ - err = "PaddlePaddle version 1.7 or higher is required, " \ + err = "PaddlePaddle version 2.0.0 or higher is required, " \ "or a suitable develop version is satisfied as well. \n" \ "Please make sure the version is good with your code." \ try: - fluid.require_version('1.7.0') + fluid.require_version('2.0.0') except Exception: logger.error(err) sys.exit(1) diff --git a/ppcls/utils/logger.py b/ppcls/utils/logger.py index 12789c7c893a9d48a189b43dfd251c1a88e45f76..c6efa17e053b724d90564b64208101c1a6bc4c32 100644 --- a/ppcls/utils/logger.py +++ b/ppcls/utils/logger.py @@ -16,9 +16,6 @@ import logging import os import datetime -from imp import reload -reload(logging) - logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s: %(message)s", diff --git a/ppcls/utils/save_load.py b/ppcls/utils/save_load.py index 1d1e505b9b719058b626b83b0e9949383c9c95cd..e2bc8b9e5d5436c0aecd001e53d6710a57fcf3ac 100644 --- a/ppcls/utils/save_load.py +++ b/ppcls/utils/save_load.py @@ -26,7 +26,7 @@ import paddle.fluid as fluid from ppcls.utils import logger -__all__ = ['init_model', 'save_model'] +__all__ = ['init_model', 'save_model', 'load_dygraph_pretrain'] def _mkdir_if_not_exist(path): @@ -45,71 +45,34 @@ def _mkdir_if_not_exist(path): raise OSError('Failed to mkdir {}'.format(path)) -def _load_state(path): - if os.path.exists(path + '.pdopt'): - # XXX another hack to ignore the optimizer state - tmp = tempfile.mkdtemp() - dst = os.path.join(tmp, os.path.basename(os.path.normpath(path))) - shutil.copy(path + '.pdparams', dst + '.pdparams') - state = fluid.io.load_program_state(dst) - shutil.rmtree(tmp) - else: - state = fluid.io.load_program_state(path) - return state - - -def load_params(exe, prog, path, ignore_params=None): - """ - Load model from the given path. - Args: - exe (fluid.Executor): The fluid.Executor object. - prog (fluid.Program): load weight to which Program object. - path (string): URL string or loca model path. - ignore_params (list): ignore variable to load when finetuning. - It can be specified by finetune_exclude_pretrained_params - and the usage can refer to the document - docs/advanced_tutorials/TRANSFER_LEARNING.md - """ +def load_dygraph_pretrain( + model, + path=None, + load_static_weights=False, ): if not (os.path.isdir(path) or os.path.exists(path + '.pdparams')): raise ValueError("Model pretrain path {} does not " "exists.".format(path)) + if load_static_weights: + pre_state_dict = fluid.load_program_state(path) + param_state_dict = {} + model_dict = model.state_dict() + for key in model_dict.keys(): + weight_name = model_dict[key].name + if weight_name in pre_state_dict.keys(): + print('Load weight: {}, shape: {}'.format( + weight_name, pre_state_dict[weight_name].shape)) + param_state_dict[key] = pre_state_dict[weight_name] + else: + param_state_dict[key] = model_dict[key] + model.set_dict(param_state_dict) + return - logger.info( - logger.coloring('Loading parameters from {}...'.format(path), - 'HEADER')) - - ignore_set = set() - state = _load_state(path) - - # ignore the parameter which mismatch the shape - # between the model and pretrain weight. - all_var_shape = {} - for block in prog.blocks: - for param in block.all_parameters(): - all_var_shape[param.name] = param.shape - ignore_set.update([ - name for name, shape in all_var_shape.items() - if name in state and shape != state[name].shape - ]) - - if ignore_params: - all_var_names = [var.name for var in prog.list_vars()] - ignore_list = filter( - lambda var: any([re.match(name, var) for name in ignore_params]), - all_var_names) - ignore_set.update(list(ignore_list)) - - if len(ignore_set) > 0: - for k in ignore_set: - if k in state: - logger.warning( - 'variable {} is already excluded automatically'.format(k)) - del state[k] - - fluid.io.set_program_state(prog, state) + param_state_dict, optim_state_dict = fluid.load_dygraph(path) + model.set_dict(param_state_dict) + return -def init_model(config, net, optimizer): +def init_model(config, net, optimizer=None): """ load model from checkpoint or pretrained_model """ @@ -128,16 +91,24 @@ def init_model(config, net, optimizer): return pretrained_model = config.get('pretrained_model') + load_static_weights = config.get('load_static_weights', False) + use_distillation = config.get('use_distillation', False) if pretrained_model: if not isinstance(pretrained_model, list): pretrained_model = [pretrained_model] - # TODO: load pretrained_model - raise NotImplementedError - for pretrain in pretrained_model: - load_params(exe, program, pretrain) - logger.info( - logger.coloring("Finish initing model from {}".format( - pretrained_model), "HEADER")) + if not isinstance(load_static_weights, list): + load_static_weights = [load_static_weights] * len(pretrained_model) + for idx, pretrained in enumerate(pretrained_model): + load_static = load_static_weights[idx] + model = net + if use_distillation and not load_static: + model = net.teacher + load_dygraph_pretrain( + model, path=pretrained, load_static_weights=load_static) + + logger.info( + logger.coloring("Finish initing model from {}".format( + pretrained_model), "HEADER")) def save_model(net, optimizer, model_path, epoch_id, prefix='ppcls'): diff --git a/tools/infer/infer.py b/tools/infer/infer.py index 95aba7f39559baf1a1ea59c3767d982e32f4df64..410a90476590a70b651977b55d463457fa7710ac 100644 --- a/tools/infer/infer.py +++ b/tools/infer/infer.py @@ -18,6 +18,8 @@ import numpy as np import paddle.fluid as fluid from ppcls.modeling import architectures +from ppcls.utils.save_load import load_dygraph_pretrain + def parse_args(): def str2bool(v): @@ -28,9 +30,11 @@ def parse_args(): parser.add_argument("-m", "--model", type=str) parser.add_argument("-p", "--pretrained_model", type=str) parser.add_argument("--use_gpu", type=str2bool, default=True) + parser.add_argument("--load_static_weights", type=str2bool, default=True) return parser.parse_args() + def create_operators(): size = 224 img_mean = [0.485, 0.456, 0.406] @@ -66,32 +70,32 @@ def main(): args = parse_args() operators = create_operators() # assign the place - gpu_id = fluid.dygraph.parallel.Env().dev_id - place = fluid.CUDAPlace(gpu_id) - - pre_weights_dict = fluid.load_program_state(args.pretrained_model) + if args.use_gpu: + gpu_id = fluid.dygraph.parallel.Env().dev_id + place = fluid.CUDAPlace(gpu_id) + else: + place = fluid.CPUPlace() + with fluid.dygraph.guard(place): net = architectures.__dict__[args.model]() data = preprocess(args.image_file, operators) data = np.expand_dims(data, axis=0) data = fluid.dygraph.to_variable(data) - dy_weights_dict = net.state_dict() - pre_weights_dict_new = {} - for key in dy_weights_dict: - weights_name = dy_weights_dict[key].name - pre_weights_dict_new[key] = pre_weights_dict[weights_name] - net.set_dict(pre_weights_dict_new) + load_dygraph_pretrain(net, args.pretrained_model, + args.load_static_weights) net.eval() outputs = net(data) outputs = fluid.layers.softmax(outputs) outputs = outputs.numpy() - + probs = postprocess(outputs) rank = 1 for idx, prob in probs: - print("top{:d}, class id: {:d}, probability: {:.4f}".format( - rank, idx, prob)) + print("top{:d}, class id: {:d}, probability: {:.4f}".format(rank, idx, + prob)) rank += 1 + return + if __name__ == "__main__": main() diff --git a/tools/program.py b/tools/program.py index 55900b98599ecc15055f7368a9fea9b1430e852f..05a14ec97b868a3eb12bced1f36842abf251bd6d 100644 --- a/tools/program.py +++ b/tools/program.py @@ -21,6 +21,7 @@ import time from collections import OrderedDict +import paddle import paddle.fluid as fluid from ppcls.optimizer import LearningRateBuilder @@ -71,6 +72,8 @@ def create_model(architecture, classes_num): """ name = architecture["name"] params = architecture.get("params", {}) + print(name) + print(params) return architectures.__dict__[name](class_dim=classes_num, **params) @@ -278,7 +281,7 @@ def mixed_precision_optimizer(config, optimizer): def create_feeds(batch, use_mix): - image = to_variable(batch[0].numpy().astype("float32")) + image = batch[0] if use_mix: y_a = to_variable(batch[1].numpy().astype("int64").reshape(-1, 1)) y_b = to_variable(batch[2].numpy().astype("int64").reshape(-1, 1)) diff --git a/tools/train.py b/tools/train.py index 976136e359f6631235cc009ac90e7ee9b72e594e..afb78354874250019d22386d8226a62ba979e60a 100644 --- a/tools/train.py +++ b/tools/train.py @@ -57,13 +57,14 @@ def main(args): with fluid.dygraph.guard(place): net = program.create_model(config.ARCHITECTURE, config.classes_num) - if config["use_data_parallel"]: - strategy = fluid.dygraph.parallel.prepare_context() - net = fluid.dygraph.parallel.DataParallel(net, strategy) optimizer = program.create_optimizer( config, parameter_list=net.parameters()) + if config["use_data_parallel"]: + strategy = fluid.dygraph.parallel.prepare_context() + net = fluid.dygraph.parallel.DataParallel(net, strategy) + # load model from checkpoint or pretrained model init_model(config, net, optimizer) @@ -102,7 +103,7 @@ def main(args): config.model_save_dir, config.ARCHITECTURE["name"]) save_model(net, optimizer, model_path, - "best_model_in_epoch_" + str(epoch_id)) + "best_model") # 3. save the persistable model if epoch_id % config.save_interval == 0: