未验证 提交 e95a22ca 编写于 作者: C ceci3 提交者: GitHub

fix save_inference_model (#1198)

* fix save_inference_model
上级 a10fc884
......@@ -105,7 +105,7 @@ def create_strategy_config(strategy_str, model_type):
'prune_strategy':
'gmp', ### default unstruture prune strategy is gmp
'prune_mode': 'ratio',
'pruned_ratio': float(tmp_s[1]),
'ratio': float(tmp_s[1]),
'local_sparsity': True,
'prune_params_type': 'conv1x1_only'
}
......
......@@ -205,13 +205,14 @@ class AutoCompression:
train_configs = [train_config]
for idx in range(1, len(self._strategy)):
if 'qat' in self._strategy[idx]:
### if compress strategy more than one, the train config in the yaml set for prune
### the train config for quantization is extrapolate from the yaml
if 'qat' in self._strategy[idx] or 'ptq' in self._strategy[idx]:
### If compress strategy more than one, the TrainConfig in the yaml only used in prune.
### The TrainConfig for quantization is extrapolate from above.
tmp_train_config = copy.deepcopy(train_config.__dict__)
### the epoch, train_iter, learning rate of quant is 10% of the prune compress
tmp_train_config['epochs'] = max(
int(train_config.epochs * 0.1), 1)
if self.model_type != 'transformer':
tmp_train_config['epochs'] = max(
int(train_config.epochs * 0.1), 1)
if train_config.train_iter is not None:
tmp_train_config['train_iter'] = int(
train_config.train_iter * 0.1)
......@@ -228,8 +229,6 @@ class AutoCompression:
map(lambda x: x * 0.1, train_config.learning_rate[
'values']))
train_cfg = TrainConfig(**tmp_train_config)
elif 'ptq' in self._strategy[idx]:
train_cfg = None
else:
tmp_train_config = copy.deepcopy(train_config.__dict__)
train_cfg = TrainConfig(**tmp_train_config)
......@@ -802,11 +801,12 @@ class AutoCompression:
for name in test_program_info.feed_target_names
]
model_name = '.'.join(self.model_filename.split(
'.')[:-1]) if self.model_filename is not None else 'model'
path_prefix = os.path.join(model_dir, model_name)
paddle.static.save_inference_model(
path_prefix=str(model_dir),
path_prefix=path_prefix,
feed_vars=feed_vars,
fetch_vars=test_program_info.fetch_targets,
executor=self._exe,
program=test_program,
model_filename=self.model_filename,
params_filename=self.params_filename)
program=test_program)
import os
import paddle
from paddle.fluid.framework import IrGraph
from paddle.framework import core
......@@ -111,10 +112,11 @@ def post_quant_fake(executor,
_program = graph.to_program()
feed_vars = [_program.global_block().var(name) for name in _feed_list]
model_name = model_filename.split('.')[
0] if model_filename is not None else 'model'
save_model_path = os.path.join(save_model_path, model_name)
paddle.static.save_inference_model(
path_prefix=save_model_path,
model_filename=model_filename,
params_filename=params_filename,
feed_vars=feed_vars,
fetch_vars=_fetch_list,
executor=executor,
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle
__all__ = ['load_inference_model']
......@@ -29,8 +30,16 @@ def load_inference_model(path_prefix,
model_filename=model_filename,
params_filename=params_filename))
else:
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
path_prefix=path_prefix, executor=executor))
model_name = '.'.join(model_filename.split('.')
[:-1]) if model_filename is not None else 'model'
if os.path.exists(os.path.join(path_prefix, model_name + '.pdmodel')):
model_path_prefix = os.path.join(path_prefix, model_name)
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
path_prefix=model_path_prefix, executor=executor))
else:
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
path_prefix=path_prefix, executor=executor))
return [inference_program, feed_target_names, fetch_targets]
......@@ -86,14 +86,15 @@ def get_sparse_model(executor, places, model_file, param_file, ratio,
feed_vars = [
inference_program.global_block().var(name) for name in feed_target_names
]
model_name = '.'.join(model_name.split('.')
[:-1]) if model_name is not None else 'model'
save_path = os.path.join(save_path, model_name)
static.save_inference_model(
save_path,
feed_vars=feed_vars,
fetch_vars=fetch_targets,
executor=executor,
program=inference_program,
model_filename=model_name,
params_filename=param_name)
program=inference_program)
print("The pruned model is saved in: ", save_path)
......@@ -160,11 +161,12 @@ def get_prune_model(executor, places, model_file, param_file, ratio, save_path):
feed_vars = [
main_program.global_block().var(name) for name in feed_target_names
]
model_name = '.'.join(model_name.split('.')
[:-1]) if model_name is not None else 'model'
save_path = os.path.join(save_path, model_name)
static.save_inference_model(
save_path,
feed_vars=feed_vars,
fetch_vars=fetch_targets,
executor=executor,
program=main_program,
model_filename=model_name,
params_filename=param_name)
program=main_program)
......@@ -307,7 +307,7 @@ def quantize(cfg):
quant_scope = paddle.static.Scope()
with paddle.static.scope_guard(float_scope):
[float_inference_program, feed_target_names, fetch_targets]= fluid.io.load_inference_model( \
dirname=g_quant_config.model_filename, \
dirname=g_quant_config.float_infer_model_path, \
model_filename=g_quant_config.model_filename, params_filename=g_quant_config.params_filename,
executor=g_quant_config.executor)
float_metric = g_quant_config.eval_function(
......@@ -320,8 +320,8 @@ def quantize(cfg):
model_filename=g_quant_config.model_filename, params_filename=g_quant_config.params_filename,
executor=g_quant_config.executor)
quant_metric = g_quant_config.eval_function(
g_quant_config.executor, inference_program, feed_target_names,
fetch_targets)
g_quant_config.executor, quant_inference_program,
feed_target_names, fetch_targets)
emd_loss = float(abs(float_metric - quant_metric)) / float_metric
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册