diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 834a2c953eab833d9957ddf5b0770178d922015a..48ccadd037363caf44e3eb190b913e7717e6c0f9 100755 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -128,6 +128,7 @@ if(WITH_MKLDNN) pass_library(fc_mkldnn_pass inference DIR mkldnn) pass_library(interpolate_mkldnn_pass inference DIR mkldnn) pass_library(softplus_activation_mkldnn_fuse_pass inference DIR mkldnn) + pass_library(shuffle_channel_mkldnn_detect_pass inference DIR mkldnn) pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn) diff --git a/paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass.cc b/paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..bf603dc4bbcb9ddf6bfcff9326fa0cc05682050b --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass.cc @@ -0,0 +1,237 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass.h" +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(reshape1_op); \ + GET_IR_NODE(reshape1_out); \ + GET_IR_NODE(transpose_op); \ + GET_IR_NODE(transpose_out); \ + GET_IR_NODE(reshape2_op); \ + GET_IR_NODE(reshape2_out); + +ShuffleChannelMKLDNNDetectPass::ShuffleChannelMKLDNNDetectPass() { + AddOpCompat(OpCompat("reshape2")) + .AddInput("X") + .IsTensor() + .End() + .AddInput("Shape") + .IsOptional() + .IsTensor() + .End() + .AddInput("ShapeTensor") + .IsOptional() + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("shape") + .IsType>() + .End(); + + AddOpCompat(OpCompat("transpose2")) + .AddInput("X") + .IsTensor() + .End() + .AddOutput("XShape") + .IsTensor() + .End() + .AddOutput("Out") + .IsTensor() + .End() + .AddAttr("axis") + .IsType>() + .End(); +} + +void ShuffleChannelMKLDNNDetectPass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = "shufflechannel_pattern"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + auto* x = gpd.mutable_pattern() + ->NewNode("x") + ->assert_is_op_input("reshape2", "X") + ->AsInput(); + + patterns::ShuffleChannelPattern pattern(gpd.mutable_pattern(), pattern_name); + pattern(x); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + if (!IsCompat(subgraph, g)) { + LOG(WARNING) << "The Pass in op compat failed."; + return; + } + PADDLE_ENFORCE_GT( + subgraph.count(x), 0, + platform::errors::NotFound("Detector did not find input X.")); + auto* input_node = subgraph.at(x); + auto reshape1_desc = reshape1_op->Op(); + auto reshape2_desc = reshape2_op->Op(); + auto trans_desc = transpose_op->Op(); + std::string input_name = input_node->Name(); + std::string output_name = reshape2_out->Name(); + + auto reshape1_shape = + BOOST_GET_CONST(std::vector, reshape1_desc->GetAttr("shape")); + auto reshape2_shape = + BOOST_GET_CONST(std::vector, reshape2_desc->GetAttr("shape")); + auto trans_axis = + BOOST_GET_CONST(std::vector, trans_desc->GetAttr("axis")); + auto* block1 = reshape1_desc->Block(); + auto* block2 = reshape2_desc->Block(); + if (block1 && block2) { + auto x_var_name = reshape1_desc->Input("X")[0]; + auto* x_var_desc = block1->FindVar(x_var_name); + auto x_shape1 = x_var_desc->GetShape(); + x_var_name = reshape2_desc->Input("X")[0]; + x_var_desc = block2->FindVar(x_var_name); + auto x_shape2 = x_var_desc->GetShape(); + // now shuffle_channel is 4D(NCHW) only. + if (x_shape1.size() != 4 || reshape1_shape.size() != 5 || + reshape2_shape.size() != 4 || trans_axis.size() != 5) { + return; + } + + // process 0 and -1 in reshape. + constexpr int64_t copy_dim_val = 0; + for (size_t i = 0; i < reshape1_shape.size(); i++) { + if (reshape1_shape[i] == copy_dim_val) { + reshape1_shape[i] = x_shape1[i]; + } + } + for (size_t i = 0; i < reshape2_shape.size(); i++) { + if (reshape2_shape[i] == copy_dim_val) { + reshape2_shape[i] = x_shape2[i]; + } + } + constexpr int64_t unk_dim_idx = -1; + bool all_positive = std::all_of(x_shape1.cbegin(), x_shape1.cend(), + [](int64_t i) { return i > 0; }); + for (size_t i = 0; i < reshape1_shape.size(); ++i) { + // if -1 is not in batch dim, try to calculate number + if ((reshape1_shape[i] == unk_dim_idx) && (i != 0)) { + // there is no sufficient info + if (!all_positive) return; + reshape1_shape[i] = + std::accumulate(x_shape1.begin(), x_shape1.end(), + static_cast(1), + std::multiplies()) / + std::accumulate(reshape1_shape.begin(), reshape1_shape.end(), + static_cast(-1), + std::multiplies()); + break; + } + } + + all_positive = std::all_of(x_shape2.cbegin(), x_shape2.cend(), + [](int64_t i) { return i > 0; }); + for (size_t i = 0; i < reshape2_shape.size(); ++i) { + // if -1 is not in batch dim, try to calculate number + if ((reshape2_shape[i] == unk_dim_idx) && (i != 0)) { + // there is no sufficient info + if (!all_positive) return; + reshape2_shape[i] = + std::accumulate(x_shape2.begin(), x_shape2.end(), + static_cast(1), + std::multiplies()) / + std::accumulate(reshape2_shape.begin(), reshape2_shape.end(), + static_cast(-1), + std::multiplies()); + break; + } + } + + // shuffle_channel dosen't change shape + if ((reshape2_shape[0] != -1) && (x_shape1[0] != reshape2_shape[0])) { + return; + } + for (size_t i = 1; i < x_shape1.size(); i++) { + if (x_shape1[i] != reshape2_shape[i]) { + return; + } + } + if ((reshape2_shape[3] != reshape1_shape[4]) || + (reshape2_shape[2] != reshape1_shape[3])) { + return; + } + } else { + return; // conservative judgement + } + + int i_c = reshape1_shape[2]; + int o_c = reshape2_shape[1]; + int group = o_c / i_c; + // should split on channel dim + if (reshape2_shape[1] != reshape1_shape[2] * reshape1_shape[1]) return; + // trans on channel dim + if (trans_axis[0] != 0 || trans_axis[3] != 3 || trans_axis[4] != 4) return; + if (group != 1 && i_c != 1) { + if (trans_axis[1] != 2 && trans_axis[2] != 1) { + return; + } + } + + framework::OpDesc new_op_desc; + new_op_desc.SetType("shuffle_channel"); + new_op_desc.SetInput("X", {input_name}); + new_op_desc.SetOutput("Out", {output_name}); + + new_op_desc.SetAttr("group", group); + new_op_desc.SetAttr("use_mkldnn", true); + new_op_desc.Flush(); + + // Create a new node for the fused op. + auto* new_op = graph->CreateOpNode(&new_op_desc); + + IR_NODE_LINK_TO(input_node, new_op); + IR_NODE_LINK_TO(new_op, reshape2_out); + + // Delete the unneeded nodes. + GraphSafeRemoveNodes(graph, {reshape1_op, reshape1_out, transpose_op, + transpose_out, reshape2_op}); + LOG_FIRST_N(WARNING, 1) + << "There is fluid.layers.shuffle_channel API already, maybe you can " + "use it instead of (reshape + transpose + reshape)"; + }; + + gpd(graph, handler); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(shuffle_channel_mkldnn_detect_pass, + paddle::framework::ir::ShuffleChannelMKLDNNDetectPass); +REGISTER_PASS_CAPABILITY(shuffle_channel_mkldnn_detect_pass) + .AddCombination( + paddle::framework::compatible::OpVersionComparatorCombination() + .EQ("reshape2", 0) + .EQ("transpose2", 0)); diff --git a/paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass.h b/paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..231b63c3b6a0024c79e843f0b3359b382683a330 --- /dev/null +++ b/paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass.h @@ -0,0 +1,38 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +class ShuffleChannelMKLDNNDetectPass : public FusePassBase { + public: + ShuffleChannelMKLDNNDetectPass(); + virtual ~ShuffleChannelMKLDNNDetectPass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 20418e37a7b94c38a2cfa76d0db6cc63ae5b3d52..d0fe3953d00d6b2b043f6f08b92789422f927225 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -298,6 +298,7 @@ void CpuPassStrategy::EnableMKLDNN() { // "fc_act_mkldnn_fuse_pass", "batch_norm_act_fuse_pass", // "softplus_activation_mkldnn_fuse_pass", // + "shuffle_channel_mkldnn_detect_pass", // "elt_act_mkldnn_fuse_pass", // // TODO(intel): Please fix the bug on windows. // https://github.com/PaddlePaddle/Paddle/issues/29710 diff --git a/paddle/fluid/operators/mkldnn/shuffle_channel_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/shuffle_channel_mkldnn_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..408de57bf946d22b1a8912161303b541fe257f40 --- /dev/null +++ b/paddle/fluid/operators/mkldnn/shuffle_channel_mkldnn_op.cc @@ -0,0 +1,77 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/platform/mkldnn_reuse.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +using platform::MKLDNNGetDataType; +template +class ShuffleChannelMKLDNNHandler + : public platform::MKLDNNHandlerNoCachingT { + public: + ShuffleChannelMKLDNNHandler(const Tensor* x, const int group, + const dnnl::engine engine, + platform::Place cpu_place) + : platform::MKLDNNHandlerNoCachingT(engine, + cpu_place) { + static constexpr int channel_axis = 1; + const auto md = dnnl::memory::desc(phi::vectorize(x->dims()), + MKLDNNGetDataType(), x->format()); + + this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training, + md, channel_axis, group); + } +}; + +template +class ShuffleChannelMKLDNNKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto& dev_ctx = + ctx.template device_context(); + const auto& mkldnn_engine = dev_ctx.GetEngine(); + + const auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + // oneDNN handles group using C/g instead of g + const int group = x->dims()[1] / ctx.Attr("group"); + + ShuffleChannelMKLDNNHandler handler(x, group, mkldnn_engine, + ctx.GetPlace()); + + auto src_memory_p = handler.AcquireSrcMemory(x); + auto dst_memory_p = handler.AcquireDstMemory(out); + + auto shuffle_p = handler.AcquireForwardPrimitive(); + + auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); + shuffle_p->execute(astream, {{DNNL_ARG_SRC, *src_memory_p}, + {DNNL_ARG_DST, *dst_memory_p}}); + astream.wait(); + + out->set_layout(framework::DataLayout::kMKLDNN); + out->set_format(x->format()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_KERNEL(shuffle_channel, MKLDNN, paddle::platform::CPUPlace, + ops::ShuffleChannelMKLDNNKernel, + ops::ShuffleChannelMKLDNNKernel); diff --git a/paddle/fluid/operators/shuffle_channel_op.cc b/paddle/fluid/operators/shuffle_channel_op.cc index 119d2e7236946e7243ef53c791f4bb7f48d21c91..70fddc9b04712d53af79651bb2c164846268608e 100644 --- a/paddle/fluid/operators/shuffle_channel_op.cc +++ b/paddle/fluid/operators/shuffle_channel_op.cc @@ -35,9 +35,17 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -56,6 +64,10 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { PADDLE_ENFORCE_GE(group, 1, platform::errors::InvalidArgument( "group should be larger than 0.")); }); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false) + .AsExtra(); AddComment(R"DOC( Shuffle Channel operator diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shuffle_channel_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shuffle_channel_op.py new file mode 100644 index 0000000000000000000000000000000000000000..26655970290cdb15be3b517de146276a2cd6b809 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_shuffle_channel_op.py @@ -0,0 +1,61 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from auto_scan_test import MkldnnAutoScanTest +from program_config import TensorConfig, ProgramConfig, OpConfig +import numpy as np +from functools import partial +import unittest +from hypothesis import given +import hypothesis.strategies as st + + +class TestMKLDNNShuffleChannelOp(MkldnnAutoScanTest): + def is_program_valid(self, program_config: ProgramConfig) -> bool: + return True + + def sample_program_configs(self, *args, **kwargs): + def generate_input(*args, **kwargs): + return np.random.random(kwargs['in_shape']).astype(np.float32) + + shuffle_channel_op = OpConfig( + type="shuffle_channel", + inputs={"X": ["input_data"]}, + outputs={"Out": ["output_data"]}, + attrs={"group": kwargs['group']}) + + program_config = ProgramConfig( + ops=[shuffle_channel_op], + weights={}, + inputs={ + "input_data": TensorConfig(data_gen=partial(generate_input, + *args, **kwargs)), + }, + outputs=["output_data"]) + + yield program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_mkldnn=True) + yield config, (1e-5, 1e-5) + + @given( + group=st.sampled_from([1, 2, 8, 32, 128]), + in_shape=st.sampled_from([[5, 512, 2, 3], [2, 256, 5, 4]])) + def test(self, *args, **kwargs): + self.run_test(quant=False, *args, **kwargs) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_shuffle_channel_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_shuffle_channel_mkldnn_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1d657817503deb8debaddfeaf524f166a4b9e177 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_shuffle_channel_mkldnn_op.py @@ -0,0 +1,62 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + + +@OpTestTool.skip_if_not_cpu_bf16() +class TestShuffleChannelOneDNNOp(OpTest): + def setUp(self): + self.op_type = "shuffle_channel" + self.set_dtype() + self.set_group() + self.inputs = {'X': np.random.random((5, 64, 2, 3)).astype(self.dtype)} + self.attrs = {'use_mkldnn': True, 'group': self.group} + + _, c, h, w = self.inputs['X'].shape + input_reshaped = np.reshape(self.inputs['X'], + (-1, self.group, c // self.group, h, w)) + input_transposed = np.transpose(input_reshaped, (0, 2, 1, 3, 4)) + self.outputs = {'Out': np.reshape(input_transposed, (-1, c, h, w))} + + def set_dtype(self): + self.dtype = np.float32 + + def set_group(self): + self.group = 4 + + def test_check_output(self): + self.check_output_with_place(core.CPUPlace()) + + +class TestShuffleChannelSingleGroupOneDNNOp(TestShuffleChannelOneDNNOp): + def set_group(self): + self.group = 1 + + +class TestShuffleChannelBF16OneDNNOp(TestShuffleChannelOneDNNOp): + def set_dtype(self): + self.dtype = np.uint16 + + +if __name__ == "__main__": + paddle.enable_static() + unittest.main()