未验证 提交 13017580 编写于 作者: B Bai Yifan 提交者: GitHub

Add automatic calculation of pact clip threshold (#398)

* add adaptive pact threshold matching

* update readme

* fix coverage

* fix coverage
上级 a419fe10
......@@ -179,7 +179,10 @@ python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/Mob
使用PACT量化训练
```
python train.py --model MobileNetV3_large_x1_0 --pretrained_model ./pretrain/MobileNetV3_large_x1_0_ssld_pretrained --num_epochs 30 --lr 0.0001 --use_pact True --batch_size 128 --lr_strategy=piecewise_decay --step_epochs 20 --l2_decay 1e-5
# 先分析MobileNetV3模型激活值分布,来初始化PACT截断阈值
python train.py --analysis=True
# 启动PACT量化训练
python train.py
```
输出结果为
......
import sys
import paddle
import paddle.fluid as fluid
from paddleslim.quant import quant_aware, convert
import numpy as np
from paddle.fluid.layer_helper import LayerHelper
def pact(x, name=None):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(
attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_optimizer():
return fluid.optimizer.MomentumOptimizer(0.0001, 0.9)
......@@ -7,45 +7,49 @@ import functools
import math
import time
import numpy as np
from collections import defaultdict
import paddle.fluid as fluid
sys.path.append(os.path.dirname("__file__"))
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
from paddleslim.common import get_logger
from paddleslim.common import get_logger, VarCollector
from paddleslim.analysis import flops
from paddleslim.quant import quant_aware, quant_post, convert
import models
from utility import add_arguments, print_arguments
from pact import *
from paddle.fluid.layer_helper import LayerHelper
quantization_model_save_dir = './quantization_models/'
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass
_logger = get_logger(__name__, level=logging.INFO)
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64 * 4,
add_arg('batch_size', int, 128,
"Minibatch size.")
add_arg('use_gpu', bool, True,
"Whether to use GPU or not.")
add_arg('model', str, "MobileNet",
add_arg('model', str, "MobileNetV3_large_x1_0",
"The target model.")
add_arg('pretrained_model', str, "../pretrained_model/MobileNetV1_pretrained",
add_arg('pretrained_model', str, "./pretrain/MobileNetV3_large_x1_0_ssld_pretrained",
"Whether to use pretrained model.")
add_arg('lr', float, 0.0001,
add_arg('lr', float, 0.001,
"The learning rate used to fine-tune pruned model.")
add_arg('lr_strategy', str, "piecewise_decay",
"The learning rate decay strategy.")
add_arg('l2_decay', float, 3e-5,
add_arg('l2_decay', float, 1e-5,
"The l2_decay parameter.")
add_arg('momentum_rate', float, 0.9,
"The value of momentum_rate.")
add_arg('num_epochs', int, 1,
add_arg('num_epochs', int, 30,
"The number of total epochs.")
add_arg('total_images', int, 1281167,
"The number of total training images.")
parser.add_argument('--step_epochs', nargs='+', type=int,
default=[30, 60, 90],
default=[20],
help="piecewise decay step")
add_arg('config_file', str, None,
"The config file for compression with yaml format.")
......@@ -61,6 +65,8 @@ add_arg('output_dir', str, "output/MobileNetV3_large_x1_0",
"model save dir")
add_arg('use_pact', bool, True,
"Whether to use PACT or not.")
add_arg('analysis', bool, False,
"Whether analysis variables distribution.")
# yapf: enable
......@@ -68,7 +74,9 @@ model_list = [m for m in dir(models) if "__" not in m]
def piecewise_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
step = int(
math.ceil(float(args.total_images) / (args.batch_size * len(places))))
bd = [step * e for e in args.step_epochs]
lr = [args.lr * (0.1**i) for i in range(len(bd) + 1)]
learning_rate = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
......@@ -76,18 +84,20 @@ def piecewise_decay(args):
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
return learning_rate, optimizer
def cosine_decay(args):
step = int(math.ceil(float(args.total_images) / args.batch_size))
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
step = int(
math.ceil(float(args.total_images) / (args.batch_size * len(places))))
learning_rate = fluid.layers.cosine_decay(
learning_rate=args.lr, step_each_epoch=step, epochs=args.num_epochs)
optimizer = fluid.optimizer.Momentum(
learning_rate=learning_rate,
momentum=args.momentum_rate,
regularization=fluid.regularizer.L2Decay(args.l2_decay))
return optimizer
return learning_rate, optimizer
def create_optimizer(args):
......@@ -98,30 +108,7 @@ def create_optimizer(args):
def compress(args):
# 1. quantization configs
quant_config = {
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
}
train_reader = None
test_reader = None
if args.data == "mnist":
import paddle.dataset.mnist as reader
train_reader = reader.train()
......@@ -155,18 +142,126 @@ def compress(args):
train_prog = fluid.default_main_program()
val_program = fluid.default_main_program().clone(for_test=True)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
opt = create_optimizer(args)
opt.minimize(avg_cost)
if not args.analysis:
learning_rate, opt = create_optimizer(args)
opt.minimize(avg_cost)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
train_reader = paddle.fluid.io.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
places = fluid.cuda_places() if args.use_gpu else fluid.cpu_places()
train_loader.set_sample_list_generator(train_reader, places)
val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size)
valid_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
valid_loader.set_sample_list_generator(val_reader, places[0])
if args.analysis:
# get all activations names
activates = [
'pool2d_1.tmp_0', 'tmp_35', 'batch_norm_21.tmp_2', 'tmp_26',
'elementwise_mul_5.tmp_0', 'pool2d_5.tmp_0',
'elementwise_add_5.tmp_0', 'relu_2.tmp_0', 'pool2d_3.tmp_0',
'conv2d_40.tmp_2', 'elementwise_mul_0.tmp_0', 'tmp_62',
'elementwise_add_8.tmp_0', 'batch_norm_39.tmp_2', 'conv2d_32.tmp_2',
'tmp_17', 'tmp_5', 'elementwise_add_9.tmp_0', 'pool2d_4.tmp_0',
'relu_0.tmp_0', 'tmp_53', 'relu_3.tmp_0', 'elementwise_add_4.tmp_0',
'elementwise_add_6.tmp_0', 'tmp_11', 'conv2d_36.tmp_2',
'relu_8.tmp_0', 'relu_5.tmp_0', 'pool2d_7.tmp_0',
'elementwise_add_2.tmp_0', 'elementwise_add_7.tmp_0',
'pool2d_2.tmp_0', 'tmp_47', 'batch_norm_12.tmp_2',
'elementwise_mul_6.tmp_0', 'elementwise_mul_7.tmp_0',
'pool2d_6.tmp_0', 'relu_6.tmp_0', 'elementwise_add_0.tmp_0',
'elementwise_mul_3.tmp_0', 'conv2d_12.tmp_2',
'elementwise_mul_2.tmp_0', 'tmp_8', 'tmp_2', 'conv2d_8.tmp_2',
'elementwise_add_3.tmp_0', 'elementwise_mul_1.tmp_0',
'pool2d_8.tmp_0', 'conv2d_28.tmp_2', 'image', 'conv2d_16.tmp_2',
'batch_norm_33.tmp_2', 'relu_1.tmp_0', 'pool2d_0.tmp_0', 'tmp_20',
'conv2d_44.tmp_2', 'relu_10.tmp_0', 'tmp_41', 'relu_4.tmp_0',
'elementwise_add_1.tmp_0', 'tmp_23', 'batch_norm_6.tmp_2', 'tmp_29',
'elementwise_mul_4.tmp_0', 'tmp_14'
]
var_collector = VarCollector(train_prog, activates, use_ema=True)
values = var_collector.abs_max_run(
train_loader, exe, step=None, loss_name=avg_cost.name)
np.save('pact_thres.npy', values)
_logger.info(values)
_logger.info("PACT threshold have been saved as pact_thres.npy")
# Draw Histogram in 'dist_pdf/result.pdf'
# var_collector.pdf(values)
return
values = defaultdict(lambda: 20)
try:
values = np.load("pact_thres.npy", allow_pickle=True).item()
values.update(tmp)
_logger.info("pact_thres.npy info loaded.")
except:
_logger.info(
"cannot find pact_thres.npy. Set init PACT threshold as 20.")
_logger.info(values)
# 1. quantization configs
quant_config = {
# weight quantize type, default is 'channel_wise_abs_max'
'weight_quantize_type': 'channel_wise_abs_max',
# activation quantize type, default is 'moving_average_abs_max'
'activation_quantize_type': 'moving_average_abs_max',
# weight quantize bit num, default is 8
'weight_bits': 8,
# activation quantize bit num, default is 8
'activation_bits': 8,
# ops of name_scope in not_quant_pattern list, will not be quantized
'not_quant_pattern': ['skip_quant'],
# ops of type in quantize_op_types, will be quantized
'quantize_op_types': ['conv2d', 'depthwise_conv2d', 'mul'],
# data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
'dtype': 'int8',
# window size for 'range_abs_max' quantization. defaulf is 10000
'window_size': 10000,
# The decay coefficient of moving average, default is 0.9
'moving_rate': 0.9,
}
# 2. quantization transform programs (training aware)
# Make some quantization transforms in the graph before training and testing.
# According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators.
def pact(x):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = values[x.name.split('_tmp_input')[0]]
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(
attr=u_param_attr, shape=[1], dtype=dtype)
part_a = fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param))
part_b = fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x))
x = x - part_a + part_b
return x
def get_optimizer():
return fluid.optimizer.MomentumOptimizer(args.lr, 0.9)
if args.use_pact:
act_preprocess_func = pact
optimizer_func = get_optimizer
......@@ -205,28 +300,18 @@ def compress(args):
fluid.io.load_vars(exe, args.pretrained_model, predicate=if_exist)
val_reader = paddle.fluid.io.batch(val_reader, batch_size=args.batch_size)
train_reader = paddle.fluid.io.batch(
train_reader, batch_size=args.batch_size, drop_last=True)
train_feeder = feeder = fluid.DataFeeder([image, label], place)
val_feeder = feeder = fluid.DataFeeder(
[image, label], place, program=val_program)
def test(epoch, program):
batch_id = 0
acc_top1_ns = []
acc_top5_ns = []
for data in val_reader():
for data in valid_loader():
start_time = time.time()
acc_top1_n, acc_top5_n = exe.run(
program,
feed=train_feeder.feed(data),
fetch_list=[acc_top1.name, acc_top5.name])
program, feed=data, fetch_list=[acc_top1.name, acc_top5.name])
end_time = time.time()
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
"Eval epoch[{}] batch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}; time: {:.3f}".
format(epoch, batch_id,
np.mean(acc_top1_n),
np.mean(acc_top5_n), end_time - start_time))
......@@ -234,30 +319,35 @@ def compress(args):
acc_top5_ns.append(np.mean(acc_top5_n))
batch_id += 1
_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
epoch,
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
_logger.info(
"Final eval epoch[{}] - acc_top1: {:.6f}; acc_top5: {:.6f}".format(
epoch,
np.mean(np.array(acc_top1_ns)), np.mean(np.array(acc_top5_ns))))
return np.mean(np.array(acc_top1_ns))
def train(epoch, compiled_train_prog):
batch_id = 0
for data in train_reader():
for data in train_loader():
start_time = time.time()
loss_n, acc_top1_n, acc_top5_n = exe.run(
lr_n, loss_n, acc_top1_n, acc_top5_n = exe.run(
compiled_train_prog,
feed=train_feeder.feed(data),
fetch_list=[avg_cost.name, acc_top1.name, acc_top5.name])
feed=data,
fetch_list=[
learning_rate.name, avg_cost.name, acc_top1.name,
acc_top5.name
])
end_time = time.time()
lr_n = np.mean(lr_n)
loss_n = np.mean(loss_n)
acc_top1_n = np.mean(acc_top1_n)
acc_top5_n = np.mean(acc_top5_n)
if batch_id % args.log_period == 0:
_logger.info(
"epoch[{}]-batch[{}] - loss: {}; acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id, loss_n, acc_top1_n, acc_top5_n,
end_time - start_time))
"epoch[{}]-batch[{}] lr: {:.6f} - loss: {:.6f}; acc_top1: {:.6f}; acc_top5: {:.6f}; time: {:.3f}".
format(epoch, batch_id, lr_n, loss_n, acc_top1_n,
acc_top5_n, end_time - start_time))
if args.use_pact and batch_id % 1000 == 0:
threshold = {}
......@@ -266,15 +356,12 @@ def compress(args):
array = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
threshold[var.name] = array[0]
print(threshold)
_logger.info(threshold)
batch_id += 1
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False
build_strategy.sync_batch_norm = False
exec_strategy = fluid.ExecutionStrategy()
compiled_train_prog = compiled_train_prog.with_data_parallel(
loss_name=avg_cost.name,
......@@ -297,9 +384,16 @@ def compress(args):
v = fluid.global_scope().find_var('@LR_DECAY_COUNTER@').get_tensor()
v.set(np.array([start_step]).astype(np.float32), place)
best_eval_acc1 = 0
best_acc1_epoch = 0
for i in range(start_epoch, args.num_epochs):
train(i, compiled_train_prog)
acc1 = test(i, val_program)
if acc1 > best_eval_acc1:
best_eval_acc1 = acc1
best_acc1_epoch = i
_logger.info("Best Validation Acc1: {:.6f}, at epoch {}".format(
best_eval_acc1, best_acc1_epoch))
fluid.io.save_persistables(
exe,
dirname=os.path.join(args.output_dir, str(i)),
......@@ -311,25 +405,28 @@ def compress(args):
exe,
dirname=os.path.join(args.output_dir, 'best_model'),
main_program=val_program)
if os.path.exists(os.path.join(args.output_dir, 'best_model')):
fluid.io.load_persistables(
exe,
dirname=os.path.join(args.output_dir, 'best_model'),
main_program=val_program)
# 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range.
float_program, int8_program = convert(val_program, place, quant_config, \
scope=None, \
save_int8=True)
print("eval best_model after convert")
_logger.info("eval best_model after convert")
final_acc1 = test(best_epoch, float_program)
_logger.info("final acc:{}".format(final_acc1))
# 4. Save inference model
model_path = os.path.join(quantization_model_save_dir, args.model,
'act_' + quant_config['activation_quantize_type']
+ '_w_' + quant_config['weight_quantize_type'])
float_path = os.path.join(model_path, 'float')
int8_path = os.path.join(model_path, 'int8')
if not os.path.isdir(model_path):
os.makedirs(model_path)
......@@ -342,15 +439,6 @@ def compress(args):
model_filename=float_path + '/model',
params_filename=float_path + '/params')
fluid.io.save_inference_model(
dirname=int8_path,
feeded_var_names=[image.name],
target_vars=[out],
executor=exe,
main_program=int8_program,
model_filename=int8_path + '/model',
params_filename=int8_path + '/params')
def main():
paddle.enable_static()
......
......@@ -21,10 +21,10 @@ from .cached_reader import cached_reader
from .server import Server
from .client import Client
from .meter import AvgrageMeter
from .analyze_helper import pdf
from .analyze_helper import VarCollector
__all__ = [
'EvolutionaryController', 'SAController', 'get_logger', 'ControllerServer',
'ControllerClient', 'lock', 'unlock', 'cached_reader', 'AvgrageMeter',
'Server', 'Client', 'RLBaseController', 'pdf'
'Server', 'Client', 'RLBaseController', 'VarCollector'
]
......@@ -12,116 +12,167 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import matplotlib
matplotlib.use('Agg')
import logging
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.pyplot as plt
import os
import types
import paddle
import paddle.fluid as fluid
import numpy as np
from collections import defaultdict
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import logging
from ..common import get_logger
_logger = get_logger(__name__, level=logging.INFO)
def pdf(program,
var_names,
executor=None,
batch_generator=None,
data_loader=None,
feed_vars=None,
fetch_list=None,
scope=None,
pdf_save_dir='tmp_pdf'):
"""
Draw hist for distributtion of variables in that name is in var_names
Args:
program(fluid.Program): program to analyze.
var_names(list): name of variables to analyze. When there is activation name in var_names,
you should set executor, one of batch_generator and data_loader, feed_list.
executor(fluid.Executor, optional): The executor to run program. Default is None.
batch_generator(Python Generator, optional): The batch generator provides calibrate data for DataLoader,
and it returns a batch every time. For data_loader and batch_generator,
only one can be set. Default is None.
data_loader(fluid.io.DataLoader, optional): The data_loader provides calibrate data to run program.
Default is None.
feed_vars(list): feed variables for program. When you use batch_generator to provide data,
you should set feed_vars. Default is None.
fetch_list(list): fetch list for program. Default is None.
scope(fluid.Scope, optional): The scope to run program, use it to load variables.
If scope is None, will use fluid.global_scope().
pdf_save_dir(str): dirname to save pdf. Default is 'tmp_pdf'
Returns:
dict: numpy array of variables that name in var_names
"""
scope = fluid.global_scope() if scope is None else scope
assert isinstance(var_names, list), 'var_names is a list of variable name'
real_names = []
weight_only = True
for var in program.list_vars():
if var.name in var_names:
if var.persistable == False:
weight_only = False
var.persistable = True
real_names.append(var.name)
if weight_only == False:
if batch_generator is not None:
assert feed_vars is not None, "When using batch_generator, feed_vars must be set"
dataloader = fluid.io.DataLoader.from_generator(
feed_list=feed_vars, capacity=512, iterable=True)
dataloader.set_batch_generator(batch_generator, executor.place)
elif data_loader is not None:
dataloader = data_loader
class Averager(object):
def __init__(self):
self.shadow = {}
self.cnt = 0
def register(self, name, val):
self.shadow[name] = val
self.cnt = 1
def get(self, name):
return self.shadow[name]
def record(self):
return self.shadow
def update(self, name, val):
assert name in self.shadow
new_average = (self.cnt * self.shadow[name] + val) / (self.cnt + 1)
self.cnt += 1
self.shadow[name] = new_average
class EMA(Averager):
def __init__(self, decay):
self.decay = decay
self.shadow = {}
def update(self, name, val):
assert name in self.shadow
new_average = (1.0 - self.decay) * val + self.decay * self.shadow[name]
self.shadow[name] = new_average
class VarCollector(object):
def __init__(self,
program,
var_names,
use_ema=False,
ema_decay=0.999,
scope=None):
self.program = program
self.var_names = var_names
self.scope = fluid.global_scope() if scope is None else scope
self.use_ema = use_ema
self.set_up()
if self.use_ema:
self.stats = EMA(decay=ema_decay)
else:
_logger.info(
"When both batch_generator and data_loader is None, var_names can only include weight names"
)
return
assert executor is not None, "when var_names include activations'name, executor must be set"
assert fetch_list is not None, "when var_names include activations'name,, executor must be set"
for data in dataloader:
executor.run(program=program,
feed=data,
fetch_list=fetch_list,
return_numpy=False)
break
res_np = {}
for name in real_names:
var = fluid.global_scope().find_var(name)
if var is not None:
res_np[name] = np.array(var.get_tensor())
self.stats = Averager()
def set_up(self):
self.real_names = []
if hasattr(self.program, '_program'):
program = self.program._program
else:
_logger.info(
"can't find var {}. Maybe you should set one of batch_generator and data_loader".
format(name))
numbers = len(real_names)
if pdf_save_dir is not None:
if not os.path.exists(pdf_save_dir):
os.mkdir(pdf_save_dir)
pdf_path = os.path.join(pdf_save_dir, 'result.pdf')
with PdfPages(pdf_path) as pdf:
idx = 1
for name in res_np.keys():
if idx % 10 == 0:
_logger.info("plt {}/{}".format(idx, numbers))
arr = res_np[name]
arr = arr.flatten()
weights = np.ones_like(arr) / len(arr)
plt.hist(arr, bins=1000, weights=weights)
plt.xlabel(name)
plt.ylabel("frequency")
plt.title("Hist of variable {}".format(name))
plt.show()
pdf.savefig()
plt.close()
idx += 1
return res_np
program = self.program
for var in program.list_vars():
if var.name in self.var_names:
self.real_names.append(var.name)
def update(self, vars_np):
for name in self.real_names:
val = vars_np[name]
if val is not None:
try:
self.stats.update(name, val)
except:
self.stats.register(name, val)
else:
_logger.info("can't find var {}.".format(name))
return self.stats.record()
def run(self, reader, exe, step=None, loss_name=None):
if not hasattr(self.program, '_program'):
# Compile the native program to speed up
program = fluid.CompiledProgram(self.program).with_data_parallel(
loss_name=loss_name)
for idx, data in enumerate(reader):
vars_np = exe.run(program=program,
feed=data,
fetch_list=self.real_names)
mapped_vars_np = dict(zip(self.real_names, vars_np))
values = self.update(mapped_vars_np)
if idx % 10 == 0:
_logger.info("Collecting..., Step: {}".format(idx))
if step is not None and idx + 1 >= step:
break
return values
def abs_max_run(self, reader, exe, step=None, loss_name=None):
fetch_list = []
with fluid.program_guard(self.program):
for act_name in self.real_names:
act = self.program.global_block().var(act_name)
act = fluid.layers.reduce_max(
fluid.layers.abs(act), name=act_name + "_reduced")
fetch_list.append(act_name + "_reduced.tmp_0")
if not hasattr(self.program, '_program'):
# Compile the native program to speed up
program = fluid.CompiledProgram(self.program).with_data_parallel(
loss_name=loss_name)
for idx, data in enumerate(reader):
vars_np = exe.run(program=program, feed=data, fetch_list=fetch_list)
vars_np = [np.max(var) for var in vars_np]
mapped_vars_np = dict(zip(self.real_names, vars_np))
values = self.update(mapped_vars_np)
if idx % 10 == 0:
_logger.info("Collecting..., Step: {}".format(idx))
if step is not None and idx + 1 >= step:
break
return values
@staticmethod
def pdf(var_dist, save_dir='dist_pdf'):
"""
Draw histogram for distributtion of variables in that in var_dist.
Args:
var_dist(dict): numpy array of variables distribution.
save_dir(str): dirname to save pdf. Default is 'dist_pdf'
"""
numbers = len(var_dist)
if save_dir is not None:
if not os.path.exists(save_dir):
os.mkdir(save_dir)
pdf_path = os.path.join(save_dir, 'result.pdf')
with PdfPages(pdf_path) as pdf:
for i, name in enumerate(var_dist.keys()):
if i % 10 == 0:
_logger.info("plt {}/{}".format(i, numbers))
arr = var_dist[name]
arr = arr.flatten()
weights = np.ones_like(arr) / len(arr)
plt.hist(arr, bins=1000, weights=weights)
plt.xlabel(name)
plt.ylabel("frequency")
plt.title("Hist of variable {}".format(name))
plt.show()
pdf.savefig()
plt.close()
_logger.info("variables histogram have been saved as {}".format(
pdf_path))
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import unittest
import paddle
import paddle.fluid as fluid
from paddleslim.common import VarCollector
from static_case import StaticCase
sys.path.append("../demo")
from models import MobileNet
from layers import conv_bn_layer
import paddle.dataset.mnist as reader
import numpy as np
class TestAnalysisHelper(StaticCase):
def test_analysis_helper(self):
image = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
model = MobileNet()
out = model.net(input=image, class_dim=10)
cost = fluid.layers.cross_entropy(input=out, label=label)
avg_cost = fluid.layers.mean(x=cost)
acc_top1 = fluid.layers.accuracy(input=out, label=label, k=1)
acc_top5 = fluid.layers.accuracy(input=out, label=label, k=5)
optimizer = fluid.optimizer.Momentum(
momentum=0.9,
learning_rate=0.01,
regularization=fluid.regularizer.L2Decay(4e-5))
optimizer.minimize(avg_cost)
main_prog = fluid.default_main_program()
places = fluid.cuda_places() if fluid.is_compiled_with_cuda(
) else fluid.cpu_places()
exe = fluid.Executor(places[0])
train_reader = paddle.fluid.io.batch(
paddle.dataset.mnist.train(), batch_size=64)
train_loader = fluid.io.DataLoader.from_generator(
feed_list=[image, label],
capacity=512,
use_double_buffer=True,
iterable=True)
train_loader.set_sample_list_generator(train_reader, places)
exe.run(fluid.default_startup_program())
vars = ['conv2d_0.tmp_0', 'fc_0.tmp_0', 'fc_0.tmp_1', 'fc_0.tmp_2']
var_collector1 = VarCollector(main_prog, vars, use_ema=True)
values = var_collector1.abs_max_run(
train_loader, exe, step=None, loss_name=avg_cost.name)
vars = [v.name for v in main_prog.list_vars() if v.persistable]
var_collector2 = VarCollector(main_prog, vars, use_ema=False)
values = var_collector2.run(train_loader,
exe,
step=None,
loss_name=avg_cost.name)
var_collector2.pdf(values)
if __name__ == '__main__':
TestAnalysisHelper('test_analysis_helper').test_analysis_helper()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册