未验证 提交 e95a22ca 编写于 作者: C ceci3 提交者: GitHub

fix save_inference_model (#1198)

* fix save_inference_model
上级 a10fc884
...@@ -105,7 +105,7 @@ def create_strategy_config(strategy_str, model_type): ...@@ -105,7 +105,7 @@ def create_strategy_config(strategy_str, model_type):
'prune_strategy': 'prune_strategy':
'gmp', ### default unstruture prune strategy is gmp 'gmp', ### default unstruture prune strategy is gmp
'prune_mode': 'ratio', 'prune_mode': 'ratio',
'pruned_ratio': float(tmp_s[1]), 'ratio': float(tmp_s[1]),
'local_sparsity': True, 'local_sparsity': True,
'prune_params_type': 'conv1x1_only' 'prune_params_type': 'conv1x1_only'
} }
......
...@@ -205,13 +205,14 @@ class AutoCompression: ...@@ -205,13 +205,14 @@ class AutoCompression:
train_configs = [train_config] train_configs = [train_config]
for idx in range(1, len(self._strategy)): for idx in range(1, len(self._strategy)):
if 'qat' in self._strategy[idx]: if 'qat' in self._strategy[idx] or 'ptq' in self._strategy[idx]:
### if compress strategy more than one, the train config in the yaml set for prune ### If compress strategy more than one, the TrainConfig in the yaml only used in prune.
### the train config for quantization is extrapolate from the yaml ### The TrainConfig for quantization is extrapolate from above.
tmp_train_config = copy.deepcopy(train_config.__dict__) tmp_train_config = copy.deepcopy(train_config.__dict__)
### the epoch, train_iter, learning rate of quant is 10% of the prune compress ### the epoch, train_iter, learning rate of quant is 10% of the prune compress
tmp_train_config['epochs'] = max( if self.model_type != 'transformer':
int(train_config.epochs * 0.1), 1) tmp_train_config['epochs'] = max(
int(train_config.epochs * 0.1), 1)
if train_config.train_iter is not None: if train_config.train_iter is not None:
tmp_train_config['train_iter'] = int( tmp_train_config['train_iter'] = int(
train_config.train_iter * 0.1) train_config.train_iter * 0.1)
...@@ -228,8 +229,6 @@ class AutoCompression: ...@@ -228,8 +229,6 @@ class AutoCompression:
map(lambda x: x * 0.1, train_config.learning_rate[ map(lambda x: x * 0.1, train_config.learning_rate[
'values'])) 'values']))
train_cfg = TrainConfig(**tmp_train_config) train_cfg = TrainConfig(**tmp_train_config)
elif 'ptq' in self._strategy[idx]:
train_cfg = None
else: else:
tmp_train_config = copy.deepcopy(train_config.__dict__) tmp_train_config = copy.deepcopy(train_config.__dict__)
train_cfg = TrainConfig(**tmp_train_config) train_cfg = TrainConfig(**tmp_train_config)
...@@ -802,11 +801,12 @@ class AutoCompression: ...@@ -802,11 +801,12 @@ class AutoCompression:
for name in test_program_info.feed_target_names for name in test_program_info.feed_target_names
] ]
model_name = '.'.join(self.model_filename.split(
'.')[:-1]) if self.model_filename is not None else 'model'
path_prefix = os.path.join(model_dir, model_name)
paddle.static.save_inference_model( paddle.static.save_inference_model(
path_prefix=str(model_dir), path_prefix=path_prefix,
feed_vars=feed_vars, feed_vars=feed_vars,
fetch_vars=test_program_info.fetch_targets, fetch_vars=test_program_info.fetch_targets,
executor=self._exe, executor=self._exe,
program=test_program, program=test_program)
model_filename=self.model_filename,
params_filename=self.params_filename)
import os
import paddle import paddle
from paddle.fluid.framework import IrGraph from paddle.fluid.framework import IrGraph
from paddle.framework import core from paddle.framework import core
...@@ -111,10 +112,11 @@ def post_quant_fake(executor, ...@@ -111,10 +112,11 @@ def post_quant_fake(executor,
_program = graph.to_program() _program = graph.to_program()
feed_vars = [_program.global_block().var(name) for name in _feed_list] feed_vars = [_program.global_block().var(name) for name in _feed_list]
model_name = model_filename.split('.')[
0] if model_filename is not None else 'model'
save_model_path = os.path.join(save_model_path, model_name)
paddle.static.save_inference_model( paddle.static.save_inference_model(
path_prefix=save_model_path, path_prefix=save_model_path,
model_filename=model_filename,
params_filename=params_filename,
feed_vars=feed_vars, feed_vars=feed_vars,
fetch_vars=_fetch_list, fetch_vars=_fetch_list,
executor=executor, executor=executor,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import paddle import paddle
__all__ = ['load_inference_model'] __all__ = ['load_inference_model']
...@@ -29,8 +30,16 @@ def load_inference_model(path_prefix, ...@@ -29,8 +30,16 @@ def load_inference_model(path_prefix,
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename)) params_filename=params_filename))
else: else:
[inference_program, feed_target_names, fetch_targets] = ( model_name = '.'.join(model_filename.split('.')
paddle.static.load_inference_model( [:-1]) if model_filename is not None else 'model'
path_prefix=path_prefix, executor=executor)) if os.path.exists(os.path.join(path_prefix, model_name + '.pdmodel')):
model_path_prefix = os.path.join(path_prefix, model_name)
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
path_prefix=model_path_prefix, executor=executor))
else:
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
path_prefix=path_prefix, executor=executor))
return [inference_program, feed_target_names, fetch_targets] return [inference_program, feed_target_names, fetch_targets]
...@@ -86,14 +86,15 @@ def get_sparse_model(executor, places, model_file, param_file, ratio, ...@@ -86,14 +86,15 @@ def get_sparse_model(executor, places, model_file, param_file, ratio,
feed_vars = [ feed_vars = [
inference_program.global_block().var(name) for name in feed_target_names inference_program.global_block().var(name) for name in feed_target_names
] ]
model_name = '.'.join(model_name.split('.')
[:-1]) if model_name is not None else 'model'
save_path = os.path.join(save_path, model_name)
static.save_inference_model( static.save_inference_model(
save_path, save_path,
feed_vars=feed_vars, feed_vars=feed_vars,
fetch_vars=fetch_targets, fetch_vars=fetch_targets,
executor=executor, executor=executor,
program=inference_program, program=inference_program)
model_filename=model_name,
params_filename=param_name)
print("The pruned model is saved in: ", save_path) print("The pruned model is saved in: ", save_path)
...@@ -160,11 +161,12 @@ def get_prune_model(executor, places, model_file, param_file, ratio, save_path): ...@@ -160,11 +161,12 @@ def get_prune_model(executor, places, model_file, param_file, ratio, save_path):
feed_vars = [ feed_vars = [
main_program.global_block().var(name) for name in feed_target_names main_program.global_block().var(name) for name in feed_target_names
] ]
model_name = '.'.join(model_name.split('.')
[:-1]) if model_name is not None else 'model'
save_path = os.path.join(save_path, model_name)
static.save_inference_model( static.save_inference_model(
save_path, save_path,
feed_vars=feed_vars, feed_vars=feed_vars,
fetch_vars=fetch_targets, fetch_vars=fetch_targets,
executor=executor, executor=executor,
program=main_program, program=main_program)
model_filename=model_name,
params_filename=param_name)
...@@ -307,7 +307,7 @@ def quantize(cfg): ...@@ -307,7 +307,7 @@ def quantize(cfg):
quant_scope = paddle.static.Scope() quant_scope = paddle.static.Scope()
with paddle.static.scope_guard(float_scope): with paddle.static.scope_guard(float_scope):
[float_inference_program, feed_target_names, fetch_targets]= fluid.io.load_inference_model( \ [float_inference_program, feed_target_names, fetch_targets]= fluid.io.load_inference_model( \
dirname=g_quant_config.model_filename, \ dirname=g_quant_config.float_infer_model_path, \
model_filename=g_quant_config.model_filename, params_filename=g_quant_config.params_filename, model_filename=g_quant_config.model_filename, params_filename=g_quant_config.params_filename,
executor=g_quant_config.executor) executor=g_quant_config.executor)
float_metric = g_quant_config.eval_function( float_metric = g_quant_config.eval_function(
...@@ -320,8 +320,8 @@ def quantize(cfg): ...@@ -320,8 +320,8 @@ def quantize(cfg):
model_filename=g_quant_config.model_filename, params_filename=g_quant_config.params_filename, model_filename=g_quant_config.model_filename, params_filename=g_quant_config.params_filename,
executor=g_quant_config.executor) executor=g_quant_config.executor)
quant_metric = g_quant_config.eval_function( quant_metric = g_quant_config.eval_function(
g_quant_config.executor, inference_program, feed_target_names, g_quant_config.executor, quant_inference_program,
fetch_targets) feed_target_names, fetch_targets)
emd_loss = float(abs(float_metric - quant_metric)) / float_metric emd_loss = float(abs(float_metric - quant_metric)) / float_metric
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册