提交 14dfa73e 编写于 作者: B baiyfbupt

fix issue

上级 2ac3ec96
......@@ -53,6 +53,37 @@ from paddleslim.quant import quant_aware, convert
from paddle.fluid.layer_helper import LayerHelper
def pact(x):
"""
Process a variable using the pact method you define
Args:
x(Tensor): Paddle Tensor, need to be preprocess before quantization
Returns:
The processed Tensor x.
"""
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_optimizer():
"""
Build a program using a model and an optimizer
"""
return fluid.optimizer.AdamOptimizer(0.001)
def main():
train_build_outputs = program.build(
config, train_program, startup_program, mode='train')
......@@ -77,26 +108,6 @@ def main():
exe = fluid.Executor(place)
exe.run(startup_program)
def pact(x, name=None):
helper = LayerHelper("pact", **locals())
dtype = 'float32'
init_thres = 20
u_param_attr = fluid.ParamAttr(
name=x.name + '_pact',
initializer=fluid.initializer.ConstantInitializer(value=init_thres),
regularizer=fluid.regularizer.L2Decay(0.0001),
learning_rate=1)
u_param = helper.create_parameter(
attr=u_param_attr, shape=[1], dtype=dtype)
x = fluid.layers.elementwise_sub(
x, fluid.layers.relu(fluid.layers.elementwise_sub(x, u_param)))
x = fluid.layers.elementwise_add(
x, fluid.layers.relu(fluid.layers.elementwise_sub(-u_param, x)))
return x
def get_optimizer():
return fluid.optimizer.AdamOptimizer(0.001)
# 1. quantization configs
quant_config = {
# weight quantize type, default is 'channel_wise_abs_max'
......@@ -151,14 +162,6 @@ def main():
train_compile_program = program.create_multi_devices_program(
quant_train_program, train_opt_loss_name, for_quant=True)
# dump mode structure
if config['Global']['debug']:
if train_alg_type == 'rec' and 'attention' in config['Global'][
'loss_type']:
logger.warning('Does not suport dump attention...')
else:
summary(quant_train_program)
init_model(config, quant_train_program, exe)
train_info_dict = {'compile_program':train_compile_program,\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册