未验证 提交 a4e43aff 编写于 作者: W whs 提交者: GitHub

Fix issues when hardware is gpu or gpu in ACT (#1180)

上级 6fa4ff18
......@@ -152,18 +152,13 @@ def prepare_strategy(executor,
""" prepare compression config automatically """
final_strategy = None
if with_variable_shape(
### use hardware latency tabel if support
if not with_variable_shape(
model_dir,
model_filename=model_filename,
params_filename=params_filename):
deploy_hardware = None
_logger.warning(
"The model's inputs have variable shape. "
"And the latency predictor doesn't support variable shape. "
"So auto tuning will be skipped and a default strategy will be chosen."
)
### use hardware latency tabel if support
if deploy_hardware is not None:
params_filename=params_filename) and (
deploy_hardware in TableLatencyPredictor.hardware_list):
compressed_time_dict = predict_compressed_model(
executor,
places,
......@@ -216,7 +211,6 @@ def prepare_strategy(executor,
if final_strategy is None:
final_strategy = candidate_s[0]
### if deploy_hardware is not None
else:
### default speedup ratio of quantization is 70% compare to fp32
### TODO(ceci3): full quant or skip some layer later
......
......@@ -221,11 +221,15 @@ class AutoCompression:
@deploy_hardware.setter
def deploy_hardware(self, value):
supported_hardware = TableLatencyPredictor.hardware_list + [
'gpu', # nvidia gpu
"cpu", # intel cpu
]
if value is not None:
# Fail-fast when deploy hardware is set explicitly
assert (
value in TableLatencyPredictor.hardware_list
), f"Hardware should be in supported list {TableLatencyPredictor.hardware_list} but got {value}. Or you can set deploy_hardware None."
value in supported_hardware
), f"Hardware should be in supported list {supported_hardware} but got {value}. Or you can set deploy_hardware None."
self._deploy_hardware = value
def _get_eval_dataloader(self, train_dataloader):
......@@ -469,7 +473,7 @@ class AutoCompression:
return tmp_dir
def compress(self):
self.tmp_dir = create_tmp_dir(self.final_dir)
self.tmp_dir = self.create_tmp_dir(self.final_dir)
for strategy_idx, (
strategy,
config) in enumerate(zip(self._strategy, self._config)):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册