提交 6992fd6c 编写于 作者: B baiyfbupt

Add automatic calculation of pact clip threshold

上级 271fbf44
import sys
import paddle
import paddle.fluid as fluid
from paddleslim.quant import quant_aware, convert
import numpy as np
from paddle.fluid.layer_helper import LayerHelper
def pact(x, name=None):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(
attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
...@@ -10,13 +10,14 @@ import numpy as np ...@@ -10,13 +10,14 @@ import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
sys.path[0] = os.path.join( sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir) os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
from paddleslim.common import get_logger from paddleslim.common import get_logger, get_distribution, pdf
from paddleslim.analysis import flops from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, quant_post, convert from paddleslim.quant import quant_aware, quant_post, convert
from paddleslim.quant import pact_thres
import models import models
from utility import add_arguments, print_arguments from utility import add_arguments, print_arguments
sys.path.append('./')
from pact import * from paddle.fluid.layer_helper import LayerHelper
quantization_model_save_dir = './quantization_models/' quantization_model_save_dir = './quantization_models/'
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
...@@ -158,11 +159,63 @@ def compress(args): ...@@ -158,11 +159,63 @@ def compress(args):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.fluid.io.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
valid_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
train_loader.set_sample_list_generator(train_reader, place)
valid_loader.set_sample_list_generator(val_reader, place)
# get all activations distribution
act_names = [
var.name for var in list(train_prog.list_vars())
if not var.persistable and 'generated_var' not in var.name and
'@GRAD' not in var.name
]
var_dist = get_distribution(train_prog, act_names, exe, train_loader)
train_loader.set_sample_list_generator(train_reader, places)
# draw histogram
pdf(var_dist, pdf_save_dir='var_dist_pdf')
# calculate appropriate pact clip threshold
pact_alphas = pact_thres(var_dist)
# 2. quantization transform programs (training aware) # 2. quantization transform programs (training aware)
# Make some quantization transforms in the graph before training and testing. # Make some quantization transforms in the graph before training and testing.
# According to the weight and activation quantization type, the graph will be added # According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators. # some fake quantize operators and fake dequantize operators.
def pact(x):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = pact_alphas[x.name.split('_tmp_input')[0]]
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(
attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
if args.use_pact: if args.use_pact:
act_preprocess_func = pact act_preprocess_func = pact
optimizer_func = get_optimizer optimizer_func = get_optimizer
...@@ -201,25 +254,6 @@ def compress(args): ...@@ -201,25 +254,6 @@ def compress(args):
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist) fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.fluid.io.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
valid_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
train_loader.set_sample_list_generator(train_reader, places)
valid_loader.set_sample_list_generator(val_reader, place)
def test(epoch, program): def test(epoch, program):
batch_id = 0 batch_id = 0
acc_top1_ns = [] acc_top1_ns = []
...@@ -270,8 +304,7 @@ def compress(args): ...@@ -270,8 +304,7 @@ def compress(args):
array = np.array(fluid.global_scope().find_var(var.name) array = np.array(fluid.global_scope().find_var(var.name)
.get_tensor()) .get_tensor())
threshold[var.name] = array[0] threshold[var.name] = array[0]
print(threshold) _logger.info(threshold)
batch_id += 1 batch_id += 1
build_strategy = fluid.BuildStrategy() build_strategy = fluid.BuildStrategy()
...@@ -307,6 +340,7 @@ def compress(args): ...@@ -307,6 +340,7 @@ def compress(args):
exe, exe,
dirname=os.path.join(args.checkpoint_dir, 'best_model'), dirname=os.path.join(args.checkpoint_dir, 'best_model'),
main_program=val_program) main_program=val_program)
# 3. Freeze the graph after training by adjusting the quantize # 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference. # operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range. # The dtype of float_program's weights is float32, but in int8 range.
...@@ -315,6 +349,8 @@ def compress(args): ...@@ -315,6 +349,8 @@ def compress(args):
save_int8=True) save_int8=True)
print("eval best_model after convert") print("eval best_model after convert")
final_acc1 = test(best_epoch, float_program) final_acc1 = test(best_epoch, float_program)
_logger.info("final acc:{}".format(final_acc1))
# 4. Save inference model # 4. Save inference model
model_path = os.path.join(quantization_model_save_dir, args.model, model_path = os.path.join(quantization_model_save_dir, args.model,
'act_' + quant_config['activation_quantize_type'] 'act_' + quant_config['activation_quantize_type']
......
...@@ -21,10 +21,10 @@ from .cached_reader import cached_reader ...@@ -21,10 +21,10 @@ from .cached_reader import cached_reader
from .server import Server from .server import Server
from .client import Client from .client import Client
from .meter import AvgrageMeter from .meter import AvgrageMeter
from .analyze_helper import pdf from .analyze_helper import pdf, get_distribution
__all__ = [ __all__ = [
'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer', 'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer',
'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter', 'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter',
'Server', 'Client', 'RLBaseController', 'pdf' 'Server', 'Client', 'RLBaseController', 'pdf', 'get_distribution'
] ]
...@@ -12,55 +12,51 @@ ...@@ -12,55 +12,51 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import matplotlib
matplotlib.use('Agg')
import logging
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import os import os
import types
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import logging
from ..common import get_logger from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO) _logger = get_logger(__name__, level=logging.INFO)
def pdf(program, def get_distribution(program,
var_names, var_names,
executor=None, executor,
batch_generator=None, reader=None,
data_loader=None, feed_vars=None,
feed_vars=None, scope=None):
fetch_list=None,
scope=None,
pdf_save_dir='tmp_pdf'):
""" """
Draw hist for distributtion of variables in that name is in var_names Get the variables distribution in the var_names list
Args: Args:
program(fluid.Program): program to analyze. program(fluid.Program): program to analyze.
var_names(list): name of variables to analyze. When there is activation name in var_names, var_names(list): name of variables to analyze. When there is activation name in var_names,
you should set executor, one of batch_generator and data_loader, feed_list. you should set executor.
executor(fluid.Executor, optional): The executor to run program. Default is None. executor(fluid.Executor, optional): The executor to run program. Default is None.
batch_generator(Python Generator, optional): The batch generator provides calibrate data for DataLoader, reader(Python Generator, fluid.io.DataLoader, optional): If you only want to get the distribution of weight parameters,
and it returns a batch every time. For data_loader and batch_generator, you do not need to provide a reader. Otherwise, a reader must be provided. The reader provides calibrate data,
only one can be set. Default is None. and it returns a batch every time. It must be either a python generator or a iterable fluid dataloader.
data_loader(fluid.io.DataLoader, optional): The data_loader provides calibrate data to run program. When you use a python generator, please ensure that its behavior is consistent with `batch_generator`。
Default is None. You can get more detail about batch_generator at https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/io_cn/DataLoader_cn.html#id1
feed_vars(list): feed variables for program. When you use batch_generator to provide data, feed_vars(list): feed variables for program. When you use python generator reader to provide data,
you should set feed_vars. Default is None. you should set feed_vars. Default is None.
fetch_list(list): fetch list for program. Default is None. scope(fluid.Scope, optional): The scope to run program, use it to load variables.
scope(fluid.Scope, optional): The scope to run program, use it to load variables.
If scope is None, will use fluid.global_scope(). If scope is None, will use fluid.global_scope().
pdf_save_dir(str): dirname to save pdf. Default is 'tmp_pdf'
Returns: Returns:
dict: numpy array of variables that name in var_names dict: numpy array of variables distribution that name in var_names
""" """
scope = fluid.global_scope() if scope is None else scope scope = fluid.global_scope() if scope is None else scope
assert isinstance(var_names, list), 'var_names is a list of variable name' assert isinstance(var_names, list), 'var_names is a list of variable name'
var_changed = []
real_names = [] real_names = []
weight_only = True weight_only = True
for var in program.list_vars(): for var in program.list_vars():
...@@ -68,52 +64,70 @@ def pdf(program, ...@@ -68,52 +64,70 @@ def pdf(program,
if var.persistable == False: if var.persistable == False:
weight_only = False weight_only = False
var.persistable = True var.persistable = True
var_changed.append(var)
real_names.append(var.name) real_names.append(var.name)
if weight_only == False: def update_var_dist(var_dist):
if batch_generator is not None: for name in real_names:
var = scope.find_var(name)
if var is not None:
var_array = np.array(var.get_tensor())
var_dist[name] = var_array
else:
_logger.info("can't find var {} in scope.".format(name))
return var_dist
var_dist = {}
if weight_only:
var_dist = update_var_dist(var_dist)
else:
assert isinstance(reader, types.GeneratorType) or isinstance(
reader, fluid.reader.DataLoaderBase
), "when var_names include activations'name, reader must be either a python generator or a fluid dataloader."
assert executor is not None, "when var_names include activations'name, executor must be set"
if isinstance(reader, types.GeneratorType):
assert feed_vars is not None, "When using batch_generator, feed_vars must be set" assert feed_vars is not None, "When using batch_generator, feed_vars must be set"
dataloader = fluid.io.DataLoader.from_generator( dataloader = fluid.io.DataLoader.from_generator(
feed_list=feed_vars, capacity=512, iterable=True) feed_list=feed_vars, capacity=128, iterable=True)
dataloader.set_batch_generator(batch_generator, executor.place) dataloader.set_batch_generator(reader, executor.place)
elif data_loader is not None: elif isinstance(reader, fluid.reader.DataLoaderBase):
dataloader = data_loader dataloader = reader
else: else:
_logger.info( _logger.info(
"When both batch_generator and data_loader is None, var_names can only include weight names" "When both batch_generator and data_loader is None, var_names can only include weight names"
) )
return return
assert executor is not None, "when var_names include activations'name, executor must be set"
assert fetch_list is not None, "when var_names include activations'name,, executor must be set"
for data in dataloader: for data in dataloader:
executor.run(program=program, executor.run(program=program, feed=data)
feed=data, var_dist = update_var_dist(var_dist)
fetch_list=fetch_list,
return_numpy=False)
break break
res_np = {} for var in var_changed:
for name in real_names: var.persistable = False
var = fluid.global_scope().find_var(name)
if var is not None: return var_dist
res_np[name] = np.array(var.get_tensor())
else:
_logger.info( def pdf(var_dist, pdf_save_dir='var_dist_pdf'):
"can't find var {}. Maybe you should set one of batch_generator and data_loader". """
format(name)) Draw hist for distributtion of variables in that in var_dist.
numbers = len(real_names)
Args:
var_dist(dict): numpy array of variables distribution.
pdf_save_dir(str): dirname to save pdf. Default is 'var_dist_pdf'
"""
numbers = len(var_dist)
if pdf_save_dir is not None: if pdf_save_dir is not None:
if not os.path.exists(pdf_save_dir): if not os.path.exists(pdf_save_dir):
os.mkdir(pdf_save_dir) os.mkdir(pdf_save_dir)
pdf_path = os.path.join(pdf_save_dir, 'result.pdf') pdf_path = os.path.join(pdf_save_dir, 'result.pdf')
with PdfPages(pdf_path) as pdf: with PdfPages(pdf_path) as pdf:
idx = 1 for i, name in enumerate(var_dist.keys()):
for name in res_np.keys(): if i % 10 == 0:
if idx % 10 == 0: _logger.info("plt {}/{}".format(i, numbers))
_logger.info("plt {}/{}".format(idx, numbers)) arr = var_dist[name]
arr = res_np[name]
arr = arr.flatten() arr = arr.flatten()
weights = np.ones_like(arr) / len(arr) weights = np.ones_like(arr) / len(arr)
plt.hist(arr, bins=1000, weights=weights) plt.hist(arr, bins=1000, weights=weights)
...@@ -123,5 +137,4 @@ def pdf(program, ...@@ -123,5 +137,4 @@ def pdf(program,
plt.show() plt.show()
pdf.savefig() pdf.savefig()
plt.close() plt.close()
idx += 1 _logger.info("variables histogram have been saved as {}".format(pdf_path))
return res_np
...@@ -29,3 +29,4 @@ except Exception as e: ...@@ -29,3 +29,4 @@ except Exception as e:
"please use Paddle >= 2.0.0 or develop version") "please use Paddle >= 2.0.0 or develop version")
from .quant_embedding import quant_embedding from .quant_embedding import quant_embedding
from .utility import pact_thres
\ No newline at end of file
import logging
import numpy as np
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
def pact_thres(var_dist, q=100):
"""
Compute the qth percentile threshold of the data in var_dist.
Args:
var_dist(dict): numpy array of variables distribution.
q(float): Percentile to compute which must be between 0 and 100 inclusive. Default is 100.
Returns:
dict: the qth percentile of the array element in var_dist.
"""
var_percentile = {}
for var_name in var_dist.keys():
var = var_dist[var_name]
var = var.flatten()
var = np.abs(var)
try:
var_percentile[var_name] = np.percentile(var, q)
except:
_logger.info('{} is empty in this program'.format(var_name))
return var_percentile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册