未验证 提交 61a1f688 编写于 作者: G Guanghua Yu 提交者: GitHub

Support matmul in QAT and loading quantized models in PTQ (#47892)

上级 cb812f40
......@@ -63,6 +63,7 @@ fake_quant_output_layers = [
paddle.nn.quant.subtract,
paddle.nn.quant.multiply,
paddle.nn.quant.divide,
paddle.nn.quant.matmul,
]
fake_quant_leaf_layers = [
......
......@@ -1939,6 +1939,15 @@ class AddQuantDequantPass:
op_node.op()._set_attr("activation_bits", self._quant_bits)
op_node.op()._set_attr("with_quant_attr", True)
arg_names = utils._get_op_input_var_names(op_node)
# If already quanted, skip it.
skip_quant = False
for arg_name in arg_names:
if "quantized.dequantized" in arg_name:
skip_quant = True
break
if skip_quant:
continue
for arg_name in arg_names:
in_node = graph._find_node_by_name(
op_node.inputs, arg_name
......@@ -2797,6 +2806,15 @@ class AddQuantDequantPassV2:
continue
arg_names = utils._get_op_input_var_names(op_node)
# If already quanted, skip it.
skip_quant = False
for arg_name in arg_names:
if "quantized.dequantized" in arg_name:
skip_quant = True
break
if skip_quant:
continue
for arg_name in arg_names:
in_node = graph._find_node_by_name(
op_node.inputs, arg_name
......
......@@ -21,6 +21,7 @@ from .functional_layers import reshape # noqa: F401
from .functional_layers import transpose # noqa: F401
from .functional_layers import concat # noqa: F401
from .functional_layers import flatten # noqa: F401
from .functional_layers import matmul # noqa: F401
from .quant_layers import QuantStub # noqa: F401
__all__ = []
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ...tensor import manipulation, math
from ...tensor import linalg, manipulation, math
from .. import Layer
__all__ = []
......@@ -85,3 +85,11 @@ class flatten(FloatFunctionalLayer):
def forward(self, x, start_axis=0, stop_axis=-1, name=None):
return manipulation.flatten(x, start_axis, stop_axis, name)
class matmul(FloatFunctionalLayer):
def __init__(self):
super().__init__()
def forward(self, x, y, transpose_x=False, transpose_y=False, name=None):
return linalg.matmul(x, y, transpose_x, transpose_y, name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册