diff --git a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc index 4819bb300dce7a4b5eeda0d6f872749c0d68ae6d..71b9cc19398663048a59a4dfd2e2b04e1cae10b9 100644 --- a/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc @@ -207,6 +207,14 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory { int y_num_col_dims = ctx.Attr("y_num_col_dims"); auto scale_y = ctx.Attr>("scale_y"); + // TODO(intel-minghui) : Remove the restriction that only supports Input(Y) + // as weights + bool enforce = std::is_same::value; + PADDLE_ENFORCE( + enforce == true, + "Input(Y) supposed to be fp32 data type since only fp32 data type is " + "supported in the current design of MKLDNN INT8."); + auto x_matrix = this->template UpdateDataFormat(x_input, x_num_col_dims, ctx); auto y_matrix = diff --git a/paddle/fluid/operators/mul_op.cc b/paddle/fluid/operators/mul_op.cc index ebb88fe28290681febb7389415b4b08a46352edc..0823ea8f4d3a95be6637e71bff818dfe9490ed1b 100644 --- a/paddle/fluid/operators/mul_op.cc +++ b/paddle/fluid/operators/mul_op.cc @@ -144,13 +144,17 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { )DOC") .SetDefault(1) .EqualGreaterThan(1); - AddAttr("scale_x", - "scale_x to used for int8 input data x." - "Only used with MKL-DNN INT8") + AddAttr( + "scale_x", + "scale_x to be used for int8 mul input data x. scale_x has the" + "same purpose as scale_in in OPs that support quantization." + "Only to be used with MKL-DNN INT8") .SetDefault(1.0f); - AddAttr>("scale_y", - "scale_y to used for int8 input data y." - "Only used with MKL-DNN INT8") + AddAttr>( + "scale_y", + "scale_y to be used for int8 mul input data y. scale_y has the" + "same purpose as scale_weights in OPs that support quantization." + "Only to be used with MKL-DNN INT8") .SetDefault({1.0f}); AddAttr("scale_out", "scale_out to be used for int8 output data."