diff --git a/demo/quant/pact_quant_aware/pact.py b/demo/quant/pact_quant_aware/pact.py deleted file mode 100644 index 26a2a5efd6e9b819db9b7134a62a1ac8c1fc296f..0000000000000000000000000000000000000000 --- a/demo/quant/pact_quant_aware/pact.py +++ /dev/null @@ -1,30 +0,0 @@ -import sys -import paddle -import paddle.fluid as fluid -from paddleslim.quant import quant_aware, convert -import numpy as np - -from paddle.fluid.layer_helper import LayerHelper - - -def pact(x, name=None): - helper = LayerHelper("pact", **locals()) - dtype = 'float32' - init_thres = 20 - u_param_attr = fluid.ParamAttr( - name=x.name + '_pact', - initializer=fluid.initializer.ConstantInitializer(value=init_thres), - regularizer=fluid.regularizer.L2Decay(0.0001), - learning_rate=1) - u_param = helper.create_parameter( - attr=u_param_attr, shape=[1], dtype=dtype) - x = fluid.layers.elementwise_sub( - x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param))) - x = fluid.layers.elementwise_add( - x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x))) - - return x - - -def get_optimizer(): - return fluid.optimizer.MomentumOptimizer(0.0001, 0.9) diff --git a/demo/quant/pact_quant_aware/train.py b/demo/quant/pact_quant_aware/train.py index 8eae690a3281ecfe51f0d4ee2d481a90ffa51f40..2e502f25ef4165e7e613350b2f4476c01d6e1f93 100644 --- a/demo/quant/pact_quant_aware/train.py +++ b/demo/quant/pact_quant_aware/train.py @@ -10,13 +10,14 @@ import numpy as np import paddle.fluid as fluid sys.path[0] = os.path.join( os.path.dirname("__file__"), os.path.pardir, os.path.pardir) -from paddleslim.common import get_logger +from paddleslim.common import get_logger, get_distribution, pdf from paddleslim.analysis import flops from paddleslim.quant import quant_aware, quant_post, convert +from paddleslim.quant import pact_thres import models from utility import add_arguments, print_arguments -sys.path.append('./') -from pact import * + +from paddle.fluid.layer_helper import LayerHelper quantization_model_save_dir = './quantization_models/' _logger = get_logger(__name__, level=logging.INFO) @@ -158,11 +159,63 @@ def compress(args): exe = fluid.Executor(place) exe.run(fluid.default_startup_program()) + val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size) + train_reader = paddle.fluid.io.batch( + train_reader, batch_size=args.batch_size, drop_last=True) + train_loader = fluid.io.DataLoader.from_generator( + feed_list=[image, label], + capacity=512, + use_double_buffer=True, + iterable=True) + valid_loader = fluid.io.DataLoader.from_generator( + feed_list=[image, label], + capacity=512, + use_double_buffer=True, + iterable=True) + places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places() + train_loader.set_sample_list_generator(train_reader, place) + valid_loader.set_sample_list_generator(val_reader, place) + + # get all activations distribution + act_names = [ + var.name for var in list(train_prog.list_vars()) + if not var.persistable and 'generated_var' not in var.name and + '@GRAD' not in var.name + ] + var_dist = get_distribution(train_prog, act_names, exe, train_loader) + train_loader.set_sample_list_generator(train_reader, places) + + # draw histogram + pdf(var_dist, pdf_save_dir='var_dist_pdf') + + # calculate appropriate pact clip threshold + pact_alphas = pact_thres(var_dist) + # 2. quantization transform programs (training aware) # Make some quantization transforms in the graph before training and testing. # According to the weight and activation quantization type, the graph will be added # some fake quantize operators and fake dequantize operators. + def pact(x): + helper = LayerHelper("pact", **locals()) + dtype = 'float32' + init_thres = pact_alphas[x.name.split('_tmp_input')[0]] + u_param_attr = fluid.ParamAttr( + name=x.name + '_pact', + initializer=fluid.initializer.ConstantInitializer(value=init_thres), + regularizer=fluid.regularizer.L2Decay(0.0001), + learning_rate=1) + u_param = helper.create_parameter( + attr=u_param_attr, shape=[1], dtype=dtype) + x = fluid.layers.elementwise_sub( + x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param))) + x = fluid.layers.elementwise_add( + x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x))) + return x + + def get_optimizer(): + return fluid.optimizer.MomentumOptimizer(0.0001, 0.9) + if args.use_pact: act_preprocess_func = pact optimizer_func = get_optimizer @@ -201,25 +254,6 @@ def compress(args): fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) - val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size) - train_reader = paddle.fluid.io.batch( - train_reader, batch_size=args.batch_size, drop_last=True) - - train_loader = fluid.io.DataLoader.from_generator( - feed_list=[image, label], - capacity=512, - use_double_buffer=True, - iterable=True) - valid_loader = fluid.io.DataLoader.from_generator( - feed_list=[image, label], - capacity=512, - use_double_buffer=True, - iterable=True) - - places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places() - train_loader.set_sample_list_generator(train_reader, places) - valid_loader.set_sample_list_generator(val_reader, place) - def test(epoch, program): batch_id = 0 acc_top1_ns = [] @@ -270,8 +304,7 @@ def compress(args): array = np.array(fluid.global_scope().find_var(var.name) .get_tensor()) threshold[var.name] = array[0] - print(threshold) - + _logger.info(threshold) batch_id += 1 build_strategy = fluid.BuildStrategy() @@ -307,6 +340,7 @@ def compress(args): exe, dirname=os.path.join(args.checkpoint_dir, 'best_model'), main_program=val_program) + # 3. Freeze the graph after training by adjusting the quantize # operators' order for the inference. # The dtype of float_program's weights is float32, but in int8 range. @@ -315,6 +349,8 @@ def compress(args): save_int8=True) print("eval best_model after convert") final_acc1 = test(best_epoch, float_program) + _logger.info("final acc:{}".format(final_acc1)) + # 4. Save inference model model_path = os.path.join(quantization_model_save_dir, args.model, 'act_' + quant_config['activation_quantize_type'] diff --git a/paddleslim/common/__init__.py b/paddleslim/common/__init__.py index 894d5d5a1a13d9fd1aa1fabcfc6e849df6fa17ca..274781b1a6eb6dec55086667d70fe30e511d9c7d 100644 --- a/paddleslim/common/__init__.py +++ b/paddleslim/common/__init__.py @@ -21,10 +21,10 @@ from .cached_reader import cached_reader from .server import Server from .client import Client from .meter import AvgrageMeter -from .analyze_helper import pdf +from .analyze_helper import pdf, get_distribution __all__ = [ 'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer', 'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter', - 'Server', 'Client', 'RLBaseController', 'pdf' + 'Server', 'Client', 'RLBaseController', 'pdf', 'get_distribution' ] diff --git a/paddleslim/common/analyze_helper.py b/paddleslim/common/analyze_helper.py index d5883bb597d027ce2c836e098a45af1383684c64..6ddc48801402d09f8b272906a87115bf9b5f88b4 100644 --- a/paddleslim/common/analyze_helper.py +++ b/paddleslim/common/analyze_helper.py @@ -12,55 +12,51 @@ # See the License for the specific language governing permissions and # limitations under the License. -import matplotlib -matplotlib.use('Agg') -import logging -import numpy as np -from matplotlib.backends.backend_pdf import PdfPages -import matplotlib.pyplot as plt import os - +import types import paddle import paddle.fluid as fluid +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +from matplotlib.backends.backend_pdf import PdfPages +import logging from ..common import get_logger _logger = get_logger(__name__, level=logging.INFO) -def pdf(program, - var_names, - executor=None, - batch_generator=None, - data_loader=None, - feed_vars=None, - fetch_list=None, - scope=None, - pdf_save_dir='tmp_pdf'): +def get_distribution(program, + var_names, + executor, + reader=None, + feed_vars=None, + scope=None): """ - Draw hist for distributtion of variables in that name is in var_names - + Get the variables distribution in the var_names list + Args: program(fluid.Program): program to analyze. - var_names(list): name of variables to analyze. When there is activation name in var_names, - you should set executor, one of batch_generator and data_loader, feed_list. + var_names(list): name of variables to analyze. When there is activation name in var_names, + you should set executor. executor(fluid.Executor, optional): The executor to run program. Default is None. - batch_generator(Python Generator, optional): The batch generator provides calibrate data for DataLoader, - and it returns a batch every time. For data_loader and batch_generator, - only one can be set. Default is None. - data_loader(fluid.io.DataLoader, optional): The data_loader provides calibrate data to run program. - Default is None. - feed_vars(list): feed variables for program. When you use batch_generator to provide data, + reader(Python Generator, fluid.io.DataLoader, optional): If you only want to get the distribution of weight parameters, + you do not need to provide a reader. Otherwise, a reader must be provided. The reader provides calibrate data, + and it returns a batch every time. It must be either a python generator or a iterable fluid dataloader. + When you use a python generator, please ensure that its behavior is consistent with `batch_generator`。 + You can get more detail about batch_generator at https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/io_cn/DataLoader_cn.html#id1 + feed_vars(list): feed variables for program. When you use python generator reader to provide data, you should set feed_vars. Default is None. - fetch_list(list): fetch list for program. Default is None. - scope(fluid.Scope, optional): The scope to run program, use it to load variables. + scope(fluid.Scope, optional): The scope to run program, use it to load variables. If scope is None, will use fluid.global_scope(). - pdf_save_dir(str): dirname to save pdf. Default is 'tmp_pdf' - + Returns: - dict: numpy array of variables that name in var_names + dict: numpy array of variables distribution that name in var_names """ scope = fluid.global_scope() if scope is None else scope assert isinstance(var_names, list), 'var_names is a list of variable name' + var_changed = [] real_names = [] weight_only = True for var in program.list_vars(): @@ -68,52 +64,70 @@ def pdf(program, if var.persistable == False: weight_only = False var.persistable = True + var_changed.append(var) real_names.append(var.name) - if weight_only == False: - if batch_generator is not None: + def update_var_dist(var_dist): + for name in real_names: + var = scope.find_var(name) + if var is not None: + var_array = np.array(var.get_tensor()) + var_dist[name] = var_array + else: + _logger.info("can't find var {} in scope.".format(name)) + return var_dist + + var_dist = {} + if weight_only: + var_dist = update_var_dist(var_dist) + else: + assert isinstance(reader, types.GeneratorType) or isinstance( + reader, fluid.reader.DataLoaderBase + ), "when var_names include activations'name, reader must be either a python generator or a fluid dataloader." + assert executor is not None, "when var_names include activations'name, executor must be set" + + if isinstance(reader, types.GeneratorType): assert feed_vars is not None, "When using batch_generator, feed_vars must be set" dataloader = fluid.io.DataLoader.from_generator( - feed_list=feed_vars, capacity=512, iterable=True) - dataloader.set_batch_generator(batch_generator, executor.place) - elif data_loader is not None: - dataloader = data_loader + feed_list=feed_vars, capacity=128, iterable=True) + dataloader.set_batch_generator(reader, executor.place) + elif isinstance(reader, fluid.reader.DataLoaderBase): + dataloader = reader else: _logger.info( "When both batch_generator and data_loader is None, var_names can only include weight names" ) return - assert executor is not None, "when var_names include activations'name, executor must be set" - assert fetch_list is not None, "when var_names include activations'name,, executor must be set" - for data in dataloader: - executor.run(program=program, - feed=data, - fetch_list=fetch_list, - return_numpy=False) + executor.run(program=program, feed=data) + var_dist = update_var_dist(var_dist) break - res_np = {} - for name in real_names: - var = fluid.global_scope().find_var(name) - if var is not None: - res_np[name] = np.array(var.get_tensor()) - else: - _logger.info( - "can't find var {}. Maybe you should set one of batch_generator and data_loader". - format(name)) - numbers = len(real_names) + for var in var_changed: + var.persistable = False + + return var_dist + + +def pdf(var_dist, pdf_save_dir='var_dist_pdf'): + """ + Draw hist for distributtion of variables in that in var_dist. + + Args: + var_dist(dict): numpy array of variables distribution. + pdf_save_dir(str): dirname to save pdf. Default is 'var_dist_pdf' + """ + numbers = len(var_dist) if pdf_save_dir is not None: if not os.path.exists(pdf_save_dir): os.mkdir(pdf_save_dir) pdf_path = os.path.join(pdf_save_dir, 'result.pdf') with PdfPages(pdf_path) as pdf: - idx = 1 - for name in res_np.keys(): - if idx % 10 == 0: - _logger.info("plt {}/{}".format(idx, numbers)) - arr = res_np[name] + for i, name in enumerate(var_dist.keys()): + if i % 10 == 0: + _logger.info("plt {}/{}".format(i, numbers)) + arr = var_dist[name] arr = arr.flatten() weights = np.ones_like(arr) / len(arr) plt.hist(arr, bins=1000, weights=weights) @@ -123,5 +137,4 @@ def pdf(program, plt.show() pdf.savefig() plt.close() - idx += 1 - return res_np + _logger.info("variables histogram have been saved as {}".format(pdf_path)) diff --git a/paddleslim/quant/__init__.py b/paddleslim/quant/__init__.py index 38ca531c57b44bb29730f529e990a1547818d147..d00fbd6fa609c551fde86d392597bd1d83f0ced4 100644 --- a/paddleslim/quant/__init__.py +++ b/paddleslim/quant/__init__.py @@ -29,3 +29,4 @@ except Exception as e: "please use Paddle >= 2.0.0 or develop version") from .quant_embedding import quant_embedding +from .utility import pact_thres \ No newline at end of file diff --git a/paddleslim/quant/utility.py b/paddleslim/quant/utility.py new file mode 100755 index 0000000000000000000000000000000000000000..bed8eddbb73ffa8a729abe5f88f83503583e1a04 --- /dev/null +++ b/paddleslim/quant/utility.py @@ -0,0 +1,27 @@ +import logging +import numpy as np +from ..common import get_logger +_logger = get_logger(__name__, level=logging.INFO) + + +def pact_thres(var_dist, q=100): + """ + Compute the qth percentile threshold of the data in var_dist. + + Args: + var_dist(dict): numpy array of variables distribution. + q(float): Percentile to compute which must be between 0 and 100 inclusive. Default is 100. + + Returns: + dict: the qth percentile of the array element in var_dist. + """ + var_percentile = {} + for var_name in var_dist.keys(): + var = var_dist[var_name] + var = var.flatten() + var = np.abs(var) + try: + var_percentile[var_name] = np.percentile(var, q) + except: + _logger.info('{} is empty in this program'.format(var_name)) + return var_percentile