From 12a426bec344e09b1e8d6b839a4d44e14c64465c Mon Sep 17 00:00:00 2001 From: Dun Date: Thu, 28 Feb 2019 15:21:00 +0800 Subject: [PATCH] add mem opt settings of deeplabv3+ (#1648) * add mem opt settings * inplace normalize * code polish * chagne url --- fluid/PaddleCV/deeplabv3+/README.md | 11 +- fluid/PaddleCV/deeplabv3+/eval.py | 69 +++++--- fluid/PaddleCV/deeplabv3+/models.py | 21 ++- fluid/PaddleCV/deeplabv3+/train.py | 250 ++++++++++++++------------- fluid/PaddleCV/deeplabv3+/utility.py | 60 +++++++ 5 files changed, 253 insertions(+), 158 deletions(-) create mode 100644 fluid/PaddleCV/deeplabv3+/utility.py diff --git a/fluid/PaddleCV/deeplabv3+/README.md b/fluid/PaddleCV/deeplabv3+/README.md index 97e1600d..3a075c89 100644 --- a/fluid/PaddleCV/deeplabv3+/README.md +++ b/fluid/PaddleCV/deeplabv3+/README.md @@ -40,13 +40,13 @@ data/cityscape/ 如果需要从头开始训练模型,用户需要下载我们的初始化模型 ``` -wget http://paddlemodels.cdn.bcebos.com/deeplab/deeplabv3plus_xception65_initialize.tar.gz -tar -xf deeplabv3plus_xception65_initialize.tar.gz && rm deeplabv3plus_xception65_initialize.tar.gz +wget https://paddle-deeplab.bj.bcebos.com/deeplabv3plus_xception65_initialize.tgz +tar -xf deeplabv3plus_xception65_initialize.tgz && rm deeplabv3plus_xception65_initialize.tgz ``` 如果需要最终训练模型进行fine tune或者直接用于预测,请下载我们的最终模型 ``` -wget http://paddlemodels.cdn.bcebos.com/deeplab/deeplabv3plus.tar.gz -tar -xf deeplabv3plus.tar.gz && rm deeplabv3plus.tar.gz +wget https://paddle-deeplab.bj.bcebos.com/deeplabv3plus.tgz +tar -xf deeplabv3plus.tgz && rm deeplabv3plus.tgz ``` @@ -99,9 +99,10 @@ step: 500, mIoU: 0.7873 ``` ## 其他信息 + |数据集 | pretrained model | trained model | mean IoU |---|---|---|---| -|CityScape | [deeplabv3plus_xception65_initialize.tar.gz](http://paddlemodels.cdn.bcebos.com/deeplab/deeplabv3plus_xception65_initialize.tar.gz) | [deeplabv3plus.tar.gz](http://paddlemodels.cdn.bcebos.com/deeplab/deeplabv3plus.tar.gz) | 0.7873 | +|CityScape | [deeplabv3plus_xception65_initialize.tgz](https://paddle-deeplab.bj.bcebos.com/deeplabv3plus_xception65_initialize.tgz) | [deeplabv3plus.tgz](https://paddle-deeplab.bj.bcebos.com/deeplabv3plus.tgz) | 0.7873 | ## 参考 diff --git a/fluid/PaddleCV/deeplabv3+/eval.py b/fluid/PaddleCV/deeplabv3+/eval.py index 5699f2fa..6137af41 100644 --- a/fluid/PaddleCV/deeplabv3+/eval.py +++ b/fluid/PaddleCV/deeplabv3+/eval.py @@ -2,7 +2,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os -os.environ['FLAGS_fraction_of_gpu_memory_to_use'] = '0.98' +if 'FLAGS_fraction_of_gpu_memory_to_use' not in os.environ: + os.environ['FLAGS_fraction_of_gpu_memory_to_use'] = '0.98' +os.environ['FLAGS_enable_parallel_graph'] = '1' import paddle import paddle.fluid as fluid @@ -12,21 +14,20 @@ from reader import CityscapeDataset import reader import models import sys +import utility +parser = argparse.ArgumentParser() +add_arg = lambda *args: utility.add_arguments(*args, argparser=parser) -def add_argument(name, type, default, help): - parser.add_argument('--' + name, default=default, type=type, help=help) - - -def add_arguments(): - add_argument('total_step', int, -1, - "Number of the step to be evaluated, -1 for full evaluation.") - add_argument('init_weights_path', str, None, - "Path of the weights to evaluate.") - add_argument('dataset_path', str, None, "Cityscape dataset path.") - add_argument('verbose', bool, False, "Print mIoU for each step if verbose.") - add_argument('use_gpu', bool, True, "Whether use GPU or CPU.") - add_argument('num_classes', int, 19, "Number of classes.") +# yapf: disable +add_arg('total_step', int, -1, "Number of the step to be evaluated, -1 for full evaluation.") +add_arg('init_weights_path', str, None, "Path of the weights to evaluate.") +add_arg('dataset_path', str, None, "Cityscape dataset path.") +add_arg('verbose', bool, False, "Print mIoU for each step if verbose.") +add_arg('use_gpu', bool, True, "Whether use GPU or CPU.") +add_arg('num_classes', int, 19, "Number of classes.") +add_arg('use_py_reader', bool, True, "Use py_reader.") +#yapf: enable def mean_iou(pred, label): @@ -43,7 +44,7 @@ def mean_iou(pred, label): def load_model(): - if args.init_weights_path.endswith('/'): + if os.path.isdir(args.init_weights_path): fluid.io.load_params( exe, dirname=args.init_weights_path, main_program=tp) else: @@ -53,9 +54,6 @@ def load_model(): CityscapeDataset = reader.CityscapeDataset -parser = argparse.ArgumentParser() -add_arguments() - args = parser.parse_args() models.clean() @@ -73,8 +71,15 @@ reader.default_config['shuffle'] = False num_classes = args.num_classes with fluid.program_guard(tp, sp): - img = fluid.layers.data(name='img', shape=[3, 0, 0], dtype='float32') - label = fluid.layers.data(name='label', shape=eval_shape, dtype='int32') + if args.use_py_reader: + py_reader = fluid.layers.py_reader(capacity=64, + shapes=[[1, 3, 0, 0], [1] + eval_shape], + dtypes=['float32', 'int32']) + img, label = fluid.layers.read_file(py_reader) + else: + img = fluid.layers.data(name='img', shape=[3, 0, 0], dtype='float32') + label = fluid.layers.data(name='label', shape=eval_shape, dtype='int32') + img = fluid.layers.resize_bilinear(img, image_shape) logit = deeplabv3p(img) logit = fluid.layers.resize_bilinear(logit, eval_shape) @@ -105,16 +110,25 @@ else: total_step = args.total_step batches = dataset.get_batch_generator(batch_size, total_step) +if args.use_py_reader: + py_reader.decorate_tensor_provider(lambda :[ (yield b[1],b[2]) for b in batches]) + py_reader.start() sum_iou = 0 all_correct = np.array([0], dtype=np.int64) all_wrong = np.array([0], dtype=np.int64) -for i, imgs, labels, names in batches: - result = exe.run(tp, - feed={'img': imgs, - 'label': labels}, - fetch_list=[pred, miou, out_wrong, out_correct]) +for i in range(total_step): + if not args.use_py_reader: + _, imgs, labels, names = next(batches) + result = exe.run(tp, + feed={'img': imgs, + 'label': labels}, + fetch_list=[pred, miou, out_wrong, out_correct]) + else: + result = exe.run(tp, + fetch_list=[pred, miou, out_wrong, out_correct]) + wrong = result[2][:-1] + all_wrong right = result[3][:-1] + all_correct all_wrong = wrong.copy() @@ -122,7 +136,6 @@ for i, imgs, labels, names in batches: mp = (wrong + right) != 0 miou2 = np.mean((right[mp] * 1.0 / (right[mp] + wrong[mp]))) if args.verbose: - print('step: %s, mIoU: %s' % (i + 1, miou2)) + print('step: %s, mIoU: %s' % (i + 1, miou2), flush=True) else: - print('\rstep: %s, mIoU: %s' % (i + 1, miou2)) - sys.stdout.flush() + print('\rstep: %s, mIoU: %s' % (i + 1, miou2), end='\r', flush=True) diff --git a/fluid/PaddleCV/deeplabv3+/models.py b/fluid/PaddleCV/deeplabv3+/models.py index c1ea1229..117ab5da 100644 --- a/fluid/PaddleCV/deeplabv3+/models.py +++ b/fluid/PaddleCV/deeplabv3+/models.py @@ -5,6 +5,7 @@ import paddle import paddle.fluid as fluid import contextlib +import os name_scope = "" decode_channel = 48 @@ -146,10 +147,12 @@ def bn_relu(data): def relu(data): - return append_op_result(fluid.layers.relu(data), 'relu') + return append_op_result( + fluid.layers.relu( + data, name=name_scope + 'relu'), 'relu') -def seq_conv(input, channel, stride, filter, dilation=1, act=None): +def seperate_conv(input, channel, stride, filter, dilation=1, act=None): with scope('depthwise'): input = conv( input, @@ -187,14 +190,14 @@ def xception_block(input, with scope('separable_conv' + str(i + 1)): if not activation_fn_in_separable_conv: data = relu(data) - data = seq_conv( + data = seperate_conv( data, channels[i], strides[i], filters[i], dilation=dilation) else: - data = seq_conv( + data = seperate_conv( data, channels[i], strides[i], @@ -273,11 +276,11 @@ def encoder(input): with scope("aspp0"): aspp0 = bn_relu(conv(input, channel, 1, 1, groups=1, padding=0)) with scope("aspp1"): - aspp1 = seq_conv(input, channel, 1, 3, dilation=6, act=relu) + aspp1 = seperate_conv(input, channel, 1, 3, dilation=6, act=relu) with scope("aspp2"): - aspp2 = seq_conv(input, channel, 1, 3, dilation=12, act=relu) + aspp2 = seperate_conv(input, channel, 1, 3, dilation=12, act=relu) with scope("aspp3"): - aspp3 = seq_conv(input, channel, 1, 3, dilation=18, act=relu) + aspp3 = seperate_conv(input, channel, 1, 3, dilation=18, act=relu) with scope("concat"): data = append_op_result( fluid.layers.concat( @@ -300,10 +303,10 @@ def decoder(encode_data, decode_shortcut): [encode_data, decode_shortcut], axis=1) append_op_result(encode_data, 'concat') with scope("separable_conv1"): - encode_data = seq_conv( + encode_data = seperate_conv( encode_data, encode_channel, 1, 3, dilation=1, act=relu) with scope("separable_conv2"): - encode_data = seq_conv( + encode_data = seperate_conv( encode_data, encode_channel, 1, 3, dilation=1, act=relu) return encode_data diff --git a/fluid/PaddleCV/deeplabv3+/train.py b/fluid/PaddleCV/deeplabv3+/train.py index e009f76e..9a0f9f6c 100755 --- a/fluid/PaddleCV/deeplabv3+/train.py +++ b/fluid/PaddleCV/deeplabv3+/train.py @@ -2,7 +2,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os -os.environ['FLAGS_fraction_of_gpu_memory_to_use'] = '0.98' +if 'FLAGS_fraction_of_gpu_memory_to_use' not in os.environ: + os.environ['FLAGS_fraction_of_gpu_memory_to_use'] = '0.98' +os.environ['FLAGS_enable_parallel_graph'] = '1' import paddle import paddle.fluid as fluid @@ -12,105 +14,94 @@ from reader import CityscapeDataset import reader import models import time +import contextlib +import paddle.fluid.profiler as profiler +import utility - -def add_argument(name, type, default, help): - parser.add_argument('--' + name, default=default, type=type, help=help) - - -def add_arguments(): - add_argument('batch_size', int, 2, - "The number of images in each batch during training.") - add_argument('train_crop_size', int, 769, - "'Image crop size during training.") - add_argument('base_lr', float, 0.0001, - "The base learning rate for model training.") - add_argument('total_step', int, 90000, "Number of the training step.") - add_argument('init_weights_path', str, None, - "Path of the initial weights in paddlepaddle format.") - add_argument('save_weights_path', str, None, - "Path of the saved weights during training.") - add_argument('dataset_path', str, None, "Cityscape dataset path.") - add_argument('parallel', bool, False, "using ParallelExecutor.") - add_argument('use_gpu', bool, True, "Whether use GPU or CPU.") - add_argument('num_classes', int, 19, "Number of classes.") - parser.add_argument( - '--enable_ce', - action='store_true', - help='If set, run the task with continuous evaluation logs.') - +parser = argparse.ArgumentParser() +add_arg = lambda *args: utility.add_arguments(*args, argparser=parser) + +# yapf: disable +add_arg('batch_size', int, 2, "The number of images in each batch during training.") +add_arg('train_crop_size', int, 769, "Image crop size during training.") +add_arg('base_lr', float, 0.0001, "The base learning rate for model training.") +add_arg('total_step', int, 90000, "Number of the training step.") +add_arg('init_weights_path', str, None, "Path of the initial weights in paddlepaddle format.") +add_arg('save_weights_path', str, None, "Path of the saved weights during training.") +add_arg('dataset_path', str, None, "Cityscape dataset path.") +add_arg('parallel', bool, True, "using ParallelExecutor.") +add_arg('use_gpu', bool, True, "Whether use GPU or CPU.") +add_arg('num_classes', int, 19, "Number of classes.") +add_arg('load_logit_layer', bool, True, "Load last logit fc layer or not. If you are training with different number of classes, you should set to False.") +add_arg('memory_optimize', bool, True, "Using memory optimizer.") +add_arg('norm_type', str, 'bn', "Normalization type, should be bn or gn.") +add_arg('profile', bool, False, "Enable profiler.") +add_arg('use_py_reader', bool, True, "Use py reader.") +parser.add_argument( + '--enable_ce', + action='store_true', + help='If set, run the task with continuous evaluation logs.') +#yapf: enable + +@contextlib.contextmanager +def profile_context(profile=True): + if profile: + with profiler.profiler('All', 'total', '/tmp/profile_file2'): + yield + else: + yield def load_model(): - myvars = [ - x for x in tp.list_vars() - if isinstance(x, fluid.framework.Parameter) and x.name.find('logit') == - -1 - ] - if args.init_weights_path.endswith('/'): - if args.num_classes == 19: + if os.path.isdir(args.init_weights_path): + load_vars = [ + x for x in tp.list_vars() + if isinstance(x, fluid.framework.Parameter) and x.name.find('logit') == + -1 + ] + if args.load_logit_layer: fluid.io.load_params( exe, dirname=args.init_weights_path, main_program=tp) else: - fluid.io.load_vars(exe, dirname=args.init_weights_path, vars=myvars) + fluid.io.load_vars(exe, dirname=args.init_weights_path, vars=load_vars) else: - if args.num_classes == 19: - fluid.io.load_params( - exe, - dirname="", - filename=args.init_weights_path, - main_program=tp) - else: - fluid.io.load_vars( - exe, dirname="", filename=args.init_weights_path, vars=myvars) + fluid.io.load_params( + exe, + dirname="", + filename=args.init_weights_path, + main_program=tp) + def save_model(): - if args.save_weights_path.endswith('/'): - fluid.io.save_params( - exe, dirname=args.save_weights_path, main_program=tp) - else: - fluid.io.save_params( - exe, dirname="", filename=args.save_weights_path, main_program=tp) + assert not os.path.isfile(args.save_weights_path) + fluid.io.save_params( + exe, dirname=args.save_weights_path, main_program=tp) def loss(logit, label): - label_nignore = (label < num_classes).astype('float32') - label = fluid.layers.elementwise_min( - label, - fluid.layers.assign(np.array( - [num_classes - 1], dtype=np.int32))) + label_nignore = fluid.layers.less_than( + label.astype('float32'), + fluid.layers.assign(np.array([num_classes], 'float32')), + force_cpu=False).astype('float32') logit = fluid.layers.transpose(logit, [0, 2, 3, 1]) logit = fluid.layers.reshape(logit, [-1, num_classes]) label = fluid.layers.reshape(label, [-1, 1]) label = fluid.layers.cast(label, 'int64') label_nignore = fluid.layers.reshape(label_nignore, [-1, 1]) - loss = fluid.layers.softmax_with_cross_entropy(logit, label) - loss = loss * label_nignore - no_grad_set.add(label_nignore.name) - no_grad_set.add(label.name) + loss = fluid.layers.softmax_with_cross_entropy(logit, label, ignore_index=255, numeric_stable_mode=True) + label_nignore.stop_gradient = True + label.stop_gradient = True return loss, label_nignore -def get_cards(args): - if args.enable_ce: - cards = os.environ.get('CUDA_VISIBLE_DEVICES') - num = len(cards.split(",")) - return num - else: - return args.num_devices - - -CityscapeDataset = reader.CityscapeDataset -parser = argparse.ArgumentParser() - -add_arguments() - args = parser.parse_args() +utility.print_arguments(args) models.clean() models.bn_momentum = 0.9997 models.dropout_keep_prop = 0.9 models.label_number = args.num_classes +models.default_norm_type = args.norm_type deeplabv3p = models.deeplabv3p sp = fluid.Program() @@ -133,12 +124,17 @@ weight_decay = 0.00004 base_lr = args.base_lr total_step = args.total_step -no_grad_set = set() - with fluid.program_guard(tp, sp): - img = fluid.layers.data( - name='img', shape=[3] + image_shape, dtype='float32') - label = fluid.layers.data(name='label', shape=image_shape, dtype='int32') + if args.use_py_reader: + batch_size_each = batch_size // fluid.core.get_cuda_device_count() + py_reader = fluid.layers.py_reader(capacity=64, + shapes=[[batch_size_each, 3] + image_shape, [batch_size_each] + image_shape], + dtypes=['float32', 'int32']) + img, label = fluid.layers.read_file(py_reader) + else: + img = fluid.layers.data( + name='img', shape=[3] + image_shape, dtype='float32') + label = fluid.layers.data(name='label', shape=image_shape, dtype='int32') logit = deeplabv3p(img) pred = fluid.layers.argmax(logit, axis=1).astype('int32') loss, mask = loss(logit, label) @@ -154,11 +150,21 @@ with fluid.program_guard(tp, sp): lr, momentum=0.9, regularization=fluid.regularizer.L2DecayRegularizer( - regularization_coeff=weight_decay), ) - retv = opt.minimize(loss_mean, startup_program=sp, no_grad_set=no_grad_set) - -fluid.memory_optimize( - tp, print_log=False, skip_opt_set=set([pred.name, loss_mean.name]), level=1) + regularization_coeff=weight_decay)) + optimize_ops, params_grads = opt.minimize(loss_mean, startup_program=sp) + # ir memory optimizer has some issues, we need to seed grad persistable to + # avoid this issue + for p,g in params_grads: g.persistable = True + + +exec_strategy = fluid.ExecutionStrategy() +exec_strategy.num_threads = fluid.core.get_cuda_device_count() +exec_strategy.num_iteration_per_drop_scope = 100 +build_strategy = fluid.BuildStrategy() +if args.memory_optimize: + build_strategy.fuse_relu_depthwise_conv = True + build_strategy.enable_inplace = True + build_strategy.memory_optimize = True place = fluid.CPUPlace() if args.use_gpu: @@ -170,47 +176,59 @@ if args.init_weights_path: print("load from:", args.init_weights_path) load_model() -dataset = CityscapeDataset(args.dataset_path, 'train') +dataset = reader.CityscapeDataset(args.dataset_path, 'train') if args.parallel: - exe_p = fluid.ParallelExecutor( - use_cuda=True, loss_name=loss_mean.name, main_program=tp) - -batches = dataset.get_batch_generator(batch_size, total_step) - + binary = fluid.compiler.CompiledProgram(tp).with_data_parallel( + loss_name=loss_mean.name, + build_strategy=build_strategy, + exec_strategy=exec_strategy) +else: + binary = fluid.compiler.CompiledProgram(main) + +if args.use_py_reader: + assert(batch_size % fluid.core.get_cuda_device_count() == 0) + def data_gen(): + batches = dataset.get_batch_generator( + batch_size // fluid.core.get_cuda_device_count(), + total_step * fluid.core.get_cuda_device_count()) + for b in batches: + yield b[1], b[2] + py_reader.decorate_tensor_provider(data_gen) + py_reader.start() +else: + batches = dataset.get_batch_generator(batch_size, total_step) total_time = 0.0 epoch_idx = 0 train_loss = 0 -for i, imgs, labels, names in batches: - epoch_idx += 1 - begin_time = time.time() - prev_start_time = time.time() - if args.parallel: - retv = exe_p.run(fetch_list=[pred.name, loss_mean.name], - feed={'img': imgs, - 'label': labels}) - else: - retv = exe.run(tp, - feed={'img': imgs, - 'label': labels}, - fetch_list=[pred, loss_mean]) - end_time = time.time() - total_time += end_time - begin_time - if i % 100 == 0: - print("Model is saved to", args.save_weights_path) - save_model() - print("step {:d}, loss: {:.6f}, step_time_cost: {:.3f}".format( - i, np.mean(retv[1]), end_time - prev_start_time)) - - # only for ce - train_loss = np.mean(retv[1]) +with profile_context(args.profile): + for i in range(total_step): + epoch_idx += 1 + begin_time = time.time() + prev_start_time = time.time() + if not args.use_py_reader: + _, imgs, labels, names = next(batches) + train_loss, = exe.run(binary, + feed={'img': imgs, + 'label': labels}, fetch_list=[loss_mean]) + else: + train_loss, = exe.run(binary, fetch_list=[loss_mean]) + train_loss = np.mean(train_loss) + end_time = time.time() + total_time += end_time - begin_time + if i % 100 == 0: + print("Model is saved to", args.save_weights_path) + save_model() + print("step {:d}, loss: {:.6f}, step_time_cost: {:.3f}".format( + i, train_loss, end_time - prev_start_time)) + +print("Training done. Model is saved to", args.save_weights_path) +save_model() +py_reader.stop() if args.enable_ce: - gpu_num = get_cards(args) + gpu_num = fluid.core.get_cuda_device_count() print("kpis\teach_pass_duration_card%s\t%s" % (gpu_num, total_time / epoch_idx)) print("kpis\ttrain_loss_card%s\t%s" % (gpu_num, train_loss)) - -print("Training done. Model is saved to", args.save_weights_path) -save_model() diff --git a/fluid/PaddleCV/deeplabv3+/utility.py b/fluid/PaddleCV/deeplabv3+/utility.py new file mode 100644 index 00000000..aebb9acb --- /dev/null +++ b/fluid/PaddleCV/deeplabv3+/utility.py @@ -0,0 +1,60 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import distutils.util +import six + + +def print_arguments(args): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + print("----------- Configuration Arguments -----------") + for arg, value in sorted(six.iteritems(vars(args))): + print("%s: %s" % (arg, value)) + print("------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) -- GitLab