diff --git a/paddle/fluid/operators/mkldnn/pad3d_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/pad3d_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..e7a528c452b8dfc8c97cd3816687405cbf931c19 --- /dev/null +++ b/paddle/fluid/operators/mkldnn/pad3d_mkldnn_op.cc @@ -0,0 +1,223 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/utils.h" +#include "paddle/fluid/platform/mkldnn_reuse.h" +namespace paddle { +namespace operators { + +using framework::Tensor; + +/* +Pad3D is done by using up to 7 reorders. Following example is done +on 2D data for simplicity, but it is straightforward to extend it to 3D case. + +Let us consider following example: + + N C H W L R T B +X_dims = (1, 1, 3, 3), paddings = (1, 2, 3, 4) in order Left, Right, Top, Bottom + +We have to copy the X tensor into Out tensor, but except from that we have to +fill the rest of the memory with an additional padding. To avoid looping through +the whole Out memory two times, only these parts of Out memory that won't store +X's memory are filled with pad value. That behavior is achieved by using +oneDNN's submemory descriptors which allows us to set offsets for each dimension +and skip some parts of the memory. For 2D case up to 5 reorders will be used in +Pad3D kernel(if padding=0 reorder is skipped). In the following example i'th +number means, that this part of memory was filled by i'th reorder. 4'th reorder +is copying X memory into Out memory. i&j means that both i'th and j'th reorder +will set the padding at that location: + + INDEX + | 0 1 2 3 4 5 + |_______________________ + 0 |0&2 2 2 2 1&2 1&2 + 1 |0&2 2 2 2 1&2 1&2 +I 2 |0&2 2 2 2 1&2 1&2 +N 3 | 0 4 4 4 1 1 +D 4 | 0 4 4 4 1 1 +E 5 | 0 4 4 4 1 1 +X 6 |0&3 3 3 3 1&3 1&3 + 7 |0&3 3 3 3 1&3 1&3 + 8 |0&3 3 3 3 1&3 1&3 + 9 |0&3 3 3 3 1&3 1&3 + +Since oneDNN's reorder cannot set the pad value to the memory by itself, we have +to prefill Out's memory and use it as a temporary buffer, which later is copied +into the rest of Out's memory. At the end last reorder is done which copies X +memory into Out memory. + +*/ +template +class PadMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + this->RunKernel(ctx); + } + + void RunKernel(const framework::ExecutionContext& ctx) const { + const auto& dev_ctx = + ctx.template device_context(); + const auto& onednn_engine = dev_ctx.GetEngine(); + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + + auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + auto* paddings_tensor = ctx.Input("Paddings"); + std::vector paddings(ctx.Attr>("paddings")); + if (paddings_tensor) { + std::copy(paddings_tensor->data(), + paddings_tensor->data() + paddings_tensor->numel(), + paddings.data()); + } + // pad2d has paddings in order top, bottom, left, right, so we need + // to swap some of them to unify paddings between pad2d and pad3d + if (ctx.Type() == "pad2d") { + std::swap(paddings[0], paddings[2]); + std::swap(paddings[1], paddings[3]); + } + + const std::string pad_attr_name = + ctx.Type() == "pad3d" ? "value" : "pad_value"; + T pad_value = static_cast(ctx.Attr(pad_attr_name)); + + std::vector x_tz = phi::vectorize(x->dims()); + // due to the need of supporting NDHWC, inferring out shape + // must be done inside the kernel + std::vector out_tz(x_tz); + + for (size_t i = 0; i < paddings.size() / 2; ++i) { + out_tz[out_tz.size() - 1 - i] += paddings[2 * i] + paddings[2 * i + 1]; + } + out->Resize(phi::make_ddim(out_tz)); + + auto paddle_dtype = framework::TransToProtoVarType(x->dtype()); + + platform::ReorderMKLDNNHandler reorder_handler( + x_tz, + paddle_dtype, + framework::ToMKLDNNDataType(paddle_dtype), + onednn_engine); + + auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( + x->mem_desc(), platform::to_void_cast(x->data())); + auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( + out, + out_tz, + platform::GetPlainMKLDNNFormat(out_tz.size()), + ctx.GetPlace()); + + // to avoid allocating new temporary memory, Out's memory is used as a tmp + // buffer for storing a contiguous memory consisting of pad_value, which + // later is used as a SRC for reorders that are filling Out with padding + T* out_ptr = out->data(); + std::fill(out_ptr, + out_ptr + CalculateNumOfPrefillElems(out_tz, paddings), + pad_value); + + // paddings are in order: left, right, top, bottom, front, back + for (size_t i = 0; i < paddings.size(); ++i) { + if (paddings[i] != 0) { + std::vector offsets(out_tz.size(), 0); + std::vector chunk_tz(out_tz.begin(), out_tz.end()); + + chunk_tz[out_tz.size() - 1 - i / 2] = paddings[i]; + if (i % 2 == 1) { + offsets[out_tz.size() - 1 - i / 2] = + paddings[i - 1] + x_tz[out_tz.size() - 1 - i / 2]; + } + + FillPartOfPadding(paddle_dtype, + onednn_engine, + out_ptr, + reorder_dst_memory_p, + chunk_tz, + offsets); + } + } + astream.wait(); + + std::vector offsets(out_tz.size(), 0); + for (size_t i = 0; i < paddings.size() / 2; ++i) { + offsets[out_tz.size() - 1 - i] = paddings[2 * i]; + } + + auto slice_mem_p = + reorder_handler.AcquireSubmemory(x_tz, offsets, reorder_dst_memory_p); + + auto reorder_p = + reorder_handler.AcquireReorder(slice_mem_p, reorder_src_memory_p); + reorder_p->execute(astream, *reorder_src_memory_p, *slice_mem_p); + astream.wait(); + + out->set_mem_desc(reorder_dst_memory_p->get_desc()); + } + + int64_t CalculateNumOfPrefillElems(const std::vector& out_tz, + const std::vector& paddings) const { + int64_t max_elems = 0; + int64_t independent_dims = out_tz[0] * out_tz[1]; + + for (size_t i = 0; i < paddings.size() / 2; ++i) { + int64_t elems = std::max(paddings[2 * i], paddings[2 * i + 1]); + for (size_t j = 0; j < paddings.size() / 2; ++j) { + if (j != i) { + elems *= out_tz[out_tz.size() - 1 - j]; + } + } + + if (max_elems < elems) { + max_elems = elems; + } + } + return independent_dims * max_elems; + } + + void FillPartOfPadding(framework::proto::VarType::Type paddle_dtype, + const dnnl::engine& onednn_engine, + T* prefilled_mem_ptr, + const std::shared_ptr& out_mem_p, + const std::vector& chunk_tz, + const std::vector& offsets) const { + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + + dnnl::memory::desc prefilled_mem_desc( + chunk_tz, + platform::MKLDNNGetDataType(), + platform::GetPlainMKLDNNFormat(chunk_tz.size())); + dnnl::memory prefilled_mem( + prefilled_mem_desc, onednn_engine, prefilled_mem_ptr); + + dnnl::memory::desc out_slice_md = + out_mem_p->get_desc().submemory_desc(chunk_tz, {offsets}); + dnnl::memory out_slice_mem( + out_slice_md, onednn_engine, out_mem_p->get_data_handle()); + + auto reorder_p = dnnl::reorder(prefilled_mem, out_slice_mem); + reorder_p.execute(astream, prefilled_mem, out_slice_mem); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(pad3d, + MKLDNN, + paddle::platform::CPUPlace, + ops::PadMKLDNNKernel); + +REGISTER_OP_KERNEL(pad2d, + MKLDNN, + paddle::platform::CPUPlace, + ops::PadMKLDNNKernel); diff --git a/paddle/fluid/operators/pad2d_op.cc b/paddle/fluid/operators/pad2d_op.cc index c2dfb8e61e5eb99032c15d9d6dcb81b336297ba7..0af4261c279bfd3345b1be94e45cddfd9718bb3e 100644 --- a/paddle/fluid/operators/pad2d_op.cc +++ b/paddle/fluid/operators/pad2d_op.cc @@ -699,8 +699,41 @@ class Pad2dOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); +#ifdef PADDLE_WITH_MKLDNN + // only constant mode and non-blocked layouts are supported for oneDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type) && + ctx.Attr("mode") == "constant" && + ctx.Input("X") + ->mem_desc() + .data.format_desc.blocking.inner_nblks == 0) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, + const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { +#ifdef PADDLE_WITH_MKLDNN + if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) && + (tensor.layout() != framework::DataLayout::kMKLDNN)) { + auto attrs = Attrs(); + auto ar = paddle::framework::AttrReader(attrs); + const std::string data_format = ar.Get("data_format"); + return framework::OpKernelType( + expected_kernel_type.data_type_, + tensor.place(), + framework::StringToDataLayout(data_format)); + } +#endif return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } }; @@ -740,6 +773,10 @@ class Pad2dOpMaker : public framework::OpProtoAndCheckerMaker { "An optional string from: \"NHWC\", \"NCHW\". " "Defaults to \"NHWC\". Specify the data format of the input data.") .SetDefault("NCHW"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false) + .AsExtra(); AddComment(R"DOC( Pad2d Operator. Pad 2-d images according to 'paddings' and 'mode'. diff --git a/paddle/fluid/operators/pad3d_op.cc b/paddle/fluid/operators/pad3d_op.cc index 21c24210afd458d68a3959ecdc1036dac774612c..e4b32b3d7a76ea9043bc63a2d19211c1c4295849 100644 --- a/paddle/fluid/operators/pad3d_op.cc +++ b/paddle/fluid/operators/pad3d_op.cc @@ -34,8 +34,41 @@ class Pad3dOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); +#ifdef PADDLE_WITH_MKLDNN + // only constant mode and non-blocked layouts are supported for oneDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type) && + ctx.Attr("mode") == "constant" && + ctx.Input("X") + ->mem_desc() + .data.format_desc.blocking.inner_nblks == 0) { + return framework::OpKernelType(input_data_type, + ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } + + framework::OpKernelType GetKernelTypeForVar( + const std::string& var_name, + const Tensor& tensor, + const framework::OpKernelType& expected_kernel_type) const { +#ifdef PADDLE_WITH_MKLDNN + if ((expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN) && + (tensor.layout() != framework::DataLayout::kMKLDNN)) { + auto attrs = Attrs(); + auto ar = paddle::framework::AttrReader(attrs); + const std::string data_format = ar.Get("data_format"); + return framework::OpKernelType( + expected_kernel_type.data_type_, + tensor.place(), + framework::StringToDataLayout(data_format)); + } +#endif return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + expected_kernel_type.data_type_, tensor.place(), tensor.layout()); } }; @@ -78,6 +111,10 @@ class Pad3dOpMaker : public framework::OpProtoAndCheckerMaker { "An optional string from: \"NDHWC\", \"NCDHW\". " "Defaults to \"NDHWC\". Specify the data format of the input data.") .SetDefault("NCDHW"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false) + .AsExtra(); AddComment(R"DOC( Pad3d Operator. Pad 3-d images according to 'paddings' and 'mode'. diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_pad2d_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_pad2d_op.py new file mode 100644 index 0000000000000000000000000000000000000000..5a81451febf3997eaab31b8a35fee5f779c550e8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_pad2d_op.py @@ -0,0 +1,65 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import MkldnnAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +from functools import partial +import unittest +from hypothesis import given, reproduce_failure +import hypothesis.strategies as st + + +class TestOneDNNPad2DOp(MkldnnAutoScanTest): + + def sample_program_configs(self, *args, **kwargs): + + def generate_input(*args, **kwargs): + return np.random.random(kwargs['in_shape']).astype(np.float32) + + pad3d_op = OpConfig(type="pad2d", + inputs={"X": ["input_data"]}, + outputs={"Out": ["output_data"]}, + attrs={ + "mode": "constant", + "data_format": kwargs['data_format'], + "paddings": kwargs['paddings'], + }) + + program_config = ProgramConfig( + ops=[pad3d_op], + weights={}, + inputs={ + "input_data": + TensorConfig(data_gen=partial(generate_input, *args, **kwargs)), + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, (1e-5, 1e-5) + + @given(data_format=st.sampled_from(['NCHW', 'NHWC']), + in_shape=st.sampled_from([[2, 3, 4, 5], [1, 4, 1, 3], [4, 3, 2, 1], + [1, 1, 1, 1]]), + paddings=st.sampled_from([[0, 0, 0, 0], [1, 2, 0, 1], [2, 5, 11, 3], + [0, 5, 0, 1]])) + def test(self, *args, **kwargs): + self.run_test(quant=False, *args, **kwargs) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_pad3d_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_pad3d_op.py new file mode 100644 index 0000000000000000000000000000000000000000..acc7fa1e30e2d2af9d408978c633aa30e8aab0f5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_pad3d_op.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import MkldnnAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +from functools import partial +import unittest +from hypothesis import given, reproduce_failure +import hypothesis.strategies as st + + +class TestOneDNNPad3DOp(MkldnnAutoScanTest): + + def sample_program_configs(self, *args, **kwargs): + + def generate_input(*args, **kwargs): + return np.random.random(kwargs['in_shape']).astype(np.float32) + + def generate_paddings(): + return np.random.randint(0, 4, size=(6)).astype(np.int32) + + pad3d_op = OpConfig(type="pad3d", + inputs={ + "X": ["input_data"], + "Paddings": ["paddings_data"] + }, + outputs={"Out": ["output_data"]}, + attrs={ + "mode": "constant", + "data_format": kwargs['data_format'], + "paddings": kwargs['paddings'], + }) + + program_config = ProgramConfig( + ops=[pad3d_op], + weights={}, + inputs={ + "input_data": + TensorConfig(data_gen=partial(generate_input, *args, **kwargs)), + "paddings_data": + TensorConfig(data_gen=generate_paddings) + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, (1e-5, 1e-5) + + @given(data_format=st.sampled_from(['NCDHW', 'NDHWC']), + use_paddings_tensor=st.sampled_from([True, False]), + in_shape=st.sampled_from([[2, 3, 4, 5, 6], [1, 4, 1, 3, 2], + [4, 3, 2, 1, 1], [1, 1, 1, 1, 1]]), + paddings=st.sampled_from([[0, 0, 0, 0, 0, 0], [1, 2, 0, 1, 2, 1], + [2, 5, 11, 3, 4, 3], [0, 5, 0, 1, 0, 2]])) + def test(self, *args, **kwargs): + self.run_test(quant=False, *args, **kwargs) + + +if __name__ == "__main__": + unittest.main()