提交 a5c98630 编写于 作者: P Physher 提交者: Tao Luo

clarify MKLDNN INT8 Mul Op attributes (#18685)

上级 cff5e2c1
...@@ -207,6 +207,14 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> { ...@@ -207,6 +207,14 @@ class QuantMulPrimitiveFactory : public MulPrimitiveFactory<XT, YT, OT> {
int y_num_col_dims = ctx.Attr<int>("y_num_col_dims"); int y_num_col_dims = ctx.Attr<int>("y_num_col_dims");
auto scale_y = ctx.Attr<std::vector<float>>("scale_y"); auto scale_y = ctx.Attr<std::vector<float>>("scale_y");
// TODO(intel-minghui) : Remove the restriction that only supports Input(Y)
// as weights
bool enforce = std::is_same<YT, float>::value;
PADDLE_ENFORCE(
enforce == true,
"Input(Y) supposed to be fp32 data type since only fp32 data type is "
"supported in the current design of MKLDNN INT8.");
auto x_matrix = auto x_matrix =
this->template UpdateDataFormat<XT>(x_input, x_num_col_dims, ctx); this->template UpdateDataFormat<XT>(x_input, x_num_col_dims, ctx);
auto y_matrix = auto y_matrix =
......
...@@ -144,13 +144,17 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -144,13 +144,17 @@ class MulOpMaker : public framework::OpProtoAndCheckerMaker {
)DOC") )DOC")
.SetDefault(1) .SetDefault(1)
.EqualGreaterThan(1); .EqualGreaterThan(1);
AddAttr<float>("scale_x", AddAttr<float>(
"scale_x to used for int8 input data x." "scale_x",
"Only used with MKL-DNN INT8") "scale_x to be used for int8 mul input data x. scale_x has the"
"same purpose as scale_in in OPs that support quantization."
"Only to be used with MKL-DNN INT8")
.SetDefault(1.0f); .SetDefault(1.0f);
AddAttr<std::vector<float>>("scale_y", AddAttr<std::vector<float>>(
"scale_y to used for int8 input data y." "scale_y",
"Only used with MKL-DNN INT8") "scale_y to be used for int8 mul input data y. scale_y has the"
"same purpose as scale_weights in OPs that support quantization."
"Only to be used with MKL-DNN INT8")
.SetDefault({1.0f}); .SetDefault({1.0f});
AddAttr<float>("scale_out", AddAttr<float>("scale_out",
"scale_out to be used for int8 output data." "scale_out to be used for int8 output data."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册