未验证 提交 e9011d3e 编写于 作者: C Chang Xu 提交者: GitHub

Support Load scales from quant model (#1790)

上级 cafd9859
...@@ -153,8 +153,8 @@ class PieceWiseSearch(): ...@@ -153,8 +153,8 @@ class PieceWiseSearch():
else: else:
smooth_scale_out += final_smooth_scale smooth_scale_out += final_smooth_scale
if cur_loss < global_loss: if calibration_loss < global_loss:
global_loss = cur_loss global_loss = calibration_loss
best_scale = smooth_scale_out best_scale = smooth_scale_out
if self.search_piece: if self.search_piece:
print('Find Better K-Piece {}'.format(k_piece)) print('Find Better K-Piece {}'.format(k_piece))
......
...@@ -69,6 +69,9 @@ class AbsmaxObserverLayer(UniformObserver): ...@@ -69,6 +69,9 @@ class AbsmaxObserverLayer(UniformObserver):
def cal_thresholds(self): def cal_thresholds(self):
""" Compute thresholds for MAX function. """ Compute thresholds for MAX function.
""" """
if self._scale is not None:
self._zero_point = 0
return
self._scale, self._zero_point = self.cal_scales_zero_points() self._scale, self._zero_point = self.cal_scales_zero_points()
def min_value(self) -> float: def min_value(self) -> float:
......
...@@ -81,6 +81,7 @@ class AbsMaxChannelWiseWeightObserverLayer(ChannelWiseObserver): ...@@ -81,6 +81,7 @@ class AbsMaxChannelWiseWeightObserverLayer(ChannelWiseObserver):
def cal_thresholds(self): def cal_thresholds(self):
""" Compute thresholds for MAX function. """ Compute thresholds for MAX function.
""" """
if self._scale is None:
self._scale = self._max self._scale = self._max
self._zero_point = paddle.zeros_like(self._scale) self._zero_point = paddle.zeros_like(self._scale)
......
...@@ -70,6 +70,9 @@ class AVGObserverLayer(UniformObserver): ...@@ -70,6 +70,9 @@ class AVGObserverLayer(UniformObserver):
def cal_thresholds(self): def cal_thresholds(self):
""" Compute thresholds for MAX function. """ Compute thresholds for MAX function.
""" """
if self._scale is not None:
self._zero_point = 0
return
self._min, self._max = self._avg_min, paddle.mean( self._min, self._max = self._avg_min, paddle.mean(
paddle.to_tensor(self._avg_list)) paddle.to_tensor(self._avg_list))
self._scale, self._zero_point = self.cal_scales_zero_points() self._scale, self._zero_point = self.cal_scales_zero_points()
......
...@@ -85,6 +85,9 @@ class EMDObserverLayer(UniformObserver): ...@@ -85,6 +85,9 @@ class EMDObserverLayer(UniformObserver):
def cal_thresholds(self): def cal_thresholds(self):
""" Compute thresholds for MAX function. """ Compute thresholds for MAX function.
""" """
if self._scale is not None:
self._zero_point = 0
return
self._min, self._max = self._emd_min, self._emd_max self._min, self._max = self._emd_min, self._emd_max
self._scale, self._zero_point = self.cal_scales_zero_points() self._scale, self._zero_point = self.cal_scales_zero_points()
......
...@@ -82,6 +82,9 @@ class MSEObserverLayer(UniformObserver): ...@@ -82,6 +82,9 @@ class MSEObserverLayer(UniformObserver):
def cal_thresholds(self): def cal_thresholds(self):
""" Compute thresholds for MAX function. """ Compute thresholds for MAX function.
""" """
if self._scale is not None:
self._zero_point = 0
return
self._min, self._max = self._mse_min, self._mse_max self._min, self._max = self._mse_min, self._mse_max
self._scale, self._zero_point = self.cal_scales_zero_points() self._scale, self._zero_point = self.cal_scales_zero_points()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册