未验证 提交 d0f3296b 编写于 作者: J Jacek Czaja 提交者: GitHub

Fix to #38693 (minimal UT) (#41026)

* Add UT

- Added missed data_layout

- Added missing conversions

- NDHWC added

- NDHWC support in data_transform

- another fix

- condddate change

- fix

u- fix

- fix

- fix

- fix

- fix

- fix to hack

- compilation fix

- fix to automatic merge

* - reduced UT

* - fix

* - lint

* - fix to lint
上级 92d8d0bc
......@@ -59,6 +59,10 @@ inline MKLDNNMemoryFormat ToMKLDNNFormat(const DataLayout& layout) {
return MKLDNNMemoryFormat::nhwc;
case DataLayout::kNCHW:
return MKLDNNMemoryFormat::nchw;
case DataLayout::kNCDHW:
return MKLDNNMemoryFormat::ncdhw;
case DataLayout::kNDHWC:
return MKLDNNMemoryFormat::ndhwc;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Fail to convert layout %s to MKLDNN format.",
......@@ -72,6 +76,10 @@ inline DataLayout ToPaddleLayout(const MKLDNNMemoryFormat& format) {
return DataLayout::kNHWC;
case MKLDNNMemoryFormat::nchw:
return DataLayout::kNCHW;
case MKLDNNMemoryFormat::ncdhw:
return DataLayout::kNCDHW;
case MKLDNNMemoryFormat::ndhwc:
return DataLayout::kNDHWC;
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Fail to convert MKLDNN format to paddle layout."));
......
......@@ -63,7 +63,7 @@ void TransformData(const OpKernelType &expected_kernel_type,
out.ShareDataWith(input_tensor);
// For NHWC data we need reshape of tensors as MKL-DNN
// is expecting NHWC dims description order
if (lin == DataLayout::kNHWC) {
if (lin == DataLayout::kNHWC || lin == DataLayout::kNDHWC) {
platform::MatchShapeToLayout(&out, lin, lout);
// We register only NHWC assuming that model is consistent e.g. either
// NHWC or NCHW
......
......@@ -579,7 +579,8 @@ std::vector<int> Tensor::shape() const {
// be done. Similarly for dim==1 when you have just one possible
// combination.
if (tensor->dims().size() < 3) return phi::vectorize<int>(tensor->dims());
if (out_layout == paddle::framework::DataLayout::kNHWC) {
if (out_layout == paddle::framework::DataLayout::kNHWC ||
out_layout == paddle::framework::DataLayout::kNDHWC) {
auto dims = phi::vectorize<int>(tensor->dims());
std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
return dims;
......
......@@ -20,6 +20,7 @@ limitations under the License. */
#include <string>
#include <utility>
#include <vector>
#include "dnnl.hpp"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/place.h"
......@@ -102,20 +103,22 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in,
switch (from) {
case framework::DataLayout::kMKLDNN:
if (to == framework::DataLayout::kNHWC) {
if ((to == framework::DataLayout::kNHWC) ||
(to == framework::DataLayout::kNDHWC)) {
auto dims = phi::vectorize<int>(tensor_in->dims());
std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
tensor_in->Resize(phi::make_ddim(dims));
VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC output_shape"
VLOG(3) << "Rotating Shape from: kMKLDNN to: kNHWC/kNDHWC output_shape"
<< print_dims(dims);
}
break;
case framework::DataLayout::kNHWC:
case framework::DataLayout::kNDHWC:
if (to == framework::DataLayout::kMKLDNN) {
auto dims = phi::vectorize<int>(tensor_in->dims());
std::rotate(dims.begin() + 1, dims.end() - 1, dims.end());
tensor_in->Resize(phi::make_ddim(dims));
VLOG(3) << "Rotating Shape from: kNHWC to: kMKLDNN output_shape"
VLOG(3) << "Rotating Shape from: kNHWC/kNDHWC to: kMKLDNN output_shape"
<< print_dims(dims);
}
break;
......@@ -279,7 +282,12 @@ inline dnnl::memory::format_tag GetMKLDNNFormat(dnnl::memory::desc mem_desc) {
return dnnl::memory::format_tag::acdeb;
}
} else if (inner_nblks == 1) {
if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (inner_blks[0] == 4 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return dnnl::memory::format_tag::aBcde4b;
}
} else if (inner_blks[0] == 8 && inner_idxs[0] == 0) {
if (strides[0] >= strides[2] && strides[2] >= strides[3] &&
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return dnnl::memory::format_tag::Acdeb8a;
......
......@@ -26,13 +26,13 @@ enum class DataLayout {
ANY = UNDEFINED,
NHWC,
NCHW,
NCDHW,
NDHWC,
MKLDNN,
SPARSE_COO,
SPARSE_CSR,
PSTRING_UNION,
NUM_DATA_LAYOUTS,
NDHWC,
NCDHW,
// See Note [ Why we need ALL in basic kernel key member? ]
ALL_LAYOUT = UNDEFINED,
// Note: Unify phi DataLayout and fluid::framework::DataLayout,
......
......@@ -128,6 +128,7 @@ if (WITH_MKLDNN)
set_tests_properties(test_mkldnn_depthwise_conv_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_mkldnn_reshape_transpose_matmul_fuse_pass PROPERTIES TIMEOUT 100)
set_tests_properties(test_mkldnn_mish_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_conv3d_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_prelu_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_conv_act_mkldnn_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_transpose_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 250)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import MkldnnAutoScanTest, SkipReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume
import hypothesis.strategies as st
class TestMkldnnConv3dOp(MkldnnAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self, *args, **kwargs):
def generate_input(*args, **kwargs):
if kwargs["data_format"] == "NCDHW":
return np.random.random(
[kwargs["batch_size"], 48, 64, 32, 64]).astype(np.float32)
else:
return np.random.random(
[kwargs["batch_size"], 64, 32, 64, 48]).astype(np.float32)
def generate_weight(*args, **kwargs):
return np.random.random(
[16, int(48 / kwargs["groups"]), 3, 3, 3]).astype(np.float32)
conv3d_op = OpConfig(
type="conv3d",
inputs={"Input": ["input_data"],
"Filter": ["conv_weight"]},
outputs={"Output": ["conv_output"]},
attrs={
"data_format": kwargs["data_format"],
"dilations": kwargs["dilations"],
"padding_algorithm": kwargs["padding_algorithm"],
"groups": kwargs["groups"],
"paddings": kwargs["paddings"],
"strides": kwargs["strides"],
"is_test": True
})
program_config = ProgramConfig(
ops=[conv3d_op],
weights={
"conv_weight":
TensorConfig(data_gen=partial(generate_weight, *args, **kwargs))
},
inputs={
"input_data":
TensorConfig(data_gen=partial(generate_input, *args, **kwargs))
},
outputs=["conv_output"])
yield program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, (1e-5, 1e-5)
@given(
data_format=st.sampled_from(["NCDHW", "NDHWC"]),
dilations=st.sampled_from([[1, 2, 1]]),
padding_algorithm=st.sampled_from(["EXPLICIT"]),
groups=st.sampled_from([2]),
paddings=st.sampled_from([[0, 3, 2]]),
strides=st.sampled_from([[1, 2, 1]]),
batch_size=st.integers(
min_value=1, max_value=4), )
def test(self, *args, **kwargs):
self.run_test(*args, **kwargs)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册