未验证 提交 2952d4ce 编写于 作者: W whs 提交者: GitHub

Skip auto tuning when shape of model's input is variable (#1155)

上级 418ef571
......@@ -12,6 +12,7 @@ repos:
sha: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
language_version: python3.9
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pre-commit/pre-commit-hooks
sha: 5bf6c09bfa1297d3692cadd621ef95f1284e33c0
......
......@@ -50,7 +50,6 @@ def get_features_from_paramkey(param_key, op_type, data_type):
"""Get op's parameters according to the key of latency table
"""
features = None
if 'conv2d' in op_type:
if data_type == 'fp16':
quant_bits = 'bit_length=16'
......
......@@ -16,7 +16,7 @@ import os
import logging
import platform
from ..common import get_logger
from .utils.predict import predict_compressed_model
from .utils.predict import predict_compressed_model, with_variable_shape
from .strategy_config import *
_logger = get_logger(__name__, level=logging.INFO)
......@@ -151,6 +151,16 @@ def prepare_strategy(executor,
""" prepare compression config automatically """
final_strategy = None
if with_variable_shape(
model_dir,
model_filename=model_filename,
params_filename=params_filename):
deploy_hardware = None
_logger.warning(
"The model's inputs have variable shape. "
"And the latency predictor doesn't support variable shape. "
"So auto tuning will be skipped and a default strategy will be chosen."
)
### use hardware latency tabel if support
if deploy_hardware is not None:
compressed_time_dict = predict_compressed_model(
......
......@@ -6,6 +6,30 @@ from .prune_model import get_sparse_model, get_prune_model
from .fake_ptq import post_quant_fake
def with_variable_shape(model_dir, model_filename=None, params_filename=None):
"""
Whether the shape of model's input is variable.
Args:
path_prefix(str | None): Directory path to save model + model name without suffix.
model_filename(str): specify model_filename if you don't want to use default name. Default : 'None'.
params_filename(str): specify params_filename if you don't want to use default name. Default : 'None'.
Returns:
bool: Whether the shape of model's input is variable.
"""
paddle.enable_static()
exe = paddle.static.Executor(paddle.CPUPlace())
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(
model_dir,
exe,
model_filename=model_filename,
params_filename=params_filename))
for var_ in inference_program.list_vars():
if var_.name in feed_target_names:
if var_.shape.count(-1) > 1:
return True
def predict_compressed_model(executor,
places,
model_dir,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册