From 02e4f1f897b3b62a31923cb554c9674769da8e90 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Mon, 11 Jul 2022 19:36:42 +0800 Subject: [PATCH] Add Concat transformer for layout autotune (#42003) * Add Concat transformer for layout autotune --- paddle/fluid/imperative/layout_autotune.cc | 2 + paddle/fluid/imperative/layout_transformer.h | 46 +++++++++++++++++++ .../tests/unittests/test_layout_autotune.py | 29 ++++++++++++ 3 files changed, 77 insertions(+) diff --git a/paddle/fluid/imperative/layout_autotune.cc b/paddle/fluid/imperative/layout_autotune.cc index 669a4af99f3..10a4a2e69d5 100644 --- a/paddle/fluid/imperative/layout_autotune.cc +++ b/paddle/fluid/imperative/layout_autotune.cc @@ -131,6 +131,8 @@ paddle::imperative::NameVarMap DealLightlyLayoutSensitive( transposer = std::make_shared>(op_type); } else if (op_type == "arg_max") { transposer = std::make_shared>(op_type); + } else if (op_type == "concat") { + transposer = std::make_shared>(op_type); } else if (op_type.find("elementwise_") != std::string::npos) { transposer = std::make_shared>(op_type); } else { diff --git a/paddle/fluid/imperative/layout_transformer.h b/paddle/fluid/imperative/layout_transformer.h index ab7619dedb2..fa7261b6d52 100644 --- a/paddle/fluid/imperative/layout_transformer.h +++ b/paddle/fluid/imperative/layout_transformer.h @@ -401,5 +401,51 @@ class ArgmaxOpTransformer } }; +template +class ConcatOpTransformer + : public LightlyLayoutSensitiveOpTransformer { + public: + explicit ConcatOpTransformer(const std::string& type) + : LightlyLayoutSensitiveOpTransformer(type) {} + + paddle::imperative::NameVarMap Apply( + const paddle::imperative::NameVarMap& ins, + const paddle::imperative::NameVarMap& outs, + paddle::framework::AttributeMap* attrs, + const std::shared_ptr& tracer) { + VLOG(3) << "Optimze lightly layout sensitive op " << this->Type(); + auto& in_var = ins.at("X")[0]; + auto var_layout = paddle::imperative::GetDataLayout(in_var); + bool need_tranppose = false; + for (auto& pair : ins) { + for (auto& var : pair.second) { + if (var != nullptr && + (paddle::imperative::GetDataLayout(var) != var_layout)) { + need_tranppose = true; + break; + } + } + } + + if (need_tranppose) { + return LightlyLayoutSensitiveOpTransformer::Apply( + ins, outs, attrs, tracer); + } + + if (var_layout != DataLayout::UNDEFINED) { + std::vector perm_nhwc = {0, 3, 1, 2}; + std::vector perm_nchw = {0, 2, 3, 1}; + auto perm = var_layout == DataLayout::NHWC ? perm_nhwc : perm_nchw; + auto axis = BOOST_GET_CONST(int, (*attrs)["axis"]); + (*attrs)["axis"] = static_cast(perm[axis]); + } + auto axis = BOOST_GET_CONST(int, (*attrs)["axis"]); + VLOG(3) << "Optimze lightly layout sensitive op asdfasdfasdf axis" << axis; + + this->SetVarsLayout(outs, var_layout); + return ins; + } +}; + } // namespace imperative } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_layout_autotune.py b/python/paddle/fluid/tests/unittests/test_layout_autotune.py index fc9b51c5fc0..5cb53437fe9 100644 --- a/python/paddle/fluid/tests/unittests/test_layout_autotune.py +++ b/python/paddle/fluid/tests/unittests/test_layout_autotune.py @@ -161,6 +161,35 @@ class LayoutAutoTune(unittest.TestCase): self.assertEqual(conv_out.shape, [1, 14, 12, 8]) self.assertEqual(out.shape, [1]) + def test_concat_op_transposer(self): + if not self.use_autoune(): + return + in1 = paddle.rand([1, 8, 14, 12]) + conv = paddle.nn.Conv2D(3, 8, (3, 3)) + data = paddle.rand([1, 3, 16, 14]) + with paddle.amp.auto_cast(level="O2"): + conv_out = conv(data) + # conv_out.shape = [1, 14, 12, 8] with NHWC + out = paddle.concat(x=[conv_out, in1], axis=0) + + self.assertEqual(conv_out.shape, [1, 14, 12, 8]) + self.assertEqual(out.shape, [2, 8, 14, 12]) + + def test_concat_op_no_transposer(self): + if not self.use_autoune(): + return + conv = paddle.nn.Conv2D(3, 8, (3, 3)) + data1 = paddle.rand([1, 3, 16, 14]) + data2 = paddle.rand([1, 3, 16, 14]) + with paddle.amp.auto_cast(level="O2"): + conv_out1 = conv(data1) + conv_out2 = conv(data2) + # conv_out.shape = [1, 14, 12, 8] with NHWC + out = paddle.concat(x=[conv_out1, conv_out2], axis=0) + + self.assertEqual(conv_out1.shape, [1, 14, 12, 8]) + self.assertEqual(out.shape, [2, 14, 12, 8]) + class TestAutoTuneAPI(unittest.TestCase): -- GitLab