From bb224bf71040511738fd597224bffdb491fce9b8 Mon Sep 17 00:00:00 2001 From: cc <52520497+juncaipeng@users.noreply.github.com> Date: Tue, 18 May 2021 20:45:24 +0800 Subject: [PATCH] Add CE test for dygraph qat (#748) (#749) --- ce_tests/dygraph/qat/readme.md | 41 +++ ce_tests/dygraph/qat/run_test.sh | 11 + ce_tests/dygraph/qat/run_train.sh | 22 ++ ce_tests/dygraph/qat/src/eval.py | 100 ++++++++ ce_tests/dygraph/qat/src/imagenet_dataset.py | 249 +++++++++++++++++++ ce_tests/dygraph/qat/src/qat.py | 121 +++++++++ ce_tests/dygraph/qat/src/save_quant_model.py | 119 +++++++++ ce_tests/dygraph/qat/src/utility.py | 63 +++++ 8 files changed, 726 insertions(+) create mode 100644 ce_tests/dygraph/qat/readme.md create mode 100644 ce_tests/dygraph/qat/run_test.sh create mode 100644 ce_tests/dygraph/qat/run_train.sh create mode 100644 ce_tests/dygraph/qat/src/eval.py create mode 100644 ce_tests/dygraph/qat/src/imagenet_dataset.py create mode 100644 ce_tests/dygraph/qat/src/qat.py create mode 100644 ce_tests/dygraph/qat/src/save_quant_model.py create mode 100644 ce_tests/dygraph/qat/src/utility.py diff --git a/ce_tests/dygraph/qat/readme.md b/ce_tests/dygraph/qat/readme.md new file mode 100644 index 00000000..ed8786ca --- /dev/null +++ b/ce_tests/dygraph/qat/readme.md @@ -0,0 +1,41 @@ +1. 准备 + +安装需要测试的Paddle版本和PaddleSlim版本。 + +准备ImageNet数据集,假定解压到`/dataset/ILSVRC2012`文件夹,该文件夹下有`train文件夹、val文件夹、train_list.txt和val_list.txt文件`。 + +2. 产出量化模型 + +在`run_train.sh`文件中设置`data_path`为上述ImageNet数据集的路径`/dataset/ILSVRC2012`。 + +根据实际情况,在`run_train.sh`文件中设置使用GPU的id等参数。 + +执行`sh run_train.sh` 会对几个分类模型使用动态图量化训练功能进行量化,其中只执行一个epoch。 +执行完后,在`output_models/quant_dygraph`目录下有产出的量化模型。 + +3. 转换量化模型 + +在Intel CPU上部署量化模型,需要使用`test/save_quant_model.py`脚本进行模型转换。 + +如下是对`mobilenet_v1`模型进行转换的示例。 +``` +python src/save_quant_model.py --load_model_path output_models/quant_dygraph/mobilenet_v1 --save_model_path int8_models/mobilenet_v1 +``` + +4. 测试量化模型 + +在`run_test.sh`脚本中设置`data_path`为上述ImageNet数据集的路径`/dataset/ILSVRC2012`。 + +根据实际情况,在`run_test.sh`文件中设置使用GPU的id等参数。 + +使用`run_test.sh`脚本测试转换前和转换后的量化模型精度。 + +比如: +``` +sh run_test.sh output_models/quant_dygraph/mobilenet_v1 +sh run_test.sh int8_models/mobilenet_v1 +``` + +5. 测试目标 + +使用动态图量化训练功能,产出`mobilenet_v1`,`mobilenet_v2`,`resnet50`,`vgg16`量化模型,测试转换前后量化模型精度在1%误差范围内。 diff --git a/ce_tests/dygraph/qat/run_test.sh b/ce_tests/dygraph/qat/run_test.sh new file mode 100644 index 00000000..8fffd38c --- /dev/null +++ b/ce_tests/dygraph/qat/run_test.sh @@ -0,0 +1,11 @@ +model_path=$1 +test_samples=1000 # if set as -1, use all test samples +data_path='/dataset/ILSVRC2012/' +batch_size=16 + +echo "--------eval model: ${model_name}-------------" +python ./src/eval.py \ + --model_path=$model_path \ + --data_dir=${data_path} \ + --test_samples=${test_samples} \ + --batch_size=${batch_size} diff --git a/ce_tests/dygraph/qat/run_train.sh b/ce_tests/dygraph/qat/run_train.sh new file mode 100644 index 00000000..5f973a2f --- /dev/null +++ b/ce_tests/dygraph/qat/run_train.sh @@ -0,0 +1,22 @@ +export CUDA_VISIBLE_DEVICES=5 + +data_path="/dataset/ILSVRC2012" +epoch=1 +lr=0.0001 +batch_size=32 +num_workers=3 +output_dir=$PWD/output_models + +for model in mobilenet_v1 mobilenet_v2 resnet50 vgg16 +do + python ./src/qat.py \ + --arch=${model} \ + --data=${data_path} \ + --epoch=${epoch} \ + --batch_size=${batch_size} \ + --num_workers=${num_workers} \ + --lr=${lr} \ + --output_dir=${output_dir} \ + --enable_quant + #--use_pact +done diff --git a/ce_tests/dygraph/qat/src/eval.py b/ce_tests/dygraph/qat/src/eval.py new file mode 100644 index 00000000..9185bf8c --- /dev/null +++ b/ce_tests/dygraph/qat/src/eval.py @@ -0,0 +1,100 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np +import time +import sys +import argparse +import functools +import math + +import paddle +import paddle.inference as paddle_infer +from utility import add_arguments, print_arguments +import imagenet_dataset as dataset + + +def eval(args): + model_file = os.path.join(args.model_path, args.model_filename) + params_file = os.path.join(args.model_path, args.params_filename) + config = paddle_infer.Config(model_file, params_file) + config.enable_mkldnn() + + predictor = paddle_infer.create_predictor(config) + + input_names = predictor.get_input_names() + input_handle = predictor.get_input_handle(input_names[0]) + output_names = predictor.get_output_names() + output_handle = predictor.get_output_handle(output_names[0]) + + val_dataset = dataset.ImageNetDataset(data_dir=args.data_dir, mode='val') + eval_loader = paddle.io.DataLoader( + val_dataset, batch_size=args.batch_size, drop_last=True) + + cost_time = 0. + total_num = 0. + correct_1_num = 0 + correct_5_num = 0 + for batch_id, data in enumerate(eval_loader()): + img_np = np.array([tensor.numpy() for tensor in data[0]]) + label_np = np.array([tensor.numpy() for tensor in data[1]]) + + input_handle.reshape(img_np.shape) + input_handle.copy_from_cpu(img_np) + + t1 = time.time() + predictor.run() + t2 = time.time() + cost_time += (t2 - t1) + + output_data = output_handle.copy_to_cpu() + + for i in range(len(label_np)): + label = label_np[i][0] + result = output_data[i, :] + index = result.argsort() + total_num += 1 + if index[-1] == label: + correct_1_num += 1 + if label in index[-5:]: + correct_5_num += 1 + + if batch_id % 10 == 0: + acc1 = correct_1_num / total_num + acc5 = correct_5_num / total_num + avg_time = cost_time / total_num + print( + "batch_id {}, acc1 {:.3f}, acc5 {:.3f}, avg time {:.5f} sec/img". + format(batch_id, acc1, acc5, avg_time)) + + if args.test_samples > 0 and \ + (batch_id + 1)* args.batch_size >= args.test_samples: + break + + acc1 = correct_1_num / total_num + acc5 = correct_5_num / total_num + print("End test: test_acc1 {:.3f}, test_acc5 {:.5f}".format(acc1, acc5)) + + +def main(): + parser = argparse.ArgumentParser(description=__doc__) + add_arg = functools.partial(add_arguments, argparser=parser) + add_arg('model_path', str, "", "The inference model path.") + add_arg('model_filename', str, "int8_infer.pdmodel", "model filename") + add_arg('params_filename', str, "int8_infer.pdiparams", "params filename") + add_arg('data_dir', str, "/dataset/ILSVRC2012/", + "The ImageNet dataset root dir.") + add_arg('test_samples', int, -1, + "Test samples. If set -1, use all test samples") + add_arg('batch_size', int, 16, "Batch size.") + + args = parser.parse_args() + print_arguments(args) + + eval(args) + + +if __name__ == '__main__': + main() diff --git a/ce_tests/dygraph/qat/src/imagenet_dataset.py b/ce_tests/dygraph/qat/src/imagenet_dataset.py new file mode 100644 index 00000000..80fb4cc5 --- /dev/null +++ b/ce_tests/dygraph/qat/src/imagenet_dataset.py @@ -0,0 +1,249 @@ +import os +import math +import random +import functools +import numpy as np +import paddle +from PIL import Image, ImageEnhance +from paddle.io import Dataset + +random.seed(0) +np.random.seed(0) + +DATA_DIM = 224 + +THREAD = 16 +BUF_SIZE = 10240 + +DATA_DIR = './data/ILSVRC2012/' +DATA_DIR = os.path.join(os.path.split(os.path.realpath(__file__))[0], DATA_DIR) + +img_mean = np.array([0.485, 0.456, 0.406]).reshape((3, 1, 1)) +img_std = np.array([0.229, 0.224, 0.225]).reshape((3, 1, 1)) + + +def resize_short(img, target_size): + percent = float(target_size) / min(img.size[0], img.size[1]) + resized_width = int(round(img.size[0] * percent)) + resized_height = int(round(img.size[1] * percent)) + img = img.resize((resized_width, resized_height), Image.LANCZOS) + return img + + +def crop_image(img, target_size, center): + width, height = img.size + size = target_size + if center == True: + w_start = (width - size) / 2 + h_start = (height - size) / 2 + else: + w_start = np.random.randint(0, width - size + 1) + h_start = np.random.randint(0, height - size + 1) + w_end = w_start + size + h_end = h_start + size + img = img.crop((w_start, h_start, w_end, h_end)) + return img + + +def random_crop(img, size, scale=[0.08, 1.0], ratio=[3. / 4., 4. / 3.]): + aspect_ratio = math.sqrt(np.random.uniform(*ratio)) + w = 1. * aspect_ratio + h = 1. / aspect_ratio + + bound = min((float(img.size[0]) / img.size[1]) / (w**2), + (float(img.size[1]) / img.size[0]) / (h**2)) + scale_max = min(scale[1], bound) + scale_min = min(scale[0], bound) + + target_area = img.size[0] * img.size[1] * np.random.uniform(scale_min, + scale_max) + target_size = math.sqrt(target_area) + w = int(target_size * w) + h = int(target_size * h) + + i = np.random.randint(0, img.size[0] - w + 1) + j = np.random.randint(0, img.size[1] - h + 1) + + img = img.crop((i, j, i + w, j + h)) + img = img.resize((size, size), Image.LANCZOS) + return img + + +def rotate_image(img): + angle = np.random.randint(-10, 11) + img = img.rotate(angle) + return img + + +def distort_color(img): + def random_brightness(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Brightness(img).enhance(e) + + def random_contrast(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Contrast(img).enhance(e) + + def random_color(img, lower=0.5, upper=1.5): + e = np.random.uniform(lower, upper) + return ImageEnhance.Color(img).enhance(e) + + ops = [random_brightness, random_contrast, random_color] + np.random.shuffle(ops) + + img = ops[0](img) + img = ops[1](img) + img = ops[2](img) + + return img + + +def process_image(sample, mode, color_jitter, rotate): + img_path = sample[0] + + try: + img = Image.open(img_path) + except: + print(img_path, "not exists!") + return None + if mode == 'train': + if rotate: img = rotate_image(img) + img = random_crop(img, DATA_DIM) + else: + img = resize_short(img, target_size=256) + img = crop_image(img, target_size=DATA_DIM, center=True) + if mode == 'train': + if color_jitter: + img = distort_color(img) + if np.random.randint(0, 2) == 1: + img = img.transpose(Image.FLIP_LEFT_RIGHT) + + if img.mode != 'RGB': + img = img.convert('RGB') + + img = np.array(img).astype('float32').transpose((2, 0, 1)) / 255 + img -= img_mean + img /= img_std + + if mode == 'train' or mode == 'val': + return img, sample[1] + elif mode == 'test': + return [img] + + +def _reader_creator(file_list, + mode, + shuffle=False, + color_jitter=False, + rotate=False, + data_dir=DATA_DIR, + batch_size=1): + def reader(): + try: + with open(file_list) as flist: + full_lines = [line.strip() for line in flist] + if shuffle: + np.random.shuffle(full_lines) + if mode == 'train' and os.getenv('PADDLE_TRAINING_ROLE'): + # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) + per_node_lines = len(full_lines) // trainer_count + lines = full_lines[trainer_id * per_node_lines:( + trainer_id + 1) * per_node_lines] + print( + "read images from %d, length: %d, lines length: %d, total: %d" + % (trainer_id * per_node_lines, per_node_lines, + len(lines), len(full_lines))) + else: + lines = full_lines + + for line in lines: + if mode == 'train' or mode == 'val': + img_path, label = line.split() + img_path = os.path.join(data_dir, img_path) + yield img_path, int(label) + elif mode == 'test': + img_path = os.path.join(data_dir, line) + yield [img_path] + except Exception as e: + print("Reader failed!\n{}".format(str(e))) + os._exit(1) + + mapper = functools.partial( + process_image, mode=mode, color_jitter=color_jitter, rotate=rotate) + + return paddle.reader.xmap_readers(mapper, reader, THREAD, BUF_SIZE) + + +def train(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'train_list.txt') + return _reader_creator( + file_list, + 'train', + shuffle=True, + color_jitter=False, + rotate=False, + data_dir=data_dir) + + +def val(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'val_list.txt') + return _reader_creator(file_list, 'val', shuffle=False, data_dir=data_dir) + + +def test(data_dir=DATA_DIR): + file_list = os.path.join(data_dir, 'test_list.txt') + return _reader_creator(file_list, 'test', shuffle=False, data_dir=data_dir) + + +class ImageNetDataset(Dataset): + def __init__(self, data_dir=DATA_DIR, mode='train'): + super(ImageNetDataset, self).__init__() + self._data_dir = data_dir + train_file_list = os.path.join(data_dir, 'train_list.txt') + val_file_list = os.path.join(data_dir, 'val_list.txt') + test_file_list = os.path.join(data_dir, 'test_list.txt') + self.mode = mode + if mode == 'train': + with open(train_file_list) as flist: + full_lines = [line.strip() for line in flist] + np.random.shuffle(full_lines) + if os.getenv('PADDLE_TRAINING_ROLE'): + # distributed mode if the env var `PADDLE_TRAINING_ROLE` exits + trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + trainer_count = int(os.getenv("PADDLE_TRAINERS", "1")) + per_node_lines = len(full_lines) // trainer_count + lines = full_lines[trainer_id * per_node_lines:( + trainer_id + 1) * per_node_lines] + print( + "read images from %d, length: %d, lines length: %d, total: %d" + % (trainer_id * per_node_lines, per_node_lines, + len(lines), len(full_lines))) + else: + lines = full_lines + self.data = [line.split() for line in lines] + else: + with open(val_file_list) as flist: + lines = [line.strip() for line in flist] + self.data = [line.split() for line in lines] + + def __getitem__(self, index): + sample = self.data[index] + data_path = os.path.join(self._data_dir, sample[0]) + if self.mode == 'train': + data, label = process_image( + [data_path, sample[1]], + mode='train', + color_jitter=False, + rotate=False) + if self.mode == 'val': + data, label = process_image( + [data_path, sample[1]], + mode='val', + color_jitter=False, + rotate=False) + return data, np.array([label]).astype('int64') + + def __len__(self): + return len(self.data) diff --git a/ce_tests/dygraph/qat/src/qat.py b/ce_tests/dygraph/qat/src/qat.py new file mode 100644 index 00000000..41604fe0 --- /dev/null +++ b/ce_tests/dygraph/qat/src/qat.py @@ -0,0 +1,121 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import argparse +import os +import time +import math +import numpy as np + +import paddle +import paddle.hapi as hapi +from paddle.hapi.model import Input +from paddle.metric.metrics import Accuracy +import paddle.vision.models as models + +from paddleslim.dygraph.quant import QAT + +import imagenet_dataset as dataset + + +def main(): + model_list = [x for x in models.__dict__["__all__"]] + assert FLAGS.arch in model_list, "Expected FLAGS.arch in {}, but received {}".format( + model_list, FLAGS.arch) + model = models.__dict__[FLAGS.arch](pretrained=True) + + if FLAGS.enable_quant: + print("quantize model") + quant_config = { + 'weight_preprocess_type': None, + 'activation_preprocess_type': 'PACT' if FLAGS.use_pact else None, + 'weight_quantize_type': "channel_wise_abs_max", + 'activation_quantize_type': 'moving_average_abs_max', + 'weight_bits': 8, + 'activation_bits': 8, + 'window_size': 10000, + 'moving_rate': 0.9, + 'quantizable_layer_type': ['Conv2D', 'Linear'], + } + dygraph_qat = QAT(quant_config) + dygraph_qat.quantize(model) + + model = hapi.Model(model) + + train_dataset = dataset.ImageNetDataset(data_dir=FLAGS.data, mode='train') + val_dataset = dataset.ImageNetDataset(data_dir=FLAGS.data, mode='val') + + optim = paddle.optimizer.SGD(learning_rate=FLAGS.lr, + parameters=model.parameters(), + weight_decay=FLAGS.weight_decay) + + model.prepare(optim, paddle.nn.CrossEntropyLoss(), Accuracy(topk=(1, 5))) + + checkpoint_dir = os.path.join( + FLAGS.output_dir, "checkpoint", FLAGS.arch + "_checkpoint", + time.strftime('%Y-%m-%d-%H-%M', time.localtime())) + model.fit(train_dataset, + val_dataset, + batch_size=FLAGS.batch_size, + epochs=FLAGS.epoch, + save_dir=checkpoint_dir, + num_workers=FLAGS.num_workers) + + if FLAGS.enable_quant: + quant_output_dir = os.path.join(FLAGS.output_dir, "quant_dygraph", + FLAGS.arch, "int8_infer") + input_spec = paddle.static.InputSpec( + shape=[None, 3, 224, 224], dtype='float32') + dygraph_qat.save_quantized_model(model.network, quant_output_dir, + [input_spec]) + print("Save quantized inference model in " + quant_output_dir) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser("Training on ImageNet") + + # model + parser.add_argument( + "--arch", type=str, default='mobilenet_v1', help="model arch") + parser.add_argument( + "--output_dir", type=str, default='output', help="output dir") + + # data + parser.add_argument( + '--data', + default="", + help='path to dataset (should have subdirectories named "train" and "val"' + ) + + # train + parser.add_argument("--epoch", default=1, type=int, help="number of epoch") + parser.add_argument("--batch_size", default=10, type=int, help="batch size") + parser.add_argument( + "--num_workers", default=2, type=int, help="dataloader workers") + parser.add_argument( + '--lr', default=0.0001, type=float, help='initial learning rate') + parser.add_argument( + "--weight-decay", default=1e-4, type=float, help="weight decay") + + # quant + parser.add_argument( + "--enable_quant", action='store_true', help="enable quant model") + parser.add_argument("--use_pact", action='store_true', help="use pact") + + FLAGS = parser.parse_args() + + main() diff --git a/ce_tests/dygraph/qat/src/save_quant_model.py b/ce_tests/dygraph/qat/src/save_quant_model.py new file mode 100644 index 00000000..29b886e0 --- /dev/null +++ b/ce_tests/dygraph/qat/src/save_quant_model.py @@ -0,0 +1,119 @@ +# copyright (c) 2019 paddlepaddle authors. all rights reserved. +# +# licensed under the apache license, version 2.0 (the "license"); +# you may not use this file except in compliance with the license. +# you may obtain a copy of the license at +# +# http://www.apache.org/licenses/license-2.0 +# +# unless required by applicable law or agreed to in writing, software +# distributed under the license is distributed on an "as is" basis, +# without warranties or conditions of any kind, either express or implied. +# see the license for the specific language governing permissions and +# limitations under the license. + +import unittest +import os +import sys +import argparse +import logging +import struct +import six +import numpy as np +import time +import paddle +import paddle.fluid as fluid +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import Quant2Int8MkldnnPass +from paddle.fluid import core + +paddle.enable_static() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + '--load_model_path', + type=str, + default='', + help='A path to a Quant model.') + parser.add_argument( + '--save_model_path', + type=str, + default='', + help='Saved optimized and quantized INT8 model') + parser.add_argument( + '--ops_to_quantize', + type=str, + default='', + help='A comma separated list of operators to quantize. Only quantizable operators are taken into account. If the option is not used, an attempt to quantize all quantizable operators will be made.' + ) + parser.add_argument( + '--op_ids_to_skip', + type=str, + default='', + help='A comma separated list of operator ids to skip in quantization.') + parser.add_argument( + '--debug', + action='store_true', + help='If used, the graph of Quant model is drawn.') + + test_args, args = parser.parse_known_args(namespace=unittest) + return test_args, sys.argv[:1] + args + + +def transform_and_save_int8_model(original_path, save_path): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + inference_scope = fluid.executor.global_scope() + model_filename = 'int8_infer.pdmodel' + params_filename = 'int8_infer.pdiparams' + + with fluid.scope_guard(inference_scope): + if os.path.exists(os.path.join(original_path, '__model__')): + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model(original_path, exe) + else: + [inference_program, feed_target_names, + fetch_targets] = fluid.io.load_inference_model( + original_path, exe, model_filename, params_filename) + + ops_to_quantize = set() + if len(test_args.ops_to_quantize) > 0: + ops_to_quantize = set(test_args.ops_to_quantize.split(',')) + + op_ids_to_skip = set([-1]) + if len(test_args.op_ids_to_skip) > 0: + op_ids_to_skip = set(map(int, test_args.op_ids_to_skip.split(','))) + + graph = IrGraph(core.Graph(inference_program.desc), for_test=True) + if (test_args.debug): + graph.draw('.', 'quant_orig', graph.all_op_nodes()) + transform_to_mkldnn_int8_pass = Quant2Int8MkldnnPass( + ops_to_quantize, + _op_ids_to_skip=op_ids_to_skip, + _scope=inference_scope, + _place=place, + _core=core, + _debug=test_args.debug) + graph = transform_to_mkldnn_int8_pass.apply(graph) + inference_program = graph.to_program() + with fluid.scope_guard(inference_scope): + fluid.io.save_inference_model( + save_path, + feed_target_names, + fetch_targets, + exe, + inference_program, + model_filename=model_filename, + params_filename=params_filename) + print( + "Success! INT8 model obtained from the Quant model can be found at {}\n" + .format(save_path)) + + +if __name__ == '__main__': + global test_args + test_args, remaining_args = parse_args() + transform_and_save_int8_model(test_args.load_model_path, + test_args.save_model_path) diff --git a/ce_tests/dygraph/qat/src/utility.py b/ce_tests/dygraph/qat/src/utility.py new file mode 100644 index 00000000..0f3e1ba5 --- /dev/null +++ b/ce_tests/dygraph/qat/src/utility.py @@ -0,0 +1,63 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import numpy as np +import six +import logging + + +def print_arguments(args): + """Print argparse's arguments. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + parser.add_argument("name", default="Jonh", type=str, help="User name.") + args = parser.parse_args() + print_arguments(args) + + :param args: Input argparse.Namespace for printing. + :type args: argparse.Namespace + """ + print("----------- Configuration Arguments -----------") + for arg, value in sorted(six.iteritems(vars(args))): + print("%s: %s" % (arg, value)) + print("------------------------------------------------") + + +def add_arguments(argname, type, default, help, argparser, **kwargs): + """Add argparse's argument. + + Usage: + + .. code-block:: python + + parser = argparse.ArgumentParser() + add_argument("name", str, "Jonh", "User name.", parser) + args = parser.parse_args() + """ + type = distutils.util.strtobool if type == bool else type + argparser.add_argument( + "--" + argname, + default=default, + type=type, + help=help + ' Default: %(default)s.', + **kwargs) -- GitLab