未验证 提交 874c1ac6 编写于 作者: C ceci3 提交者: GitHub

fix ernie example to adapt develop paddle (#1644)

* fix ac

* adapt x2paddle version

* adapte develop paddle
上级 d0a1e2b6
...@@ -241,9 +241,9 @@ class AutoCompression: ...@@ -241,9 +241,9 @@ class AutoCompression:
], f'Type of input_shapes should be in [dict, tuple or list] but got {type(input_shapes)}.' ], f'Type of input_shapes should be in [dict, tuple or list] but got {type(input_shapes)}.'
paddle.enable_static() paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace()) exe = paddle.static.Executor(paddle.CPUPlace())
[inference_program, feed_target_names, [inference_program,
fetch_targets] = load_inference_model(model_dir, exe, model_filename, feed_target_names, fetch_targets] = load_inference_model(
params_filename) model_dir, exe, model_filename, params_filename)
if type(input_shapes) in [list, tuple]: if type(input_shapes) in [list, tuple]:
assert len( assert len(
...@@ -451,30 +451,29 @@ class AutoCompression: ...@@ -451,30 +451,29 @@ class AutoCompression:
strategy.build_strategy = build_strategy strategy.build_strategy = build_strategy
if train_config.recompute_config is not None: if train_config.recompute_config is not None:
strategy.recompute = True strategy.recompute = True
strategy.recompute_configs = { ** train_config.recompute_config} strategy.recompute_configs = {**train_config.recompute_config}
if train_config.sharding_config is not None: if train_config.sharding_config is not None:
strategy.sharding = True strategy.sharding = True
strategy.sharding_configs = { ** train_config.sharding_config} strategy.sharding_configs = {**train_config.sharding_config}
if train_config.amp_config is not None: if train_config.amp_config is not None:
strategy.amp = True strategy.amp = True
strategy.amp_configs = { ** train_config.amp_config} strategy.amp_configs = {**train_config.amp_config}
if train_config.asp_config is not None: if train_config.asp_config is not None:
strategy.asp = True strategy.asp = True
return strategy return strategy
def _prepare_program(self, program, feed_target_names, fetch_targets, def _prepare_program(self, program, feed_target_names, fetch_targets,
patterns, strategy, config, train_config): patterns, strategy, config, train_config):
train_program = recover_inference_program(program)
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
train_program = recover_inference_program(program, startup_program)
train_program_info = ProgramInfo(startup_program, train_program, train_program_info = ProgramInfo(startup_program, train_program,
feed_target_names, fetch_targets) feed_target_names, fetch_targets)
config_dict = config.__dict__ config_dict = config.__dict__
if "prune_strategy" in config_dict and config_dict[ if "prune_strategy" in config_dict and config_dict["prune_strategy"] == "gmp" and config_dict['gmp_config'] is None:
"prune_strategy"] == "gmp" and config_dict[
'gmp_config'] is None:
_logger.info( _logger.info(
"Calculating the iterations per epoch……(It will take some time)") "Calculating the iterations per epoch……(It will take some time)"
)
# NOTE:XXX: This way of calculating the iters needs to be improved. # NOTE:XXX: This way of calculating the iters needs to be improved.
if train_config.epochs: if train_config.epochs:
iters_per_epoch = len(list(self.train_dataloader())) iters_per_epoch = len(list(self.train_dataloader()))
...@@ -587,9 +586,8 @@ class AutoCompression: ...@@ -587,9 +586,8 @@ class AutoCompression:
train_config = None train_config = None
strategy_idx = None strategy_idx = None
self.final_metric = -1.0 self.final_metric = -1.0
for strategy_idx, ( for strategy_idx, (strategy, config, train_config) in enumerate(
strategy, config, train_config zip(self._strategy, self._config, self.train_config)):
) in enumerate(zip(self._strategy, self._config, self.train_config)):
self.single_strategy_compress(strategy, config, strategy_idx, self.single_strategy_compress(strategy, config, strategy_idx,
train_config) train_config)
...@@ -815,7 +813,7 @@ class AutoCompression: ...@@ -815,7 +813,7 @@ class AutoCompression:
train_config.eval_iter) == 0 and total_train_iter != 0: train_config.eval_iter) == 0 and total_train_iter != 0:
if self.eval_function is not None: if self.eval_function is not None:
# GMP pruner step 3: update params before summrizing sparsity, saving model or evaluation. # GMP pruner step 3: update params before summrizing sparsity, saving model or evaluation.
if 'unstructure' in strategy: if 'unstructure' in strategy:
self._pruner.update_params() self._pruner.update_params()
......
...@@ -296,13 +296,14 @@ class TransformerPruner: ...@@ -296,13 +296,14 @@ class TransformerPruner:
head_num = -1 head_num = -1
tmp_mha_ops = patterns['MHA$0'] tmp_mha_ops = patterns['MHA$0']
for op in tmp_mha_ops: for op in tmp_mha_ops:
if op.type() in ['matmul', 'matmul_v2'] and ( if op.type() in [
not has_trainable_var(op)) and head_num == -1: 'matmul', 'matmul_v2'
] and (not has_trainable_var(op)) and head_num == -1:
inp_var = op.inputs("X") inp_var = op.inputs("X")
head_num = inp_var[0].shape()[1] head_num = inp_var[0].shape()[1]
mha_weight, ffn_weight = preprocess_transformer_patterns(patterns, mha_weight, ffn_weight = preprocess_transformer_patterns(
graph) patterns, graph)
return input_mask_op, layer_num, head_num, mha_weight, ffn_weight return input_mask_op, layer_num, head_num, mha_weight, ffn_weight
def _program_add_mask(self, program, patterns, layer_num, head_num, def _program_add_mask(self, program, patterns, layer_num, head_num,
...@@ -312,7 +313,7 @@ class TransformerPruner: ...@@ -312,7 +313,7 @@ class TransformerPruner:
for ft in fetch_targets: for ft in fetch_targets:
fetch_list.append(ft.name) fetch_list.append(ft.name)
program = recover_inference_program(program) program = recover_inference_program(program)
block = program.global_block() block = program.current_block()
head_mask = block.create_var( head_mask = block.create_var(
name='head_mask', name='head_mask',
shape=[layer_num, head_num], shape=[layer_num, head_num],
...@@ -325,11 +326,12 @@ class TransformerPruner: ...@@ -325,11 +326,12 @@ class TransformerPruner:
1.0, 1.0,
out=head_mask, out=head_mask,
stop_gradient=False) stop_gradient=False)
head_mask = unsqueeze_op( head_mask = unsqueeze_op(block, -1,
block, -1, unsqueeze_op(block, -1,
unsqueeze_op(block, -1, unsqueeze_op(
unsqueeze_op(block, 1, head_mask, feed_num + 1), block, 1, head_mask,
feed_num + 2), feed_num + 3) feed_num + 1), feed_num + 2),
feed_num + 3)
for pattern_name, pattern in patterns.items(): for pattern_name, pattern in patterns.items():
if 'MHA' in pattern_name: if 'MHA' in pattern_name:
...@@ -432,8 +434,7 @@ class TransformerPruner: ...@@ -432,8 +434,7 @@ class TransformerPruner:
index = np.reshape( index = np.reshape(
np.take( np.take(
np.reshape( np.reshape(
np.arange( np.arange(0, head_num * num_per_head, dtype='int64'),
0, head_num * num_per_head, dtype='int64'),
(head_num, num_per_head)), (head_num, num_per_head)),
idx, idx,
axis=0), (-1)) axis=0), (-1))
...@@ -455,13 +456,13 @@ class TransformerPruner: ...@@ -455,13 +456,13 @@ class TransformerPruner:
for w_idx, weight_name in enumerate(qkv): for w_idx, weight_name in enumerate(qkv):
if w_idx % 2 == 0: if w_idx % 2 == 0:
### reorder qkv weight ### reorder qkv weight
reorder_head_matrix(weight_name, qkv_index, dim=1) reorder_head_matrix(weight_name, qkv_index, dim=1)
else: else:
### reorder qkv bias ### reorder qkv bias
reorder_head_matrix(weight_name, qkv_index, dim=0) reorder_head_matrix(weight_name, qkv_index, dim=0)
### reorder attention output weight ### reorder attention output weight
reorder_head_matrix(attn_out[0], index, dim=0) reorder_head_matrix(attn_out[0], index, dim=0)
def _reorder_neuron(self, scope, place, weight, idx): def _reorder_neuron(self, scope, place, weight, idx):
...@@ -528,8 +529,8 @@ class TransformerPruner: ...@@ -528,8 +529,8 @@ class TransformerPruner:
if _var is None: if _var is None:
return return
param_t = _var.get_tensor() param_t = _var.get_tensor()
pruned_ratio = [pruned_ratio[1]] if len(param_t.shape( pruned_ratio = [pruned_ratio[1]
)) == 1 else pruned_ratio ] if len(param_t.shape()) == 1 else pruned_ratio
origin_shape = param_t.shape() origin_shape = param_t.shape()
def process_qkv(qkv_param, pruned_ratio): def process_qkv(qkv_param, pruned_ratio):
...@@ -602,12 +603,12 @@ class TransformerPruner: ...@@ -602,12 +603,12 @@ class TransformerPruner:
origin_shape = op.attr('shape') origin_shape = op.attr('shape')
pruned_shape = origin_shape pruned_shape = origin_shape
if len(origin_shape) == 3: if len(origin_shape) == 3:
pruned_shape[-1] = int(origin_shape[-1] * pruned_shape[-1] = int(
self.width_mult) origin_shape[-1] * self.width_mult)
op.set_attr('shape', pruned_shape) op.set_attr('shape', pruned_shape)
elif len(origin_shape) == 4 or len(origin_shape) == 5: elif len(origin_shape) == 4 or len(origin_shape) == 5:
pruned_shape[-2] = int(origin_shape[-2] * pruned_shape[-2] = int(
self.width_mult) origin_shape[-2] * self.width_mult)
op.set_attr('shape', pruned_shape) op.set_attr('shape', pruned_shape)
else: else:
raise IndexError raise IndexError
......
...@@ -50,28 +50,28 @@ def load_inference_model(path_prefix, ...@@ -50,28 +50,28 @@ def load_inference_model(path_prefix,
), 'Please check {}, or fix params_filename parameter.'.format( ), 'Please check {}, or fix params_filename parameter.'.format(
os.path.join(path_prefix, model_name + '.pdiparams')) os.path.join(path_prefix, model_name + '.pdiparams'))
model_path_prefix = os.path.join(path_prefix, model_name) model_path_prefix = os.path.join(path_prefix, model_name)
[inference_program, feed_target_names, fetch_targets] = ( [inference_program, feed_target_names,
paddle.static.load_inference_model( fetch_targets] = (paddle.static.load_inference_model(
path_prefix=model_path_prefix, executor=executor)) path_prefix=model_path_prefix, executor=executor))
elif model_filename is not None and params_filename is not None: elif model_filename is not None and params_filename is not None:
[inference_program, feed_target_names, fetch_targets] = ( [inference_program, feed_target_names,
paddle.static.load_inference_model( fetch_targets] = (paddle.static.load_inference_model(
path_prefix=path_prefix, path_prefix=path_prefix,
executor=executor, executor=executor,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename)) params_filename=params_filename))
else: else:
model_name = '.'.join(model_filename.split('.') model_name = '.'.join(model_filename.split('.')
[:-1]) if model_filename is not None else 'model' [:-1]) if model_filename is not None else 'model'
if os.path.exists(os.path.join(path_prefix, model_name + '.pdmodel')): if os.path.exists(os.path.join(path_prefix, model_name + '.pdmodel')):
model_path_prefix = os.path.join(path_prefix, model_name) model_path_prefix = os.path.join(path_prefix, model_name)
[inference_program, feed_target_names, fetch_targets] = ( [inference_program, feed_target_names,
paddle.static.load_inference_model( fetch_targets] = (paddle.static.load_inference_model(
path_prefix=model_path_prefix, executor=executor)) path_prefix=model_path_prefix, executor=executor))
else: else:
[inference_program, feed_target_names, fetch_targets] = ( [inference_program, feed_target_names,
paddle.static.load_inference_model( fetch_targets] = (paddle.static.load_inference_model(
path_prefix=path_prefix, executor=executor)) path_prefix=path_prefix, executor=executor))
return [inference_program, feed_target_names, fetch_targets] return [inference_program, feed_target_names, fetch_targets]
...@@ -125,13 +125,13 @@ def load_onnx_model(model_path, ...@@ -125,13 +125,13 @@ def load_onnx_model(model_path,
version = x2paddle.__version__ version = x2paddle.__version__
v0, v1, v2 = version.split('.') v0, v1, v2 = version.split('.')
version_sum = int(v0) * 100 + int(v1) * 10 + int(v2) version_sum = int(v0) * 100 + int(v1) * 10 + int(v2)
if version_sum < 139: if version_sum != 139:
_logger.warning( _logger.warning(
"x2paddle>=1.3.9 is required, please use \"pip install x2paddle\"." "x2paddle==1.3.9 is required, please use \"pip install x2paddle==1.3.9\"."
) )
os.system('python -m pip install -U x2paddle') os.system('python -m pip install -U x2paddle==1.3.9')
except: except:
os.system('python -m pip install -U x2paddle') os.system('python -m pip install -U x2paddle==1.3.9')
# check onnx installation and version # check onnx installation and version
try: try:
pkg.require('onnx') pkg.require('onnx')
...@@ -153,7 +153,8 @@ def load_onnx_model(model_path, ...@@ -153,7 +153,8 @@ def load_onnx_model(model_path,
time_info = int(time.time()) time_info = int(time.time())
if not disable_feedback: if not disable_feedback:
ConverterCheck( ConverterCheck(
task="ONNX", time_info=time_info, convert_state="Start").start() task="ONNX", time_info=time_info,
convert_state="Start").start()
# support distributed convert model # support distributed convert model
model_idx = paddle.distributed.get_rank( model_idx = paddle.distributed.get_rank(
) if paddle.distributed.get_world_size() > 1 else 0 ) if paddle.distributed.get_world_size() > 1 else 0
......
...@@ -41,9 +41,9 @@ def _recover_outputs_attr(program): ...@@ -41,9 +41,9 @@ def _recover_outputs_attr(program):
if "ReserveSpace" not in op.output_names or len( if "ReserveSpace" not in op.output_names or len(
op.output("ReserveSpace")) == 0: op.output("ReserveSpace")) == 0:
reserve_space = block.create_var( reserve_space = block.create_var(
name=paddle.fluid.unique_name. name=paddle.fluid.
generate_with_ignorable_key(".".join( unique_name.generate_with_ignorable_key(
["reserve_space", 'tmp'])), ".".join(["reserve_space", 'tmp'])),
dtype=block.var(op.input("X")[0]).dtype, dtype=block.var(op.input("X")[0]).dtype,
type=paddle.framework.core.VarDesc.VarType.LOD_TENSOR, type=paddle.framework.core.VarDesc.VarType.LOD_TENSOR,
persistable=False, persistable=False,
...@@ -52,9 +52,9 @@ def _recover_outputs_attr(program): ...@@ -52,9 +52,9 @@ def _recover_outputs_attr(program):
if op.type == 'transpose2': if op.type == 'transpose2':
if 'XShape' not in op.output_names: if 'XShape' not in op.output_names:
xshape = block.create_var( xshape = block.create_var(
name=paddle.fluid.unique_name. name=paddle.fluid.
generate_with_ignorable_key(".".join(["xshape", 'tmp' unique_name.generate_with_ignorable_key(
])), ".".join(["xshape", 'tmp'])),
dtype=block.var(op.input("X")[0]).dtype, dtype=block.var(op.input("X")[0]).dtype,
type=paddle.framework.core.VarDesc.VarType.LOD_TENSOR, type=paddle.framework.core.VarDesc.VarType.LOD_TENSOR,
shape=(0, ) + block.var(op.input("X")[0]).shape, shape=(0, ) + block.var(op.input("X")[0]).shape,
...@@ -64,24 +64,24 @@ def _recover_outputs_attr(program): ...@@ -64,24 +64,24 @@ def _recover_outputs_attr(program):
return program return program
def _recover_param_attr(program): def _recover_param_attr(program, startup_program):
"""recover parameters attribute. """recover parameters attribute.
Params in infermodel are stored in the form of variable, which can not be trained.""" Params in infermodel are stored in the form of variable, which can not be trained."""
all_weights = [param for param in program.list_vars() \ all_weights = [param for param in program.list_vars() \
if param.persistable is True and param.name != 'feed' and param.name != 'fetch'] if param.persistable is True and param.name != 'feed' and param.name != 'fetch']
with paddle.static.program_guard(program): with paddle.static.program_guard(program, startup_program):
for w in all_weights: for w in all_weights:
new_w = paddle.create_parameter( new_w = paddle.create_parameter(
shape=w.shape, dtype=w.dtype, name=w.name) shape=w.shape, dtype=w.dtype, name=w.name)
new_w.set_value(w.get_value()) new_w.set_value(w.get_value())
program.block(0).vars[w.name] = new_w program.current_block().vars[w.name] = new_w
return program return program
def recover_inference_program(inference_program): def recover_inference_program(inference_program, startup_program=None):
""" recover inference program to train program which can be trained. """ """ recover inference program to train program which can be trained. """
_remove_fetch_node(inference_program) _remove_fetch_node(inference_program)
inference_program = _recover_param_attr(inference_program) inference_program = _recover_param_attr(inference_program, startup_program)
inference_program = _recover_outputs_attr(inference_program) inference_program = _recover_outputs_attr(inference_program)
for var in inference_program.list_vars(): for var in inference_program.list_vars():
var.stop_gradient = False var.stop_gradient = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册