未验证 提交 fbb3a34f 编写于 作者: C Chang Xu 提交者: GitHub

Make QDQ FP32 (#56059)

上级 73c70654
...@@ -59,22 +59,22 @@ class LinearQuanter(Layer): ...@@ -59,22 +59,22 @@ class LinearQuanter(Layer):
def forward(self, input): def forward(self, input):
if in_dynamic_mode(): if in_dynamic_mode():
return _C_ops.quantize_linear( return _C_ops.quantize_linear(
input, input.cast('float32'),
self._scales.cast(input.dtype), self._scales,
self._zero_point.cast(input.dtype), self._zero_point,
"quant_axis", "quant_axis",
self._quant_axis, self._quant_axis,
"bit_length", "bit_length",
self._bit_length, self._bit_length,
) ).cast(input.dtype)
else: else:
out = self._helper.create_variable_for_type_inference(input.dtype) out = self._helper.create_variable_for_type_inference(input.dtype)
self._helper.append_op( self._helper.append_op(
type='quantize_linear', type='quantize_linear',
inputs={ inputs={
'X': input, 'X': input,
'Scale': self._scales.cast(input.dtype), 'Scale': self._scales,
'ZeroPoint': self._zero_point.cast(input.dtype), 'ZeroPoint': self._zero_point,
}, },
outputs={'Y': out}, outputs={'Y': out},
attrs={ attrs={
...@@ -109,22 +109,22 @@ class LinearDequanter(Layer): ...@@ -109,22 +109,22 @@ class LinearDequanter(Layer):
def forward(self, input): def forward(self, input):
if in_dynamic_mode(): if in_dynamic_mode():
return _C_ops.dequantize_linear( return _C_ops.dequantize_linear(
input, input.cast('float32'),
self._scales.cast(input.dtype), self._scales,
self._zero_point.cast(input.dtype), self._zero_point,
"quant_axis", "quant_axis",
self._quant_axis, self._quant_axis,
"bit_length", "bit_length",
self._bit_length, self._bit_length,
) ).cast(input.dtype)
else: else:
out = self._helper.create_variable_for_type_inference(input.dtype) out = self._helper.create_variable_for_type_inference(input.dtype)
self._helper.append_op( self._helper.append_op(
type='dequantize_linear', type='dequantize_linear',
inputs={ inputs={
'X': input, 'X': input,
'Scale': self._scales.cast(input.dtype), 'Scale': self._scales,
'ZeroPoint': self._zero_point.cast(input.dtype), 'ZeroPoint': self._zero_point,
}, },
outputs={'Y': out}, outputs={'Y': out},
attrs={ attrs={
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册