未验证 提交 01c4ad80 编写于 作者: P piotrekobi 提交者: GitHub

Fix for ernie3.0 int8 (#43992)

* Fix for ernie3.0 int8

* Move changes above comment
上级 e0d7d790
......@@ -170,6 +170,9 @@ class FCPrimitiveFactory {
// In case of 2 dims, we set the only possible format, nc
if (dim_num == 2) {
out->set_format(MKLDNNMemoryFormat::nc);
out->set_mem_desc({phi::vectorize(out->dims()),
platform::MKLDNNGetDataType<T_out>(),
out->format()});
// In case of 3 dims, we generate a format that is based on number
// of output dims and the layout of input format (nchw or nhwc).
} else if (dim_num == 3) {
......@@ -185,9 +188,6 @@ class FCPrimitiveFactory {
} else {
out->set_format(in_format);
}
out->set_mem_desc({phi::vectorize(out->dims()),
platform::MKLDNNGetDataType<T_out>(),
out->format()});
}
void UpdateDataPointers(const ExecutionContext& ctx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册