未验证 提交 874c1ac6 编写于 作者: C ceci3 提交者: GitHub

fix ernie example to adapt develop paddle (#1644)

* fix ac

* adapt x2paddle version

* adapte develop paddle
上级 d0a1e2b6
......@@ -241,9 +241,9 @@ class AutoCompression:
], f'Type of input_shapes should be in [dict, tuple or list] but got {type(input_shapes)}.'
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
[inference_program, feed_target_names,
fetch_targets] = load_inference_model(model_dir, exe, model_filename,
params_filename)
[inference_program,
feed_target_names, fetch_targets] = load_inference_model(
model_dir, exe, model_filename, params_filename)
if type(input_shapes) in [list, tuple]:
assert len(
......@@ -451,30 +451,29 @@ class AutoCompression:
strategy.build_strategy = build_strategy
if train_config.recompute_config is not None:
strategy.recompute = True
strategy.recompute_configs = { ** train_config.recompute_config}
strategy.recompute_configs = {**train_config.recompute_config}
if train_config.sharding_config is not None:
strategy.sharding = True
strategy.sharding_configs = { ** train_config.sharding_config}
strategy.sharding_configs = {**train_config.sharding_config}
if train_config.amp_config is not None:
strategy.amp = True
strategy.amp_configs = { ** train_config.amp_config}
strategy.amp_configs = {**train_config.amp_config}
if train_config.asp_config is not None:
strategy.asp = True
return strategy
def _prepare_program(self, program, feed_target_names, fetch_targets,
patterns, strategy, config, train_config):
train_program = recover_inference_program(program)
startup_program = paddle.static.Program()
train_program = recover_inference_program(program, startup_program)
train_program_info = ProgramInfo(startup_program, train_program,
feed_target_names, fetch_targets)
config_dict = config.__dict__
if "prune_strategy" in config_dict and config_dict[
"prune_strategy"] == "gmp" and config_dict[
'gmp_config'] is None:
if "prune_strategy" in config_dict and config_dict["prune_strategy"] == "gmp" and config_dict['gmp_config'] is None:
_logger.info(
"Calculating the iterations per epoch……(It will take some time)")
"Calculating the iterations per epoch……(It will take some time)"
)
# NOTE:XXX: This way of calculating the iters needs to be improved.
if train_config.epochs:
iters_per_epoch = len(list(self.train_dataloader()))
......@@ -587,9 +586,8 @@ class AutoCompression:
train_config = None
strategy_idx = None
self.final_metric = -1.0
for strategy_idx, (
strategy, config, train_config
) in enumerate(zip(self._strategy, self._config, self.train_config)):
for strategy_idx, (strategy, config, train_config) in enumerate(
zip(self._strategy, self._config, self.train_config)):
self.single_strategy_compress(strategy, config, strategy_idx,
train_config)
......
......@@ -296,13 +296,14 @@ class TransformerPruner:
head_num = -1
tmp_mha_ops = patterns['MHA$0']
for op in tmp_mha_ops:
if op.type() in ['matmul', 'matmul_v2'] and (
not has_trainable_var(op)) and head_num == -1:
if op.type() in [
'matmul', 'matmul_v2'
] and (not has_trainable_var(op)) and head_num == -1:
inp_var = op.inputs("X")
head_num = inp_var[0].shape()[1]
mha_weight, ffn_weight = preprocess_transformer_patterns(patterns,
graph)
mha_weight, ffn_weight = preprocess_transformer_patterns(
patterns, graph)
return input_mask_op, layer_num, head_num, mha_weight, ffn_weight
def _program_add_mask(self, program, patterns, layer_num, head_num,
......@@ -312,7 +313,7 @@ class TransformerPruner:
for ft in fetch_targets:
fetch_list.append(ft.name)
program = recover_inference_program(program)
block = program.global_block()
block = program.current_block()
head_mask = block.create_var(
name='head_mask',
shape=[layer_num, head_num],
......@@ -325,11 +326,12 @@ class TransformerPruner:
1.0,
out=head_mask,
stop_gradient=False)
head_mask = unsqueeze_op(
block, -1,
head_mask = unsqueeze_op(block, -1,
unsqueeze_op(block, -1,
unsqueeze_op(block, 1, head_mask, feed_num + 1),
feed_num + 2), feed_num + 3)
unsqueeze_op(
block, 1, head_mask,
feed_num + 1), feed_num + 2),
feed_num + 3)
for pattern_name, pattern in patterns.items():
if 'MHA' in pattern_name:
......@@ -432,8 +434,7 @@ class TransformerPruner:
index = np.reshape(
np.take(
np.reshape(
np.arange(
0, head_num * num_per_head, dtype='int64'),
np.arange(0, head_num * num_per_head, dtype='int64'),
(head_num, num_per_head)),
idx,
axis=0), (-1))
......@@ -528,8 +529,8 @@ class TransformerPruner:
if _var is None:
return
param_t = _var.get_tensor()
pruned_ratio = [pruned_ratio[1]] if len(param_t.shape(
)) == 1 else pruned_ratio
pruned_ratio = [pruned_ratio[1]
] if len(param_t.shape()) == 1 else pruned_ratio
origin_shape = param_t.shape()
def process_qkv(qkv_param, pruned_ratio):
......@@ -602,12 +603,12 @@ class TransformerPruner:
origin_shape = op.attr('shape')
pruned_shape = origin_shape
if len(origin_shape) == 3:
pruned_shape[-1] = int(origin_shape[-1] *
self.width_mult)
pruned_shape[-1] = int(
origin_shape[-1] * self.width_mult)
op.set_attr('shape', pruned_shape)
elif len(origin_shape) == 4 or len(origin_shape) == 5:
pruned_shape[-2] = int(origin_shape[-2] *
self.width_mult)
pruned_shape[-2] = int(
origin_shape[-2] * self.width_mult)
op.set_attr('shape', pruned_shape)
else:
raise IndexError
......
......@@ -50,12 +50,12 @@ def load_inference_model(path_prefix,
), 'Please check {}, or fix params_filename parameter.'.format(
os.path.join(path_prefix, model_name + '.pdiparams'))
model_path_prefix = os.path.join(path_prefix, model_name)
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
[inference_program, feed_target_names,
fetch_targets] = (paddle.static.load_inference_model(
path_prefix=model_path_prefix, executor=executor))
elif model_filename is not None and params_filename is not None:
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
[inference_program, feed_target_names,
fetch_targets] = (paddle.static.load_inference_model(
path_prefix=path_prefix,
executor=executor,
model_filename=model_filename,
......@@ -65,12 +65,12 @@ def load_inference_model(path_prefix,
[:-1]) if model_filename is not None else 'model'
if os.path.exists(os.path.join(path_prefix, model_name + '.pdmodel')):
model_path_prefix = os.path.join(path_prefix, model_name)
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
[inference_program, feed_target_names,
fetch_targets] = (paddle.static.load_inference_model(
path_prefix=model_path_prefix, executor=executor))
else:
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
[inference_program, feed_target_names,
fetch_targets] = (paddle.static.load_inference_model(
path_prefix=path_prefix, executor=executor))
return [inference_program, feed_target_names, fetch_targets]
......@@ -125,13 +125,13 @@ def load_onnx_model(model_path,
version = x2paddle.__version__
v0, v1, v2 = version.split('.')
version_sum = int(v0) * 100 + int(v1) * 10 + int(v2)
if version_sum < 139:
if version_sum != 139:
_logger.warning(
"x2paddle>=1.3.9 is required, please use \"pip install x2paddle\"."
"x2paddle==1.3.9 is required, please use \"pip install x2paddle==1.3.9\"."
)
os.system('python -m pip install -U x2paddle')
os.system('python -m pip install -U x2paddle==1.3.9')
except:
os.system('python -m pip install -U x2paddle')
os.system('python -m pip install -U x2paddle==1.3.9')
# check onnx installation and version
try:
pkg.require('onnx')
......@@ -153,7 +153,8 @@ def load_onnx_model(model_path,
time_info = int(time.time())
if not disable_feedback:
ConverterCheck(
task="ONNX", time_info=time_info, convert_state="Start").start()
task="ONNX", time_info=time_info,
convert_state="Start").start()
# support distributed convert model
model_idx = paddle.distributed.get_rank(
) if paddle.distributed.get_world_size() > 1 else 0
......
......@@ -41,9 +41,9 @@ def _recover_outputs_attr(program):
if "ReserveSpace" not in op.output_names or len(
op.output("ReserveSpace")) == 0:
reserve_space = block.create_var(
name=paddle.fluid.unique_name.
generate_with_ignorable_key(".".join(
["reserve_space", 'tmp'])),
name=paddle.fluid.
unique_name.generate_with_ignorable_key(
".".join(["reserve_space", 'tmp'])),
dtype=block.var(op.input("X")[0]).dtype,
type=paddle.framework.core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
......@@ -52,9 +52,9 @@ def _recover_outputs_attr(program):
if op.type == 'transpose2':
if 'XShape' not in op.output_names:
xshape = block.create_var(
name=paddle.fluid.unique_name.
generate_with_ignorable_key(".".join(["xshape", 'tmp'
])),
name=paddle.fluid.
unique_name.generate_with_ignorable_key(
".".join(["xshape", 'tmp'])),
dtype=block.var(op.input("X")[0]).dtype,
type=paddle.framework.core.VarDesc.VarType.LOD_TENSOR,
shape=(0, ) + block.var(op.input("X")[0]).shape,
......@@ -64,24 +64,24 @@ def _recover_outputs_attr(program):
return program
def _recover_param_attr(program):
def _recover_param_attr(program, startup_program):
"""recover parameters attribute.
Params in infermodel are stored in the form of variable, which can not be trained."""
all_weights = [param for param in program.list_vars() \
if param.persistable is True and param.name != 'feed' and param.name != 'fetch']
with paddle.static.program_guard(program):
with paddle.static.program_guard(program, startup_program):
for w in all_weights:
new_w = paddle.create_parameter(
shape=w.shape, dtype=w.dtype, name=w.name)
new_w.set_value(w.get_value())
program.block(0).vars[w.name] = new_w
program.current_block().vars[w.name] = new_w
return program
def recover_inference_program(inference_program):
def recover_inference_program(inference_program, startup_program=None):
""" recover inference program to train program which can be trained. """
_remove_fetch_node(inference_program)
inference_program = _recover_param_attr(inference_program)
inference_program = _recover_param_attr(inference_program, startup_program)
inference_program = _recover_outputs_attr(inference_program)
for var in inference_program.list_vars():
var.stop_gradient = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册