未验证 提交 52d6b306 编写于 作者: W wenbin 提交者: GitHub

mkldnn layout issue fix (#39422)

* mkldnn conv fix

* definetion
上级 c47d6729
...@@ -429,6 +429,33 @@ std::vector<int> Tensor::shape() const { ...@@ -429,6 +429,33 @@ std::vector<int> Tensor::shape() const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
tensor_, paddle::platform::errors::PreconditionNotMet( tensor_, paddle::platform::errors::PreconditionNotMet(
"Not found tensor called %s in the scope", name_)); "Not found tensor called %s in the scope", name_));
// mkldnn may does layout transform internally, so need to reorder before
// return
#ifdef PADDLE_WITH_MKLDNN
if (tensor->layout() == paddle::framework::DataLayout::kMKLDNN) {
paddle::framework::DataLayout out_layout =
paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout();
// Set default as NCHW in case not specified
out_layout = out_layout == paddle::framework::DataLayout::kAnyLayout
? paddle::framework::DataLayout::kNCHW
: out_layout;
// In these data layouts, channel dimension is either on 2nd position: nChw
// or
// at last nhwC, so for dim==2 these layouts are the same and nothing should
// be done. Similarly for dim==1 when you have just one possible
// combination.
if (tensor->dims().size() < 3)
return paddle::framework::vectorize<int>(tensor->dims());
if (out_layout == paddle::framework::DataLayout::kNHWC) {
auto dims = paddle::framework::vectorize<int>(tensor->dims());
std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
return dims;
} else {
return paddle::framework::vectorize<int>(tensor->dims());
}
}
#endif
return paddle::framework::vectorize<int>(tensor->dims()); return paddle::framework::vectorize<int>(tensor->dims());
} }
......
...@@ -183,7 +183,9 @@ class TestConvBnFusePass(PassAutoScanTest): ...@@ -183,7 +183,9 @@ class TestConvBnFusePass(PassAutoScanTest):
def add_ignore_pass_case(self): def add_ignore_pass_case(self):
def teller1(program_config, predictor_config): def teller1(program_config, predictor_config):
if program_config.ops[0].attrs['data_format'] == "NHWC": if program_config.ops[0].attrs[
'data_format'] == "NHWC" and not predictor_config.mkldnn_enabled(
):
return True return True
return False return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册