diff --git a/demo/quant/pact_quant_aware/README.md b/demo/quant/pact_quant_aware/README.md index cf64d7e21b071263dc7cdb2810e87fde4a7d0dc8..69e8ee3d66c8a65a798d06259c184138193bd4fd 100644 --- a/demo/quant/pact_quant_aware/README.md +++ b/demo/quant/pact_quant_aware/README.md @@ -179,7 +179,10 @@ python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/Mob 使用PACT量化训练 ``` -python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --num_epochs 30 --lr 0.0001 --use_pact True --batch_size 128 --lr_strategy=piecewise_decay --step_epochs 20 --l2_decay 1e-5 +# 先分析MobileNetV3模型激活值分布,来初始化PACT截断阈值 +python train.py --analysis=True +# 启动PACT量化训练 +python train.py ``` 输出结果为 diff --git a/demo/quant/pact_quant_aware/pact.py b/demo/quant/pact_quant_aware/pact.py deleted file mode 100644 index 26a2a5efd6e9b819db9b7134a62a1ac8c1fc296f..0000000000000000000000000000000000000000 --- a/demo/quant/pact_quant_aware/pact.py +++ /dev/null @@ -1,30 +0,0 @@ -import sys -import paddle -import paddle.fluid as fluid -from paddleslim.quant import quant_aware, convert -import numpy as np - -from paddle.fluid.layer_helper import LayerHelper - - -def pact(x, name=None): - helper = LayerHelper("pact", **locals()) - dtype = 'float32' - init_thres = 20 - u_param_attr = fluid.ParamAttr( - name=x.name + '_pact', - initializer=fluid.initializer.ConstantInitializer(value=init_thres), - regularizer=fluid.regularizer.L2Decay(0.0001), - learning_rate=1) - u_param = helper.create_parameter( - attr=u_param_attr, shape=[1], dtype=dtype) - x = fluid.layers.elementwise_sub( - x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param))) - x = fluid.layers.elementwise_add( - x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x))) - - return x - - -def get_optimizer(): - return fluid.optimizer.MomentumOptimizer(0.0001, 0.9) diff --git a/demo/quant/pact_quant_aware/train.py b/demo/quant/pact_quant_aware/train.py index d2658d3c51939a5bfa5c4e0db4c1324251265bea..c873c1814e38a505938bb03c704e7ab3862cb092 100644 --- a/demo/quant/pact_quant_aware/train.py +++ b/demo/quant/pact_quant_aware/train.py @@ -7,45 +7,49 @@ import functools import math import time import numpy as np +from collections import defaultdict + import paddle.fluid as fluid sys.path.append(os.path.dirname("__file__")) sys.path.append( os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir)) -from paddleslim.common import get_logger +from paddleslim.common import get_logger, VarCollector from paddleslim.analysis import flops from paddleslim.quant import quant_aware, quant_post, convert import models from utility import add_arguments, print_arguments -from pact import * +from paddle.fluid.layer_helper import LayerHelper quantization_model_save_dir = './quantization_models/' +from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass + _logger = get_logger(__name__, level=logging.INFO) parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) # yapf: disable -add_arg('batch_size', int, 64 * 4, +add_arg('batch_size', int, 128, "Minibatch size.") add_arg('use_gpu', bool, True, "Whether to use GPU or not.") -add_arg('model', str, "MobileNet", +add_arg('model', str, "MobileNetV3_large_x1_0", "The target model.") -add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretrained", +add_arg('pretrained_model', str, "./pretrain/MobileNetV3_large_x1_0_ssld_pretrained", "Whether to use pretrained model.") -add_arg('lr', float, 0.0001, +add_arg('lr', float, 0.001, "The learning rate used to fine-tune pruned model.") add_arg('lr_strategy', str, "piecewise_decay", "The learning rate decay strategy.") -add_arg('l2_decay', float, 3e-5, +add_arg('l2_decay', float, 1e-5, "The l2_decay parameter.") add_arg('momentum_rate', float, 0.9, "The value of momentum_rate.") -add_arg('num_epochs', int, 1, +add_arg('num_epochs', int, 30, "The number of total epochs.") add_arg('total_images', int, 1281167, "The number of total training images.") parser.add_argument('--step_epochs', nargs='+', type=int, - default=[30, 60, 90], + default=[20], help="piecewise decay step") add_arg('config_file', str, None, "The config file for compression with yaml format.") @@ -61,6 +65,8 @@ add_arg('output_dir', str, "output/MobileNetV3_large_x1_0", "model save dir") add_arg('use_pact', bool, True, "Whether to use PACT or not.") +add_arg('analysis', bool, False, + "Whether analysis variables distribution.") # yapf: enable @@ -68,7 +74,9 @@ model_list = [m for m in dir(models) if "__" not in m] def piecewise_decay(args): - step = int(math.ceil(float(args.total_images) / args.batch_size)) + places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places() + step = int( + math.ceil(float(args.total_images) / (args.batch_size * len(places)))) bd = [step * e for e in args.step_epochs] lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)] learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr) @@ -76,18 +84,20 @@ def piecewise_decay(args): learning_rate=learning_rate, momentum=args.momentum_rate, regularization=fluid.regularizer.L2Decay(args.l2_decay)) - return optimizer + return learning_rate, optimizer def cosine_decay(args): - step = int(math.ceil(float(args.total_images) / args.batch_size)) + places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places() + step = int( + math.ceil(float(args.total_images) / (args.batch_size * len(places)))) learning_rate = fluid.layers.cosine_decay( learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs) optimizer = fluid.optimizer.Momentum( learning_rate=learning_rate, momentum=args.momentum_rate, regularization=fluid.regularizer.L2Decay(args.l2_decay)) - return optimizer + return learning_rate, optimizer def create_optimizer(args): @@ -98,30 +108,7 @@ def create_optimizer(args): def compress(args): - # 1. quantization configs - quant_config = { - # weight quantize type, default is 'channel_wise_abs_max' - 'weight_quantize_type': 'channel_wise_abs_max', - # activation quantize type, default is 'moving_average_abs_max' - 'activation_quantize_type': 'moving_average_abs_max', - # weight quantize bit num, default is 8 - 'weight_bits': 8, - # activation quantize bit num, default is 8 - 'activation_bits': 8, - # ops of name_scope in not_quant_pattern list, will not be quantized - 'not_quant_pattern': ['skip_quant'], - # ops of type in quantize_op_types, will be quantized - 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], - # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' - 'dtype': 'int8', - # window size for 'range_abs_max' quantization. defaulf is 10000 - 'window_size': 10000, - # The decay coefficient of moving average, default is 0.9 - 'moving_rate': 0.9, - } - train_reader = None - test_reader = None if args.data == "mnist": import paddle.dataset.mnist as reader train_reader = reader.train() @@ -155,18 +142,126 @@ def compress(args): train_prog = fluid.default_main_program() val_program = fluid.default_main_program().clone(for_test=True) - place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() - opt = create_optimizer(args) - opt.minimize(avg_cost) + if not args.analysis: + learning_rate, opt = create_optimizer(args) + opt.minimize(avg_cost) + place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) + train_reader = paddle.fluid.io.batch( + train_reader, batch_size=args.batch_size, drop_last=True) + train_loader = fluid.io.DataLoader.from_generator( + feed_list=[image, label], + capacity=512, + use_double_buffer=True, + iterable=True) + places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places() + train_loader.set_sample_list_generator(train_reader, places) + + val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size) + valid_loader = fluid.io.DataLoader.from_generator( + feed_list=[image, label], + capacity=512, + use_double_buffer=True, + iterable=True) + valid_loader.set_sample_list_generator(val_reader, places[0]) + + if args.analysis: + # get all activations names + activates = [ + 'pool2d_1.tmp_0', 'tmp_35', 'batch_norm_21.tmp_2', 'tmp_26', + 'elementwise_mul_5.tmp_0', 'pool2d_5.tmp_0', + 'elementwise_add_5.tmp_0', 'relu_2.tmp_0', 'pool2d_3.tmp_0', + 'conv2d_40.tmp_2', 'elementwise_mul_0.tmp_0', 'tmp_62', + 'elementwise_add_8.tmp_0', 'batch_norm_39.tmp_2', 'conv2d_32.tmp_2', + 'tmp_17', 'tmp_5', 'elementwise_add_9.tmp_0', 'pool2d_4.tmp_0', + 'relu_0.tmp_0', 'tmp_53', 'relu_3.tmp_0', 'elementwise_add_4.tmp_0', + 'elementwise_add_6.tmp_0', 'tmp_11', 'conv2d_36.tmp_2', + 'relu_8.tmp_0', 'relu_5.tmp_0', 'pool2d_7.tmp_0', + 'elementwise_add_2.tmp_0', 'elementwise_add_7.tmp_0', + 'pool2d_2.tmp_0', 'tmp_47', 'batch_norm_12.tmp_2', + 'elementwise_mul_6.tmp_0', 'elementwise_mul_7.tmp_0', + 'pool2d_6.tmp_0', 'relu_6.tmp_0', 'elementwise_add_0.tmp_0', + 'elementwise_mul_3.tmp_0', 'conv2d_12.tmp_2', + 'elementwise_mul_2.tmp_0', 'tmp_8', 'tmp_2', 'conv2d_8.tmp_2', + 'elementwise_add_3.tmp_0', 'elementwise_mul_1.tmp_0', + 'pool2d_8.tmp_0', 'conv2d_28.tmp_2', 'image', 'conv2d_16.tmp_2', + 'batch_norm_33.tmp_2', 'relu_1.tmp_0', 'pool2d_0.tmp_0', 'tmp_20', + 'conv2d_44.tmp_2', 'relu_10.tmp_0', 'tmp_41', 'relu_4.tmp_0', + 'elementwise_add_1.tmp_0', 'tmp_23', 'batch_norm_6.tmp_2', 'tmp_29', + 'elementwise_mul_4.tmp_0', 'tmp_14' + ] + var_collector = VarCollector(train_prog, activates, use_ema=True) + values = var_collector.abs_max_run( + train_loader, exe, step=None, loss_name=avg_cost.name) + np.save('pact_thres.npy', values) + _logger.info(values) + _logger.info("PACT threshold have been saved as pact_thres.npy") + + # Draw Histogram in 'dist_pdf/result.pdf' + # var_collector.pdf(values) + + return + + values = defaultdict(lambda: 20) + try: + values = np.load("pact_thres.npy", allow_pickle=True).item() + values.update(tmp) + _logger.info("pact_thres.npy info loaded.") + except: + _logger.info( + "cannot find pact_thres.npy. Set init PACT threshold as 20.") + _logger.info(values) + + # 1. quantization configs + quant_config = { + # weight quantize type, default is 'channel_wise_abs_max' + 'weight_quantize_type': 'channel_wise_abs_max', + # activation quantize type, default is 'moving_average_abs_max' + 'activation_quantize_type': 'moving_average_abs_max', + # weight quantize bit num, default is 8 + 'weight_bits': 8, + # activation quantize bit num, default is 8 + 'activation_bits': 8, + # ops of name_scope in not_quant_pattern list, will not be quantized + 'not_quant_pattern': ['skip_quant'], + # ops of type in quantize_op_types, will be quantized + 'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'], + # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' + 'dtype': 'int8', + # window size for 'range_abs_max' quantization. defaulf is 10000 + 'window_size': 10000, + # The decay coefficient of moving average, default is 0.9 + 'moving_rate': 0.9, + } + # 2. quantization transform programs (training aware) # Make some quantization transforms in the graph before training and testing. # According to the weight and activation quantization type, the graph will be added # some fake quantize operators and fake dequantize operators. + def pact(x): + helper = LayerHelper("pact", **locals()) + dtype = 'float32' + init_thres = values[x.name.split('_tmp_input')[0]] + u_param_attr = fluid.ParamAttr( + name=x.name + '_pact', + initializer=fluid.initializer.ConstantInitializer(value=init_thres), + regularizer=fluid.regularizer.L2Decay(0.0001), + learning_rate=1) + u_param = helper.create_parameter( + attr=u_param_attr, shape=[1], dtype=dtype) + + part_a = fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)) + part_b = fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)) + x = x - part_a + part_b + return x + + def get_optimizer(): + return fluid.optimizer.MomentumOptimizer(args.lr, 0.9) + if args.use_pact: act_preprocess_func = pact optimizer_func = get_optimizer @@ -205,28 +300,18 @@ def compress(args): fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) - val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size) - train_reader = paddle.fluid.io.batch( - train_reader, batch_size=args.batch_size, drop_last=True) - - train_feeder = feeder = fluid.DataFeeder([image, label], place) - val_feeder = feeder = fluid.DataFeeder( - [image, label], place, program=val_program) - def test(epoch, program): batch_id = 0 acc_top1_ns = [] acc_top5_ns = [] - for data in val_reader(): + for data in valid_loader(): start_time = time.time() acc_top1_n, acc_top5_n = exe.run( - program, - feed=train_feeder.feed(data), - fetch_list=[acc_top1.name, acc_top5.name]) + program, feed=data, fetch_list=[acc_top1.name, acc_top5.name]) end_time = time.time() if batch_id % args.log_period == 0: _logger.info( - "Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}". + "Eval epoch[{}] batch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}; time: {:.3f}". format(epoch, batch_id, np.mean(acc_top1_n), np.mean(acc_top5_n), end_time - start_time)) @@ -234,30 +319,35 @@ def compress(args): acc_top5_ns.append(np.mean(acc_top5_n)) batch_id += 1 - _logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format( - epoch, - np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns)))) + _logger.info( + "Final eval epoch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}".format( + epoch, + np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns)))) return np.mean(np.array(acc_top1_ns)) def train(epoch, compiled_train_prog): batch_id = 0 - for data in train_reader(): + for data in train_loader(): start_time = time.time() - loss_n, acc_top1_n, acc_top5_n = exe.run( + lr_n, loss_n, acc_top1_n, acc_top5_n = exe.run( compiled_train_prog, - feed=train_feeder.feed(data), - fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name]) + feed=data, + fetch_list=[ + learning_rate.name, avg_cost.name, acc_top1.name, + acc_top5.name + ]) end_time = time.time() + lr_n = np.mean(lr_n) loss_n = np.mean(loss_n) acc_top1_n = np.mean(acc_top1_n) acc_top5_n = np.mean(acc_top5_n) if batch_id % args.log_period == 0: _logger.info( - "epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}". - format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n, - end_time - start_time)) + "epoch[{}]-batch[{}] lr: {:.6f} - loss: {:.6f}; acc_top1: {:.6f}; acc_top5: {:.6f}; time: {:.3f}". + format(epoch, batch_id, lr_n, loss_n, acc_top1_n, + acc_top5_n, end_time - start_time)) if args.use_pact and batch_id % 1000 == 0: threshold = {} @@ -266,15 +356,12 @@ def compress(args): array = np.array(fluid.global_scope().find_var(var.name) .get_tensor()) threshold[var.name] = array[0] - print(threshold) - + _logger.info(threshold) batch_id += 1 build_strategy = fluid.BuildStrategy() - build_strategy.memory_optimize = False build_strategy.enable_inplace = False build_strategy.fuse_all_reduce_ops = False - build_strategy.sync_batch_norm = False exec_strategy = fluid.ExecutionStrategy() compiled_train_prog = compiled_train_prog.with_data_parallel( loss_name=avg_cost.name, @@ -297,9 +384,16 @@ def compress(args): v = fluid.global_scope().find_var('@LR_DECAY_COUNTER@').get_tensor() v.set(np.array([start_step]).astype(np.float32), place) + best_eval_acc1 = 0 + best_acc1_epoch = 0 for i in range(start_epoch, args.num_epochs): train(i, compiled_train_prog) acc1 = test(i, val_program) + if acc1 > best_eval_acc1: + best_eval_acc1 = acc1 + best_acc1_epoch = i + _logger.info("Best Validation Acc1: {:.6f}, at epoch {}".format( + best_eval_acc1, best_acc1_epoch)) fluid.io.save_persistables( exe, dirname=os.path.join(args.output_dir, str(i)), @@ -311,25 +405,28 @@ def compress(args): exe, dirname=os.path.join(args.output_dir, 'best_model'), main_program=val_program) + if os.path.exists(os.path.join(args.output_dir, 'best_model')): fluid.io.load_persistables( exe, dirname=os.path.join(args.output_dir, 'best_model'), main_program=val_program) + # 3. Freeze the graph after training by adjusting the quantize # operators' order for the inference. # The dtype of float_program's weights is float32, but in int8 range. float_program, int8_program = convert(val_program, place, quant_config, \ scope=None, \ save_int8=True) - print("eval best_model after convert") + _logger.info("eval best_model after convert") final_acc1 = test(best_epoch, float_program) + _logger.info("final acc:{}".format(final_acc1)) + # 4. Save inference model model_path = os.path.join(quantization_model_save_dir, args.model, 'act_' + quant_config['activation_quantize_type'] + '_w_' + quant_config['weight_quantize_type']) float_path = os.path.join(model_path, 'float') - int8_path = os.path.join(model_path, 'int8') if not os.path.isdir(model_path): os.makedirs(model_path) @@ -342,15 +439,6 @@ def compress(args): model_filename=float_path + '/model', params_filename=float_path + '/params') - fluid.io.save_inference_model( - dirname=int8_path, - feeded_var_names=[image.name], - target_vars=[out], - executor=exe, - main_program=int8_program, - model_filename=int8_path + '/model', - params_filename=int8_path + '/params') - def main(): paddle.enable_static() diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py index 894d5d5a1a13d9fd1aa1fabcfc6e849df6fa17ca..2e9e660e981b6cb9e7fea9c709cb4e31d929c236 100644 --- a/paddleslim/common/__init__.py +++ b/paddleslim/common/__init__.py @@ -21,10 +21,10 @@ from .cached_reader import cached_reader from .server import Server from .client import Client from .meter import AvgrageMeter -from .analyze_helper import pdf +from .analyze_helper import VarCollector __all__ = [ 'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer', 'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter', - 'Server', 'Client', 'RLBaseController', 'pdf' + 'Server', 'Client', 'RLBaseController', 'VarCollector' ] diff --git a/paddleslim/common/analyze_helper.py b/paddleslim/common/analyze_helper.py index d5883bb597d027ce2c836e098a45af1383684c64..09879a074407134a7f2e16dd90dff0ee819b4faf 100644 --- a/paddleslim/common/analyze_helper.py +++ b/paddleslim/common/analyze_helper.py @@ -12,116 +12,167 @@ # See the License for the specific language governing permissions and # limitations under the License. -import matplotlib -matplotlib.use('Agg') -import logging -import numpy as np -from matplotlib.backends.backend_pdf import PdfPages -import matplotlib.pyplot as plt import os - +import types import paddle import paddle.fluid as fluid +import numpy as np +from collections import defaultdict +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +from matplotlib.backends.backend_pdf import PdfPages +import logging from ..common import get_logger _logger = get_logger(__name__, level=logging.INFO) -def pdf(program, - var_names, - executor=None, - batch_generator=None, - data_loader=None, - feed_vars=None, - fetch_list=None, - scope=None, - pdf_save_dir='tmp_pdf'): - """ - Draw hist for distributtion of variables in that name is in var_names - - Args: - program(fluid.Program): program to analyze. - var_names(list): name of variables to analyze. When there is activation name in var_names, - you should set executor, one of batch_generator and data_loader, feed_list. - executor(fluid.Executor, optional): The executor to run program. Default is None. - batch_generator(Python Generator, optional): The batch generator provides calibrate data for DataLoader, - and it returns a batch every time. For data_loader and batch_generator, - only one can be set. Default is None. - data_loader(fluid.io.DataLoader, optional): The data_loader provides calibrate data to run program. - Default is None. - feed_vars(list): feed variables for program. When you use batch_generator to provide data, - you should set feed_vars. Default is None. - fetch_list(list): fetch list for program. Default is None. - scope(fluid.Scope, optional): The scope to run program, use it to load variables. - If scope is None, will use fluid.global_scope(). - pdf_save_dir(str): dirname to save pdf. Default is 'tmp_pdf' - - Returns: - dict: numpy array of variables that name in var_names - """ - scope = fluid.global_scope() if scope is None else scope - assert isinstance(var_names, list), 'var_names is a list of variable name' - real_names = [] - weight_only = True - for var in program.list_vars(): - if var.name in var_names: - if var.persistable == False: - weight_only = False - var.persistable = True - real_names.append(var.name) - - if weight_only == False: - if batch_generator is not None: - assert feed_vars is not None, "When using batch_generator, feed_vars must be set" - dataloader = fluid.io.DataLoader.from_generator( - feed_list=feed_vars, capacity=512, iterable=True) - dataloader.set_batch_generator(batch_generator, executor.place) - elif data_loader is not None: - dataloader = data_loader +class Averager(object): + def __init__(self): + self.shadow = {} + self.cnt = 0 + + def register(self, name, val): + self.shadow[name] = val + self.cnt = 1 + + def get(self, name): + return self.shadow[name] + + def record(self): + return self.shadow + + def update(self, name, val): + assert name in self.shadow + new_average = (self.cnt * self.shadow[name] + val) / (self.cnt + 1) + self.cnt += 1 + self.shadow[name] = new_average + + +class EMA(Averager): + def __init__(self, decay): + self.decay = decay + self.shadow = {} + + def update(self, name, val): + assert name in self.shadow + new_average = (1.0 - self.decay) * val + self.decay * self.shadow[name] + self.shadow[name] = new_average + + +class VarCollector(object): + def __init__(self, + program, + var_names, + use_ema=False, + ema_decay=0.999, + scope=None): + self.program = program + self.var_names = var_names + self.scope = fluid.global_scope() if scope is None else scope + self.use_ema = use_ema + self.set_up() + if self.use_ema: + self.stats = EMA(decay=ema_decay) else: - _logger.info( - "When both batch_generator and data_loader is None, var_names can only include weight names" - ) - return - - assert executor is not None, "when var_names include activations'name, executor must be set" - assert fetch_list is not None, "when var_names include activations'name,, executor must be set" - - for data in dataloader: - executor.run(program=program, - feed=data, - fetch_list=fetch_list, - return_numpy=False) - break - - res_np = {} - for name in real_names: - var = fluid.global_scope().find_var(name) - if var is not None: - res_np[name] = np.array(var.get_tensor()) + self.stats = Averager() + + def set_up(self): + self.real_names = [] + if hasattr(self.program, '_program'): + program = self.program._program else: - _logger.info( - "can't find var {}. Maybe you should set one of batch_generator and data_loader". - format(name)) - numbers = len(real_names) - if pdf_save_dir is not None: - if not os.path.exists(pdf_save_dir): - os.mkdir(pdf_save_dir) - pdf_path = os.path.join(pdf_save_dir, 'result.pdf') - with PdfPages(pdf_path) as pdf: - idx = 1 - for name in res_np.keys(): - if idx % 10 == 0: - _logger.info("plt {}/{}".format(idx, numbers)) - arr = res_np[name] - arr = arr.flatten() - weights = np.ones_like(arr) / len(arr) - plt.hist(arr, bins=1000, weights=weights) - plt.xlabel(name) - plt.ylabel("frequency") - plt.title("Hist of variable {}".format(name)) - plt.show() - pdf.savefig() - plt.close() - idx += 1 - return res_np + program = self.program + + for var in program.list_vars(): + if var.name in self.var_names: + self.real_names.append(var.name) + + def update(self, vars_np): + for name in self.real_names: + val = vars_np[name] + if val is not None: + try: + self.stats.update(name, val) + except: + self.stats.register(name, val) + else: + _logger.info("can't find var {}.".format(name)) + return self.stats.record() + + def run(self, reader, exe, step=None, loss_name=None): + if not hasattr(self.program, '_program'): + # Compile the native program to speed up + program = fluid.CompiledProgram(self.program).with_data_parallel( + loss_name=loss_name) + + for idx, data in enumerate(reader): + vars_np = exe.run(program=program, + feed=data, + fetch_list=self.real_names) + mapped_vars_np = dict(zip(self.real_names, vars_np)) + values = self.update(mapped_vars_np) + + if idx % 10 == 0: + _logger.info("Collecting..., Step: {}".format(idx)) + if step is not None and idx + 1 >= step: + break + return values + + def abs_max_run(self, reader, exe, step=None, loss_name=None): + fetch_list = [] + with fluid.program_guard(self.program): + for act_name in self.real_names: + act = self.program.global_block().var(act_name) + act = fluid.layers.reduce_max( + fluid.layers.abs(act), name=act_name + "_reduced") + fetch_list.append(act_name + "_reduced.tmp_0") + + if not hasattr(self.program, '_program'): + # Compile the native program to speed up + program = fluid.CompiledProgram(self.program).with_data_parallel( + loss_name=loss_name) + for idx, data in enumerate(reader): + vars_np = exe.run(program=program, feed=data, fetch_list=fetch_list) + vars_np = [np.max(var) for var in vars_np] + mapped_vars_np = dict(zip(self.real_names, vars_np)) + values = self.update(mapped_vars_np) + + if idx % 10 == 0: + _logger.info("Collecting..., Step: {}".format(idx)) + + if step is not None and idx + 1 >= step: + break + return values + + @staticmethod + def pdf(var_dist, save_dir='dist_pdf'): + """ + Draw histogram for distributtion of variables in that in var_dist. + + Args: + var_dist(dict): numpy array of variables distribution. + save_dir(str): dirname to save pdf. Default is 'dist_pdf' + """ + numbers = len(var_dist) + if save_dir is not None: + if not os.path.exists(save_dir): + os.mkdir(save_dir) + pdf_path = os.path.join(save_dir, 'result.pdf') + with PdfPages(pdf_path) as pdf: + for i, name in enumerate(var_dist.keys()): + if i % 10 == 0: + _logger.info("plt {}/{}".format(i, numbers)) + arr = var_dist[name] + arr = arr.flatten() + weights = np.ones_like(arr) / len(arr) + plt.hist(arr, bins=1000, weights=weights) + plt.xlabel(name) + plt.ylabel("frequency") + plt.title("Hist of variable {}".format(name)) + plt.show() + pdf.savefig() + plt.close() + _logger.info("variables histogram have been saved as {}".format( + pdf_path)) diff --git a/tests/test_analysis_helper.py b/tests/test_analysis_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..abce65f699dea789fc5f23af0a352f834870b0e0 --- /dev/null +++ b/tests/test_analysis_helper.py @@ -0,0 +1,73 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys +sys.path.append("../") +import unittest +import paddle +import paddle.fluid as fluid +from paddleslim.common import VarCollector +from static_case import StaticCase +sys.path.append("../demo") +from models import MobileNet +from layers import conv_bn_layer +import paddle.dataset.mnist as reader +import numpy as np + + +class TestAnalysisHelper(StaticCase): + def test_analysis_helper(self): + image = fluid.layers.data( + name='image', shape=[1, 28, 28], dtype='float32') + label = fluid.layers.data(name='label', shape=[1], dtype='int64') + model = MobileNet() + out = model.net(input=image, class_dim=10) + cost = fluid.layers.cross_entropy(input=out, label=label) + avg_cost = fluid.layers.mean(x=cost) + acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1) + acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5) + optimizer = fluid.optimizer.Momentum( + momentum=0.9, + learning_rate=0.01, + regularization=fluid.regularizer.L2Decay(4e-5)) + optimizer.minimize(avg_cost) + main_prog = fluid.default_main_program() + + places = fluid.cuda_places() if fluid.is_compiled_with_cuda( + ) else fluid.cpu_places() + exe = fluid.Executor(places[0]) + train_reader = paddle.fluid.io.batch( + paddle.dataset.mnist.train(), batch_size=64) + train_loader = fluid.io.DataLoader.from_generator( + feed_list=[image, label], + capacity=512, + use_double_buffer=True, + iterable=True) + train_loader.set_sample_list_generator(train_reader, places) + exe.run(fluid.default_startup_program()) + + vars = ['conv2d_0.tmp_0', 'fc_0.tmp_0', 'fc_0.tmp_1', 'fc_0.tmp_2'] + var_collector1 = VarCollector(main_prog, vars, use_ema=True) + values = var_collector1.abs_max_run( + train_loader, exe, step=None, loss_name=avg_cost.name) + vars = [v.name for v in main_prog.list_vars() if v.persistable] + var_collector2 = VarCollector(main_prog, vars, use_ema=False) + values = var_collector2.run(train_loader, + exe, + step=None, + loss_name=avg_cost.name) + var_collector2.pdf(values) + + +if __name__ == '__main__': + TestAnalysisHelper('test_analysis_helper').test_analysis_helper()