提交 6992fd6c 编写于 作者: B baiyfbupt

Add automatic calculation of pact clip threshold

上级 271fbf44
import sys
import paddle
import paddle.fluid as fluid
from paddleslim.quant import quant_aware, convert
import numpy as np
from paddle.fluid.layer_helper import LayerHelper
def pact(x, name=None):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(
attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
......@@ -10,13 +10,14 @@ import numpy as np
import paddle.fluid as fluid
sys.path[0] = os.path.join(
os.path.dirname("__file__"), os.path.pardir, os.path.pardir)
from paddleslim.common import get_logger
from paddleslim.common import get_logger, get_distribution, pdf
from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, quant_post, convert
from paddleslim.quant import pact_thres
import models
from utility import add_arguments, print_arguments
sys.path.append('./')
from pact import *
from paddle.fluid.layer_helper import LayerHelper
quantization_model_save_dir = './quantization_models/'
_logger = get_logger(__name__, level=logging.INFO)
......@@ -158,11 +159,63 @@ def compress(args):
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.fluid.io.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
valid_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
train_loader.set_sample_list_generator(train_reader, place)
valid_loader.set_sample_list_generator(val_reader, place)
# get all activations distribution
act_names = [
var.name for var in list(train_prog.list_vars())
if not var.persistable and 'generated_var' not in var.name and
'@GRAD' not in var.name
]
var_dist = get_distribution(train_prog, act_names, exe, train_loader)
train_loader.set_sample_list_generator(train_reader, places)
# draw histogram
pdf(var_dist, pdf_save_dir='var_dist_pdf')
# calculate appropriate pact clip threshold
pact_alphas = pact_thres(var_dist)
# 2. quantization transform programs (training aware)
# Make some quantization transforms in the graph before training and testing.
# According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators.
def pact(x):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = pact_alphas[x.name.split('_tmp_input')[0]]
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(
attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
if args.use_pact:
act_preprocess_func = pact
optimizer_func = get_optimizer
......@@ -201,25 +254,6 @@ def compress(args):
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.fluid.io.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
valid_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
train_loader.set_sample_list_generator(train_reader, places)
valid_loader.set_sample_list_generator(val_reader, place)
def test(epoch, program):
batch_id = 0
acc_top1_ns = []
......@@ -270,8 +304,7 @@ def compress(args):
array = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
threshold[var.name] = array[0]
print(threshold)
_logger.info(threshold)
batch_id += 1
build_strategy = fluid.BuildStrategy()
......@@ -307,6 +340,7 @@ def compress(args):
exe,
dirname=os.path.join(args.checkpoint_dir, 'best_model'),
main_program=val_program)
# 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range.
......@@ -315,6 +349,8 @@ def compress(args):
save_int8=True)
print("eval best_model after convert")
final_acc1 = test(best_epoch, float_program)
_logger.info("final acc:{}".format(final_acc1))
# 4. Save inference model
model_path = os.path.join(quantization_model_save_dir, args.model,
'act_' + quant_config['activation_quantize_type']
......
......@@ -21,10 +21,10 @@ from .cached_reader import cached_reader
from .server import Server
from .client import Client
from .meter import AvgrageMeter
from .analyze_helper import pdf
from .analyze_helper import pdf, get_distribution
__all__ = [
'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer',
'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter',
'Server', 'Client', 'RLBaseController', 'pdf'
'Server', 'Client', 'RLBaseController', 'pdf', 'get_distribution'
]
......@@ -12,55 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib
matplotlib.use('Agg')
import logging
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import os
import types
import paddle
import paddle.fluid as fluid
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import logging
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
def pdf(program,
var_names,
executor=None,
batch_generator=None,
data_loader=None,
feed_vars=None,
fetch_list=None,
scope=None,
pdf_save_dir='tmp_pdf'):
def get_distribution(program,
var_names,
executor,
reader=None,
feed_vars=None,
scope=None):
"""
Draw hist for distributtion of variables in that name is in var_names
Get the variables distribution in the var_names list
Args:
program(fluid.Program): program to analyze.
var_names(list): name of variables to analyze. When there is activation name in var_names,
you should set executor, one of batch_generator and data_loader, feed_list.
var_names(list): name of variables to analyze. When there is activation name in var_names,
you should set executor.
executor(fluid.Executor, optional): The executor to run program. Default is None.
batch_generator(Python Generator, optional): The batch generator provides calibrate data for DataLoader,
and it returns a batch every time. For data_loader and batch_generator,
only one can be set. Default is None.
data_loader(fluid.io.DataLoader, optional): The data_loader provides calibrate data to run program.
Default is None.
feed_vars(list): feed variables for program. When you use batch_generator to provide data,
reader(Python Generator, fluid.io.DataLoader, optional): If you only want to get the distribution of weight parameters,
you do not need to provide a reader. Otherwise, a reader must be provided. The reader provides calibrate data,
and it returns a batch every time. It must be either a python generator or a iterable fluid dataloader.
When you use a python generator, please ensure that its behavior is consistent with `batch_generator`。
You can get more detail about batch_generator at https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/io_cn/DataLoader_cn.html#id1
feed_vars(list): feed variables for program. When you use python generator reader to provide data,
you should set feed_vars. Default is None.
fetch_list(list): fetch list for program. Default is None.
scope(fluid.Scope, optional): The scope to run program, use it to load variables.
scope(fluid.Scope, optional): The scope to run program, use it to load variables.
If scope is None, will use fluid.global_scope().
pdf_save_dir(str): dirname to save pdf. Default is 'tmp_pdf'
Returns:
dict: numpy array of variables that name in var_names
dict: numpy array of variables distribution that name in var_names
"""
scope = fluid.global_scope() if scope is None else scope
assert isinstance(var_names, list), 'var_names is a list of variable name'
var_changed = []
real_names = []
weight_only = True
for var in program.list_vars():
......@@ -68,52 +64,70 @@ def pdf(program,
if var.persistable == False:
weight_only = False
var.persistable = True
var_changed.append(var)
real_names.append(var.name)
if weight_only == False:
if batch_generator is not None:
def update_var_dist(var_dist):
for name in real_names:
var = scope.find_var(name)
if var is not None:
var_array = np.array(var.get_tensor())
var_dist[name] = var_array
else:
_logger.info("can't find var {} in scope.".format(name))
return var_dist
var_dist = {}
if weight_only:
var_dist = update_var_dist(var_dist)
else:
assert isinstance(reader, types.GeneratorType) or isinstance(
reader, fluid.reader.DataLoaderBase
), "when var_names include activations'name, reader must be either a python generator or a fluid dataloader."
assert executor is not None, "when var_names include activations'name, executor must be set"
if isinstance(reader, types.GeneratorType):
assert feed_vars is not None, "When using batch_generator, feed_vars must be set"
dataloader = fluid.io.DataLoader.from_generator(
feed_list=feed_vars, capacity=512, iterable=True)
dataloader.set_batch_generator(batch_generator, executor.place)
elif data_loader is not None:
dataloader = data_loader
feed_list=feed_vars, capacity=128, iterable=True)
dataloader.set_batch_generator(reader, executor.place)
elif isinstance(reader, fluid.reader.DataLoaderBase):
dataloader = reader
else:
_logger.info(
"When both batch_generator and data_loader is None, var_names can only include weight names"
)
return
assert executor is not None, "when var_names include activations'name, executor must be set"
assert fetch_list is not None, "when var_names include activations'name,, executor must be set"
for data in dataloader:
executor.run(program=program,
feed=data,
fetch_list=fetch_list,
return_numpy=False)
executor.run(program=program, feed=data)
var_dist = update_var_dist(var_dist)
break
res_np = {}
for name in real_names:
var = fluid.global_scope().find_var(name)
if var is not None:
res_np[name] = np.array(var.get_tensor())
else:
_logger.info(
"can't find var {}. Maybe you should set one of batch_generator and data_loader".
format(name))
numbers = len(real_names)
for var in var_changed:
var.persistable = False
return var_dist
def pdf(var_dist, pdf_save_dir='var_dist_pdf'):
"""
Draw hist for distributtion of variables in that in var_dist.
Args:
var_dist(dict): numpy array of variables distribution.
pdf_save_dir(str): dirname to save pdf. Default is 'var_dist_pdf'
"""
numbers = len(var_dist)
if pdf_save_dir is not None:
if not os.path.exists(pdf_save_dir):
os.mkdir(pdf_save_dir)
pdf_path = os.path.join(pdf_save_dir, 'result.pdf')
with PdfPages(pdf_path) as pdf:
idx = 1
for name in res_np.keys():
if idx % 10 == 0:
_logger.info("plt {}/{}".format(idx, numbers))
arr = res_np[name]
for i, name in enumerate(var_dist.keys()):
if i % 10 == 0:
_logger.info("plt {}/{}".format(i, numbers))
arr = var_dist[name]
arr = arr.flatten()
weights = np.ones_like(arr) / len(arr)
plt.hist(arr, bins=1000, weights=weights)
......@@ -123,5 +137,4 @@ def pdf(program,
plt.show()
pdf.savefig()
plt.close()
idx += 1
return res_np
_logger.info("variables histogram have been saved as {}".format(pdf_path))
......@@ -29,3 +29,4 @@ except Exception as e:
"please use Paddle >= 2.0.0 or develop version")
from .quant_embedding import quant_embedding
from .utility import pact_thres
\ No newline at end of file
import logging
import numpy as np
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
def pact_thres(var_dist, q=100):
"""
Compute the qth percentile threshold of the data in var_dist.
Args:
var_dist(dict): numpy array of variables distribution.
q(float): Percentile to compute which must be between 0 and 100 inclusive. Default is 100.
Returns:
dict: the qth percentile of the array element in var_dist.
"""
var_percentile = {}
for var_name in var_dist.keys():
var = var_dist[var_name]
var = var.flatten()
var = np.abs(var)
try:
var_percentile[var_name] = np.percentile(var, q)
except:
_logger.info('{} is empty in this program'.format(var_name))
return var_percentile
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册