未验证 提交 1bea8e18 编写于 作者: C ceci3 提交者: GitHub

fix ptq-hpo on the windows platform (#1032)

* fix install

* update
Co-authored-by: NGuanghua Yu <742925032@qq.com>
上级 dd85e83a
...@@ -18,9 +18,11 @@ import sys ...@@ -18,9 +18,11 @@ import sys
import numpy as np import numpy as np
import inspect import inspect
from collections import namedtuple, Iterable from collections import namedtuple, Iterable
import platform
import paddle import paddle
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
from ..quant.quant_post_hpo import quant_post_hpo if platform.system().lower() == 'linux':
from ..quant.quant_post_hpo import quant_post_hpo
from ..quant.quanter import convert from ..quant.quanter import convert
from ..common.recover_program import recover_inference_program from ..common.recover_program import recover_inference_program
from ..common import get_logger from ..common import get_logger
...@@ -231,6 +233,10 @@ class AutoCompression: ...@@ -231,6 +233,10 @@ class AutoCompression:
def compress(self): def compress(self):
### start compress, including train/eval model ### start compress, including train/eval model
if self._strategy == 'ptq_hpo': if self._strategy == 'ptq_hpo':
if platform.system().lower() != 'linux':
raise NotImplementedError(
"post-quant-hpo is not support in system other than linux")
quant_post_hpo( quant_post_hpo(
self._exe, self._exe,
self._places, self._places,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import platform
import logging import logging
import paddle import paddle
...@@ -32,7 +33,12 @@ try: ...@@ -32,7 +33,12 @@ try:
from .quanter import quant_aware, convert, quant_post_static, quant_post_dynamic from .quanter import quant_aware, convert, quant_post_static, quant_post_dynamic
from .quanter import quant_post, quant_post_only_weight from .quanter import quant_post, quant_post_only_weight
from .quant_aware_with_infermodel import quant_aware_with_infermodel, export_quant_infermodel from .quant_aware_with_infermodel import quant_aware_with_infermodel, export_quant_infermodel
if platform.system().lower() == 'linux':
from .quant_post_hpo import quant_post_hpo from .quant_post_hpo import quant_post_hpo
else:
_logger.warning(
"post-quant-hpo is not support in system other than linux")
except Exception as e: except Exception as e:
_logger.warning(e) _logger.warning(e)
_logger.warning( _logger.warning(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册