未验证 提交 aa18ae11 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Set FC input data format to ANY (#44023)

* Fc add any to input format

* Pre-commit changes
上级 3fd6f09f
......@@ -209,7 +209,7 @@ class FCPrimitiveFactory {
const Tensor* bias,
LoDTensor* output,
const ExecutionContext& ctx) {
auto src_desc = CreateMemDescriptor<T_in>(input, input->format());
auto src_desc = CreateMemDescriptor<T_in>(input, MKLDNNMemoryFormat::any);
auto weight_dims = Get2DWeightDimsForDNNL(weights);
auto weights_desc =
CreateMemDescriptor<T_w>(weight_dims, MKLDNNMemoryFormat::any);
......@@ -236,7 +236,8 @@ class FCPrimitiveFactory {
auto input_dims = phi::vectorize(input->dims());
std::vector<int64_t> new_input_dims = {
input_dims[0] * input_dims[1], input_dims[2], 1};
auto src_desc = CreateMemDescriptor<T_in>(new_input_dims, input->format());
auto src_desc =
CreateMemDescriptor<T_in>(new_input_dims, MKLDNNMemoryFormat::any);
auto weight_dims = Get3DWeightDimsForDNNL(weights);
auto weights_desc =
......@@ -267,7 +268,7 @@ class FCPrimitiveFactory {
const Tensor* bias,
LoDTensor* output,
const ExecutionContext& ctx) {
auto src_desc = CreateMemDescriptor<T_in>(input, input->format());
auto src_desc = CreateMemDescriptor<T_in>(input, MKLDNNMemoryFormat::any);
// Since MKL-DNN doesn't support 4D column-major data formats in
// inner_product primitive, transpose the weights to be in
// row-major format
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册