diff --git a/paddle/fluid/imperative/layout_autotune.cc b/paddle/fluid/imperative/layout_autotune.cc index 669a4af99f31f8354c798c66669f0060dca51991..10a4a2e69d5402d8e309af97de6c556bad08100c 100644 --- a/paddle/fluid/imperative/layout_autotune.cc +++ b/paddle/fluid/imperative/layout_autotune.cc @@ -131,6 +131,8 @@ paddle::imperative::NameVarMap DealLightlyLayoutSensitive( transposer = std::make_shared>(op_type); } else if (op_type == "arg_max") { transposer = std::make_shared>(op_type); + } else if (op_type == "concat") { + transposer = std::make_shared>(op_type); } else if (op_type.find("elementwise_") != std::string::npos) { transposer = std::make_shared>(op_type); } else { diff --git a/paddle/fluid/imperative/layout_transformer.h b/paddle/fluid/imperative/layout_transformer.h index ab7619dedb2e9df8c8299e5856cddc15bcbc1c1e..fa7261b6d52b6244a73b81e260ec824baf9832f1 100644 --- a/paddle/fluid/imperative/layout_transformer.h +++ b/paddle/fluid/imperative/layout_transformer.h @@ -401,5 +401,51 @@ class ArgmaxOpTransformer } }; +template +class ConcatOpTransformer + : public LightlyLayoutSensitiveOpTransformer { + public: + explicit ConcatOpTransformer(const std::string& type) + : LightlyLayoutSensitiveOpTransformer(type) {} + + paddle::imperative::NameVarMap Apply( + const paddle::imperative::NameVarMap& ins, + const paddle::imperative::NameVarMap& outs, + paddle::framework::AttributeMap* attrs, + const std::shared_ptr& tracer) { + VLOG(3) << "Optimze lightly layout sensitive op " << this->Type(); + auto& in_var = ins.at("X")[0]; + auto var_layout = paddle::imperative::GetDataLayout(in_var); + bool need_tranppose = false; + for (auto& pair : ins) { + for (auto& var : pair.second) { + if (var != nullptr && + (paddle::imperative::GetDataLayout(var) != var_layout)) { + need_tranppose = true; + break; + } + } + } + + if (need_tranppose) { + return LightlyLayoutSensitiveOpTransformer::Apply( + ins, outs, attrs, tracer); + } + + if (var_layout != DataLayout::UNDEFINED) { + std::vector perm_nhwc = {0, 3, 1, 2}; + std::vector perm_nchw = {0, 2, 3, 1}; + auto perm = var_layout == DataLayout::NHWC ? perm_nhwc : perm_nchw; + auto axis = BOOST_GET_CONST(int, (*attrs)["axis"]); + (*attrs)["axis"] = static_cast(perm[axis]); + } + auto axis = BOOST_GET_CONST(int, (*attrs)["axis"]); + VLOG(3) << "Optimze lightly layout sensitive op asdfasdfasdf axis" << axis; + + this->SetVarsLayout(outs, var_layout); + return ins; + } +}; + } // namespace imperative } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_layout_autotune.py b/python/paddle/fluid/tests/unittests/test_layout_autotune.py index fc9b51c5fc0402ed21b9830e5e49d8b57e7516e4..5cb53437fe9cda78df836456dd3e2ce72365164d 100644 --- a/python/paddle/fluid/tests/unittests/test_layout_autotune.py +++ b/python/paddle/fluid/tests/unittests/test_layout_autotune.py @@ -161,6 +161,35 @@ class LayoutAutoTune(unittest.TestCase): self.assertEqual(conv_out.shape, [1, 14, 12, 8]) self.assertEqual(out.shape, [1]) + def test_concat_op_transposer(self): + if not self.use_autoune(): + return + in1 = paddle.rand([1, 8, 14, 12]) + conv = paddle.nn.Conv2D(3, 8, (3, 3)) + data = paddle.rand([1, 3, 16, 14]) + with paddle.amp.auto_cast(level="O2"): + conv_out = conv(data) + # conv_out.shape = [1, 14, 12, 8] with NHWC + out = paddle.concat(x=[conv_out, in1], axis=0) + + self.assertEqual(conv_out.shape, [1, 14, 12, 8]) + self.assertEqual(out.shape, [2, 8, 14, 12]) + + def test_concat_op_no_transposer(self): + if not self.use_autoune(): + return + conv = paddle.nn.Conv2D(3, 8, (3, 3)) + data1 = paddle.rand([1, 3, 16, 14]) + data2 = paddle.rand([1, 3, 16, 14]) + with paddle.amp.auto_cast(level="O2"): + conv_out1 = conv(data1) + conv_out2 = conv(data2) + # conv_out.shape = [1, 14, 12, 8] with NHWC + out = paddle.concat(x=[conv_out1, conv_out2], axis=0) + + self.assertEqual(conv_out1.shape, [1, 14, 12, 8]) + self.assertEqual(out.shape, [2, 14, 12, 8]) + class TestAutoTuneAPI(unittest.TestCase):