未验证 提交 947c5fa7 编写于 作者: C Chang Xu 提交者: GitHub

Fit Quant Dtype (#56019)

上级 7472057c
...@@ -60,8 +60,8 @@ class LinearQuanter(Layer): ...@@ -60,8 +60,8 @@ class LinearQuanter(Layer):
if in_dynamic_mode(): if in_dynamic_mode():
return _C_ops.quantize_linear( return _C_ops.quantize_linear(
input, input,
self._scales, self._scales.cast(input.dtype),
self._zero_point, self._zero_point.cast(input.dtype),
"quant_axis", "quant_axis",
self._quant_axis, self._quant_axis,
"bit_length", "bit_length",
...@@ -73,8 +73,8 @@ class LinearQuanter(Layer): ...@@ -73,8 +73,8 @@ class LinearQuanter(Layer):
type='quantize_linear', type='quantize_linear',
inputs={ inputs={
'X': input, 'X': input,
'Scale': self._scales, 'Scale': self._scales.cast(input.dtype),
'ZeroPoint': self._zero_point, 'ZeroPoint': self._zero_point.cast(input.dtype),
}, },
outputs={'Y': out}, outputs={'Y': out},
attrs={ attrs={
...@@ -110,8 +110,8 @@ class LinearDequanter(Layer): ...@@ -110,8 +110,8 @@ class LinearDequanter(Layer):
if in_dynamic_mode(): if in_dynamic_mode():
return _C_ops.dequantize_linear( return _C_ops.dequantize_linear(
input, input,
self._scales, self._scales.cast(input.dtype),
self._zero_point, self._zero_point.cast(input.dtype),
"quant_axis", "quant_axis",
self._quant_axis, self._quant_axis,
"bit_length", "bit_length",
...@@ -123,8 +123,8 @@ class LinearDequanter(Layer): ...@@ -123,8 +123,8 @@ class LinearDequanter(Layer):
type='dequantize_linear', type='dequantize_linear',
inputs={ inputs={
'X': input, 'X': input,
'Scale': self._scales, 'Scale': self._scales.cast(input.dtype),
'ZeroPoint': self._zero_point, 'ZeroPoint': self._zero_point.cast(input.dtype),
}, },
outputs={'Y': out}, outputs={'Y': out},
attrs={ attrs={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册