提交 18a5d307 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Conv2d and Conv2d transpose MKL-DNN NHWC support (#21466)

上级 96a446f6
...@@ -113,7 +113,6 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) { ...@@ -113,7 +113,6 @@ void* GetDataFromTensor(const Tensor& tensor, mkldnn::memory::data_type type) {
PADDLE_THROW("wrong mkldnn type provided"); PADDLE_THROW("wrong mkldnn type provided");
} }
} }
#endif
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const OpKernelType& expected_kernel_type,
...@@ -127,14 +126,11 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, ...@@ -127,14 +126,11 @@ void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
"TransDataLayoutFromMKLDNN only supports transform from MKLDNN to " "TransDataLayoutFromMKLDNN only supports transform from MKLDNN to "
"non-MKLDNN"); "non-MKLDNN");
#ifdef PADDLE_WITH_MKLDNN
innerTransDataLayoutFromMKLDNN(in_layout, innerTransDataLayoutFromMKLDNN(in_layout,
paddle::platform::get_cur_paddle_data_layout(), paddle::platform::get_cur_paddle_data_layout(),
in, out, place); in, out, place);
#endif
} }
#ifdef PADDLE_WITH_MKLDNN
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out, const Tensor& in, Tensor* out,
platform::Place place) { platform::Place place) {
......
...@@ -69,11 +69,11 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) { ...@@ -69,11 +69,11 @@ inline MKLDNNDataType ToMKLDNNDataType(proto::VarType::Type type) {
void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
const Tensor& in, Tensor* out, const Tensor& in, Tensor* out,
platform::Place place); platform::Place place);
#endif
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
const OpKernelType& expected_kernel_type, const OpKernelType& expected_kernel_type,
const Tensor& in, Tensor* out); const Tensor& in, Tensor* out);
#endif
std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to); std::vector<int> GetAxis(const DataLayout& from, const DataLayout& to);
......
...@@ -43,13 +43,13 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -43,13 +43,13 @@ void TransformData(const OpKernelType &expected_kernel_type,
// do layout transform // do layout transform
if (NeedTransformLayout(lout, lin)) { if (NeedTransformLayout(lout, lin)) {
#ifdef PADDLE_WITH_MKLDNN
if (lin == DataLayout::kMKLDNN || lout == DataLayout::kMKLDNN) { if (lin == DataLayout::kMKLDNN || lout == DataLayout::kMKLDNN) {
PADDLE_ENFORCE( PADDLE_ENFORCE(
!(lin == DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN), !(lin == DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN),
"No layout transform needed between two MKLDNN OPKernels"); "No layout transform needed between two MKLDNN OPKernels");
if (lin != DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN) { if (lin != DataLayout::kMKLDNN && lout == DataLayout::kMKLDNN) {
#ifdef PADDLE_WITH_MKLDNN
// Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel // Case1 - transform from Non-MKLDNN OPKernel to MKLDNN OPKernel
// Just set layout/format. No real transform occur // Just set layout/format. No real transform occur
...@@ -67,7 +67,6 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -67,7 +67,6 @@ void TransformData(const OpKernelType &expected_kernel_type,
} }
out.set_layout(DataLayout::kMKLDNN); out.set_layout(DataLayout::kMKLDNN);
out.set_format(out_format); out.set_format(out_format);
#endif
} else { } else {
// Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel // Case2 - transfrom from MKLDNN OPKernel to Non-MKLDNN OPKernel
// Do transform via MKLDNN lib // Do transform via MKLDNN lib
...@@ -78,6 +77,10 @@ void TransformData(const OpKernelType &expected_kernel_type, ...@@ -78,6 +77,10 @@ void TransformData(const OpKernelType &expected_kernel_type,
// Case3 - transfrom between Non-MKLDNN OPKernels // Case3 - transfrom between Non-MKLDNN OPKernels
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out); TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
} }
#else
// Case3 - transfrom between Non-MKLDNN OPKernels
TransDataLayout(kernel_type_for_var, expected_kernel_type, in, &out);
#endif
transformed = true; transformed = true;
PassTensorData(&out, &in); PassTensorData(&out, &in);
} }
......
...@@ -48,7 +48,11 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -48,7 +48,11 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
int groups = ctx->Attrs().Get<int>("groups"); int groups = ctx->Attrs().Get<int>("groups");
std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations"); std::vector<int> dilations = ctx->Attrs().Get<std::vector<int>>("dilations");
const std::string data_format = ctx->Attrs().Get<std::string>("data_format"); const std::string data_format = ctx->Attrs().Get<std::string>("data_format");
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (this->IsMKLDNNType() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_dims.size() == 4 || in_dims.size() == 5, true, in_dims.size() == 4 || in_dims.size() == 5, true,
...@@ -151,15 +155,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -151,15 +155,6 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(data_format, "NHWC",
platform::errors::Unimplemented(
"Conv MKLDNN does not support NHWC data format yet"));
PADDLE_ENFORCE_NE(
data_format, "NDHWC",
platform::errors::Unimplemented(
"Conv MKLDNN does not support NDHWC data format yet"));
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
customized_type_value = customized_type_value =
...@@ -197,6 +192,32 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -197,6 +192,32 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
return type; return type;
} }
framework::OpKernelType ConvOp::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "Input") &&
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(),
framework::StringToDataLayout(data_format));
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
void Conv2DOpMaker::Make() { void Conv2DOpMaker::Make() {
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
......
...@@ -258,6 +258,10 @@ class ConvOp : public framework::OperatorWithKernel { ...@@ -258,6 +258,10 @@ class ConvOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
}; };
class ConvOpGrad : public framework::OperatorWithKernel { class ConvOpGrad : public framework::OperatorWithKernel {
......
...@@ -48,8 +48,9 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -48,8 +48,9 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
ctx->Attrs().Get<std::string>("padding_algorithm"); ctx->Attrs().Get<std::string>("padding_algorithm");
const std::string data_layout_str = const std::string data_layout_str =
ctx->Attrs().Get<std::string>("data_format"); ctx->Attrs().Get<std::string>("data_format");
const framework::DataLayout data_layout = const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str); this->IsMKLDNNType() ? DataLayout::kNCHW
: framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true, PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
"ShapeError: input of Op(conv_transpose) should be 4-D or " "ShapeError: input of Op(conv_transpose) should be 4-D or "
...@@ -145,11 +146,6 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -145,11 +146,6 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain && if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) { platform::CanMKLDNNBeUsed(ctx)) {
// TODO(jczaja): Add support for NHWC
const std::string data_format = ctx.Attr<std::string>("data_format");
PADDLE_ENFORCE_NE(
data_format, "NHWC",
"Conv Transpose MKLDNN does not support NHWC data format yet");
library_ = framework::LibraryType::kMKLDNN; library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
...@@ -160,6 +156,32 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -160,6 +156,32 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
layout_, library_); layout_, library_);
} }
framework::OpKernelType ConvTransposeOp::GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const {
#ifdef PADDLE_WITH_MKLDNN
// Only input require reshaping, weights and
// bias are having shape in NCHW order
if ((var_name == "Input") &&
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) &&
(tensor.layout() != framework::DataLayout::kMKLDNN)) {
auto attrs = Attrs();
auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format");
auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool
// op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(),
framework::StringToDataLayout(data_format));
}
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
void Conv2DTransposeOpMaker::Make() { void Conv2DTransposeOpMaker::Make() {
AddAttr<bool>("is_test", AddAttr<bool>("is_test",
"(bool, default false) Set to true for inference only, false " "(bool, default false) Set to true for inference only, false "
......
...@@ -98,6 +98,10 @@ class ConvTransposeOp : public framework::OperatorWithKernel { ...@@ -98,6 +98,10 @@ class ConvTransposeOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
}; };
class ConvTransposeOpGrad : public framework::OperatorWithKernel { class ConvTransposeOpGrad : public framework::OperatorWithKernel {
......
...@@ -220,9 +220,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -220,9 +220,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
* ('any') which lets a primitive (convolution in this case) choose * ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance * the memory format preferred for best performance
*/ */
// TODO(jczaja): This is workaround to make grad op UT's numerical
// gradient computation proper as this op is called directly without
// fetch op following it , so numercial grad is computed (in python)
// using block formats which will give wrong results
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format = auto chosen_memory_format =
platform::data_format_to_memory_format(data_format); is_test ? MKLDNNMemoryFormat::any
: platform::data_format_to_memory_format(data_format);
weights_format = MKLDNNMemoryFormat::any; weights_format = MKLDNNMemoryFormat::any;
// Check the format for user's special output // Check the format for user's special output
...@@ -519,9 +524,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -519,9 +524,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
* ('any') which lets a primitive (convolution in this case) choose * ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance * the memory format preferred for best performance
*/ */
std::string data_format = ctx.Attr<std::string>("data_format"); auto chosen_memory_format = MKLDNNMemoryFormat::any;
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
std::vector<int> bias_tz; std::vector<int> bias_tz;
...@@ -772,18 +775,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -772,18 +775,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
* ('any') which lets a primitive (conv backward in this case) choose * ('any') which lets a primitive (conv backward in this case) choose
* the memory format preferred for best performance * the memory format preferred for best performance
*/ */
std::string data_format = ctx.Attr<std::string>("data_format"); auto chosen_memory_format = MKLDNNMemoryFormat::any;
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
weights_format = MKLDNNMemoryFormat::any; weights_format = MKLDNNMemoryFormat::any;
// Check the format for user's special output
if (chosen_memory_format != MKLDNNMemoryFormat::any) {
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
}
}
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
......
...@@ -156,9 +156,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -156,9 +156,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
* ('any') which lets a primitive (convolution in this case) choose * ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance * the memory format preferred for best performance
*/ */
std::string data_format = ctx.Attr<std::string>("data_format"); auto chosen_memory_format = MKLDNNMemoryFormat::any;
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation"); std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
float fuse_alpha = ctx.Attr<float>("fuse_alpha"); float fuse_alpha = ctx.Attr<float>("fuse_alpha");
float fuse_beta = ctx.Attr<float>("fuse_beta"); float fuse_beta = ctx.Attr<float>("fuse_beta");
......
...@@ -35,7 +35,7 @@ class TestConv2dInt8Op(TestConv2dOp): ...@@ -35,7 +35,7 @@ class TestConv2dInt8Op(TestConv2dOp):
self.exhaustive_search = False self.exhaustive_search = False
self.use_cuda = False self.use_cuda = False
self.use_mkldnn = False self.use_mkldnn = False
self.data_format = "AnyLayout" self.data_format = "NCHW"
self.weighttype = np.float32 self.weighttype = np.float32
self.use_mkldnn = True self.use_mkldnn = True
self.init_group() self.init_group()
......
...@@ -197,5 +197,38 @@ class TestConv2dOp_Valid_MKLDNN(TestConv2dOp_AsyPadding_MKLDNN): ...@@ -197,5 +197,38 @@ class TestConv2dOp_Valid_MKLDNN(TestConv2dOp_AsyPadding_MKLDNN):
self.padding_algorithm = "VALID" self.padding_algorithm = "VALID"
class TestConv2dOp_Valid_NHWC_MKLDNN(TestConv2dOp_Valid_MKLDNN):
def init_data_format(self):
self.data_format = "NHWC"
def init_test_case_2(self):
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
#TODO(jczaja): Enable once GRAD op is adjusted
def test_check_grad(self):
pass
#TODO(jczaja): Enable once GRAD op is adjusted
def test_check_grad_no_filter(self):
pass
#TODO(jczaja): Enable once GRAD op is adjusted
def test_check_grad_no_input(self):
pass
class TestConv2dOp_Same_NHWC_MKLDNN(TestConv2dOp_Valid_NHWC_MKLDNN):
def init_paddings(self):
self.pad = [0, 0]
self.padding_algorithm = "SAME"
class TestConv2dOp_AsyPadding_NHWC_MKLDNN(TestConv2dOp_Valid_NHWC_MKLDNN):
def init_paddings(self):
self.pad = [0, 0, 1, 2]
self.padding_algorithm = "EXPLICIT"
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -126,3 +126,11 @@ class TestMKLDNNWithValidPad(TestConv2dTransposeMKLDNNOp): ...@@ -126,3 +126,11 @@ class TestMKLDNNWithValidPad(TestConv2dTransposeMKLDNNOp):
TestConv2dTransposeMKLDNNOp.init_test_case(self) TestConv2dTransposeMKLDNNOp.init_test_case(self)
self.pad = [1, 1] self.pad = [1, 1]
self.padding_algorithm = "VALID" self.padding_algorithm = "VALID"
class TestMKLDNNWithValidPad_NHWC(TestMKLDNNWithValidPad):
def init_test_case(self):
super(TestMKLDNNWithValidPad, self).init_test_case()
self.data_format = "NHWC"
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册