From 69e99cc7c0fd0ab8803f27fc798bbb625438084d Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Thu, 23 Jun 2022 20:44:50 +0800 Subject: [PATCH] improve LayoutAutoTune for NCHW and NHWC (#43158) --- paddle/fluid/imperative/layout_autotune.cc | 151 ++++++++----- paddle/fluid/imperative/layout_autotune.h | 9 +- paddle/fluid/imperative/layout_transformer.h | 202 ++++++++++++------ paddle/fluid/imperative/tracer.cc | 20 +- paddle/fluid/imperative/var_helper.cc | 3 +- .../tests/unittests/test_layout_autotune.py | 26 +++ 6 files changed, 284 insertions(+), 127 deletions(-) diff --git a/paddle/fluid/imperative/layout_autotune.cc b/paddle/fluid/imperative/layout_autotune.cc index 7dfd860403..669a4af99f 100644 --- a/paddle/fluid/imperative/layout_autotune.cc +++ b/paddle/fluid/imperative/layout_autotune.cc @@ -26,6 +26,7 @@ namespace imperative { bool LayoutAutoTune::UseLayoutAutoTune() const { #if defined(PADDLE_WITH_CUDA) if (!phi::backends::gpu::TensorCoreAvailable()) { + LayoutAutoTune::Instance().DisableLayoutAutoTune(); return false; } else { return use_layout_autotune_; @@ -38,30 +39,23 @@ bool LayoutAutoTune::UseLayoutAutoTune() const { LayoutAutoTune::LayoutAutoTune() { const auto& op_info = paddle::framework::OpInfoMap::Instance().map(); for (auto it = op_info.begin(); it != op_info.end(); it++) { - // only record forwrd operators - if (it->first.find("_grad") != std::string::npos) { + // only when op was not in Lightly、Heavily or Agnostic Set + if (IsLightlyLayoutSensitive(it->first) || + IsHeavilyLayoutSensitive(it->first) || IsLayoutAgnostic(it->first)) { + VLOG(4) << "Already exists in Layout OP: " << it->first; continue; } - // some normalization operators such as instance_norm and layer_norm - // do not have data_format attr, but are layout sensitive. - if (it->first.find("norm") != std::string::npos) { - layout_agnostic_ops_.emplace(it->first); + // only record forwrd operators + if (it->first.find("_grad") != std::string::npos) { continue; } auto* attr_checker = it->second.Checker(); + bool layout_agnostic = true; if (attr_checker) { auto attrs = attr_checker->GetDefaultAttrMap(); - if (attrs.find("data_format") != attrs.end() || - attrs.find("data_layout") != attrs.end()) { - VLOG(4) << "Heavily layout sensitive OP: " << it->first; - heavily_layout_sensitive_ops_.emplace(it->first); - continue; - } - // Attribute name is fuzzy matched, such as start and start_axis. - bool layout_agnostic = true; for (auto& attr : attrs) { auto attr_name = attr.first; VLOG(6) << "OP: " << it->first << " Attr Name: " << attr_name; @@ -77,11 +71,27 @@ LayoutAutoTune::LayoutAutoTune() { } } - if (layout_agnostic) { - VLOG(4) << "Layout agnostic_ops: " << it->first; - layout_agnostic_ops_.emplace(it->first); + if ((attrs.find("data_format") != attrs.end() || + attrs.find("data_layout") != attrs.end()) && + layout_agnostic == true) { + VLOG(4) << "Heavily layout sensitive OP: " << it->first; + heavily_layout_sensitive_ops_.emplace(it->first); + layout_agnostic = false; + continue; } } + + // some normalization operators such as instance_norm and layer_norm + // do not have data_format attr, but are layout sensitive. + if (it->first.find("norm") != std::string::npos && layout_agnostic) { + lightly_layout_sensitive_ops_.emplace(it->first); + continue; + } + + if (layout_agnostic) { + VLOG(4) << "Layout agnostic_ops: " << it->first; + layout_agnostic_ops_.emplace(it->first); + } } VLOG(3) << "The number of layout agnostic OPs: " @@ -91,6 +101,48 @@ LayoutAutoTune::LayoutAutoTune() { << lightly_layout_sensitive_ops_.size(); } +template +paddle::imperative::NameVarMap DealHeavilyLayoutSensitive( + const std::string& op_type, + const paddle::imperative::NameVarMap& ins, + const paddle::imperative::NameVarMap& outs, + paddle::framework::AttributeMap* attrs, + const std::shared_ptr& tracer) { + std::shared_ptr> transposer = nullptr; + transposer = + std::make_shared>(op_type); + transposer->SetArguments( + {"Input", "X"}, {"Output", "Out", "Y"}, {"data_format", "data_layout"}); + + return transposer->Apply(ins, outs, attrs, tracer); +} + +template +paddle::imperative::NameVarMap DealLightlyLayoutSensitive( + const std::string& op_type, + const paddle::imperative::NameVarMap& ins, + const paddle::imperative::NameVarMap& outs, + paddle::framework::AttributeMap* attrs, + const std::shared_ptr& tracer) { + std::shared_ptr> transposer = nullptr; + if (op_type == "transpose2") { + transposer = std::make_shared>(op_type); + } else if (op_type == "flatten_contiguous_range") { + transposer = std::make_shared>(op_type); + } else if (op_type == "arg_max") { + transposer = std::make_shared>(op_type); + } else if (op_type.find("elementwise_") != std::string::npos) { + transposer = std::make_shared>(op_type); + } else { + VLOG(4) << op_type + << "'s LayoutTransformer is unimplemented. Use default " + "LightlyLayoutTransformer instead."; + transposer = + std::make_shared>(op_type); + } + return transposer->Apply(ins, outs, attrs, tracer); +} + template paddle::imperative::NameVarMap AutoTuneLayout( const std::string& op_type, @@ -101,7 +153,6 @@ paddle::imperative::NameVarMap AutoTuneLayout( if (!LayoutAutoTune::Instance().UseLayoutAutoTune()) { return ins; } - // When layout autotuning is enabled, the tuner will check the desired layout. // (1) If the desired layout is undefined, and there is no convolutional // layers, layout optimization is unnecessary. Otherwise, the desired layout @@ -115,51 +166,49 @@ paddle::imperative::NameVarMap AutoTuneLayout( if (op_type != "conv2d") { return ins; } else { - if (BOOST_GET_CONST(std::string, (*attrs)["data_format"]) == "NCHW") { + auto conv_in_type = framework::proto::VarType::FP32; + auto& in_vars = ins.at("Input")[0]; + if (GetDataType(in_vars) == framework::proto::VarType::FP16) { + conv_in_type = framework::proto::VarType::FP16; + } + bool is_tune_fp32 = + (BOOST_GET_CONST(std::string, (*attrs)["data_format"]) == "NHWC") && + (conv_in_type == framework::proto::VarType::FP32); + bool is_tune_fp16 = + (BOOST_GET_CONST(std::string, (*attrs)["data_format"]) == "NCHW") && + (conv_in_type == framework::proto::VarType::FP16); + if (is_tune_fp32) { + LayoutAutoTune::Instance().SetDesiredLayout(DataLayout::NCHW); + } else if (is_tune_fp16) { LayoutAutoTune::Instance().SetDesiredLayout(DataLayout::NHWC); - VLOG(3) << "Tune the layout from " - << BOOST_GET_CONST(std::string, (*attrs)["data_format"]) - << " to " - << paddle::framework::DataLayoutToString( - LayoutAutoTune::Instance().GetDesiredLayout()); } else { LayoutAutoTune::Instance().DisableLayoutAutoTune(); return ins; } + VLOG(3) << "Tune the layout from " + << BOOST_GET_CONST(std::string, (*attrs)["data_format"]) << " to " + << paddle::framework::DataLayoutToString( + LayoutAutoTune::Instance().GetDesiredLayout()); } } - std::shared_ptr> transposer = nullptr; - if (op_type == "conv2d") { - transposer = - std::make_shared>(op_type); - transposer->SetArguments({"Input"}, {"Output"}, {"data_format"}); - } else if (op_type == "batch_norm") { - transposer = - std::make_shared>(op_type); - transposer->SetArguments({"X"}, {"Y"}, {"data_layout"}); - } else if (op_type == "pool2d") { - transposer = - std::make_shared>(op_type); - transposer->SetArguments({"X"}, {"Out"}, {"data_format"}); - } else if (op_type == "transpose2") { - transposer = std::make_shared>(op_type); - } else if (op_type == "flatten_contiguous_range") { - transposer = std::make_shared>(op_type); - } else if (op_type.find("elementwise_") != std::string::npos) { - transposer = std::make_shared>(op_type); - } else if (LayoutAutoTune::Instance().IsLayoutAgnostic(op_type)) { - transposer = std::make_shared>(op_type); + if (LayoutAutoTune::Instance().IsHeavilyLayoutSensitive(op_type)) { + return DealHeavilyLayoutSensitive( + op_type, ins, outs, attrs, tracer); } else if (LayoutAutoTune::Instance().IsLightlyLayoutSensitive(op_type)) { - transposer = - std::make_shared>(op_type); + return DealLightlyLayoutSensitive( + op_type, ins, outs, attrs, tracer); } else { + std::shared_ptr> transposer = nullptr; + if (LayoutAutoTune::Instance().IsLayoutAgnostic(op_type)) { + transposer = std::make_shared>(op_type); + } PADDLE_ENFORCE_NOT_NULL( - transposer, phi::errors::Unimplemented( - "%s 's LayoutTransformer is unimplemented.", op_type)); + transposer, + phi::errors::Unimplemented("%s 's LayoutTransformer is unimplemented.", + op_type)); + return transposer->Apply(ins, outs, attrs, tracer); } - - return transposer->Apply(ins, outs, attrs, tracer); } template paddle::imperative::NameVarMap AutoTuneLayout( const std::string& op_type, diff --git a/paddle/fluid/imperative/layout_autotune.h b/paddle/fluid/imperative/layout_autotune.h index 2da368910e..2f3d9c38e9 100644 --- a/paddle/fluid/imperative/layout_autotune.h +++ b/paddle/fluid/imperative/layout_autotune.h @@ -41,6 +41,10 @@ class LayoutAutoTune { void DisableLayoutAutoTune() { use_layout_autotune_ = false; } + bool IsHeavilyLayoutSensitive(const std::string& op_type) const { + return heavily_layout_sensitive_ops_.count(op_type) != 0; + } + bool IsLightlyLayoutSensitive(const std::string& op_type) const { return lightly_layout_sensitive_ops_.count(op_type) != 0; } @@ -60,9 +64,10 @@ class LayoutAutoTune { std::unordered_set layout_agnostic_ops_{}; - std::unordered_set heavily_layout_sensitive_ops_{}; + std::unordered_set heavily_layout_sensitive_ops_{"batch_norm"}; - std::unordered_set lightly_layout_sensitive_ops_{}; + std::unordered_set lightly_layout_sensitive_ops_{ + "instance_norm", "softmax", "transpose", "transpose2", "reshape2"}; DataLayout layout_{DataLayout::UNDEFINED}; }; diff --git a/paddle/fluid/imperative/layout_transformer.h b/paddle/fluid/imperative/layout_transformer.h index 73e27d4b79..50d3e2b6ac 100644 --- a/paddle/fluid/imperative/layout_transformer.h +++ b/paddle/fluid/imperative/layout_transformer.h @@ -13,18 +13,19 @@ // limitations under the License. #pragma once +#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/imperative/layout_autotune.h" #include "paddle/fluid/imperative/tracer.h" #include "paddle/fluid/imperative/var_helper.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/errors.h" - namespace paddle { namespace imperative { template std::shared_ptr TraceTransposeOp( - const std::shared_ptr& var, const DataLayout layout, + const std::shared_ptr& var, + const DataLayout layout, const std::shared_ptr& tracer) { std::vector axis; if (layout == DataLayout::NHWC) { @@ -76,8 +77,8 @@ class LayoutTransformer { for (auto& var : pair.second) { // Once the any input is desired layout, we set in_layout is desired // layout. - if (paddle::imperative::GetDataLayout(var) == - LayoutAutoTune::Instance().GetDesiredLayout()) { + if (var != nullptr && (paddle::imperative::GetDataLayout(var) == + LayoutAutoTune::Instance().GetDesiredLayout())) { in_layout = LayoutAutoTune::Instance().GetDesiredLayout(); break; } @@ -103,17 +104,27 @@ class LayoutTransformer { // will be considered. Otherwise, it only set layout for the specified output. void SetVarsLayout(const paddle::imperative::NameVarMap& outs, DataLayout layout) const { - if (outs_.empty()) { - for (auto& pair : outs) { - for (auto& var : pair.second) { - paddle::imperative::SetDataLayout(var, layout); + bool not_in_out = true; + if (!outs_.empty()) { + for (auto& name : outs_) { + if (outs.find(name) != outs.end()) { + auto out_vars = outs.at(name); + for (auto& var : out_vars) { + if (var != nullptr) { + paddle::imperative::SetDataLayout(var, layout); + } + } + not_in_out = false; } } - } else { - for (auto& name : outs_) { - auto out_vars = outs.at(name); - for (auto& var : out_vars) { - paddle::imperative::SetDataLayout(var, layout); + } + + if (not_in_out) { + for (auto& pair : outs) { + for (auto& var : pair.second) { + if (var != nullptr) { + paddle::imperative::SetDataLayout(var, layout); + } } } } @@ -132,46 +143,6 @@ class LayoutTransformer { std::vector attrs_{}; }; -template -class ElementwiseOpTransformer : public LayoutTransformer { - public: - explicit ElementwiseOpTransformer(const std::string& type) - : LayoutTransformer(type) {} - - paddle::imperative::NameVarMap Apply( - const paddle::imperative::NameVarMap& ins, - const paddle::imperative::NameVarMap& outs, - paddle::framework::AttributeMap* attrs, - const std::shared_ptr& tracer) { - // [Why we need the this?] - // The Elementwise Ops has a axis attr, it is to support broadcast. - // When bias_attr of Conv is not false, the elementwise_add will be - // appended, and the axis will be set to the channel dimension. - - // If the axis is set to the channel dimension, the attr transformation - // is necessary. Otherwise, it will fall back to the - // LayoutTransformer::Apply. - auto desired_layout = LayoutAutoTune::Instance().GetDesiredLayout(); - if (attrs->find("axis") != attrs->end() && - BOOST_GET_CONST(int, (*attrs)["axis"]) != -1) { - VLOG(3) << "Optimze layout agnostic op " << this->Type(); - if (desired_layout == DataLayout::NHWC) { - (*attrs)["axis"] = 3; - } else if (desired_layout == DataLayout::NCHW) { - (*attrs)["axis"] = 1; - } else { - PADDLE_ENFORCE_EQ( - desired_layout, DataLayout::UNDEFINED, - phi::errors::PreconditionNotMet("DataLayout is unsupport.")); - } - this->SetVarsLayout(outs, desired_layout); - return ins; - } else { - return LayoutTransformer::Apply(ins, outs, attrs, tracer); - } - } -}; - /* * Both functionality and performance are affected by data layout. * Such as operators with data_format attribute. @@ -213,11 +184,13 @@ class HeavilyLayoutSensitiveOpTransformer : public LayoutTransformer { // Step 2: Transpose the specified input for Op and set the transposed var's // layout. for (auto& name : this->Inputs()) { - auto& in_vars = new_ins[name]; - for (auto& var : in_vars) { - auto var_layout = paddle::imperative::GetDataLayout(var); - if (var_layout != desired_layout) { - var = TraceTransposeOp(var, DataLayout::NHWC, tracer); + if (new_ins.find(name) != new_ins.end()) { + auto& in_vars = new_ins[name]; + for (auto& var : in_vars) { + if (var != nullptr && + paddle::imperative::GetDataLayout(var) != desired_layout) { + var = TraceTransposeOp(var, desired_layout, tracer); + } } } } @@ -252,10 +225,20 @@ class LightlyLayoutSensitiveOpTransformer : public LayoutTransformer { // operator output data layout. Currently only a few operators are // supported, and transposers need to be carefully designed to ensure that // they do not cause exceptions. + auto desired_layout = LayoutAutoTune::Instance().GetDesiredLayout(); for (auto& pair : new_ins) { for (auto& var : pair.second) { - auto var_layout = paddle::imperative::GetDataLayout(var); - if (var_layout == LayoutAutoTune::Instance().GetDesiredLayout()) { + if (var != nullptr) { + VLOG(3) << "Tune the layout from " + << paddle::framework::DataLayoutToString( + paddle::imperative::GetDataLayout(var)) + << " to " + << paddle::framework::DataLayoutToString( + LayoutAutoTune::Instance().GetDesiredLayout()); + } + if (var != nullptr && + paddle::imperative::GetDataLayout(var) == desired_layout && + desired_layout == DataLayout::NHWC) { // Set layout to UNDEFINED so that TransposeOpTransformer do // NHWC->NCHW transformation. var = TraceTransposeOp(var, DataLayout::UNDEFINED, tracer); @@ -266,6 +249,50 @@ class LightlyLayoutSensitiveOpTransformer : public LayoutTransformer { } }; +template +class ElementwiseOpTransformer + : public LightlyLayoutSensitiveOpTransformer { + public: + explicit ElementwiseOpTransformer(const std::string& type) + : LightlyLayoutSensitiveOpTransformer(type) {} + + paddle::imperative::NameVarMap Apply( + const paddle::imperative::NameVarMap& ins, + const paddle::imperative::NameVarMap& outs, + paddle::framework::AttributeMap* attrs, + const std::shared_ptr& tracer) { + // [Why we need the this?] + // The Elementwise Ops has a axis attr, it is to support broadcast. + // When bias_attr of Conv is not false, the elementwise_add will be + // appended, and the axis will be set to the channel dimension. + // If the axis is set to the channel dimension, the attr transformation + // is necessary. Otherwise, it will fall back to the + // LayoutTransformer::Apply. + auto& in1_vars = ins.at("X")[0]; + auto& in2_vars = ins.at("Y")[0]; + auto in_layout = paddle::imperative::GetDataLayout(in1_vars); + // for conv's bias + if (attrs->find("axis") != attrs->end() && + BOOST_GET_CONST(int, (*attrs)["axis"]) != -1) { + if (in_layout == DataLayout::NHWC) { + (*attrs)["axis"] = 3; + } else if (in_layout == DataLayout::NCHW) { + (*attrs)["axis"] = 1; + } + this->SetVarsLayout(outs, in_layout); + return ins; + } else { + auto in2_layout = paddle::imperative::GetDataLayout(in2_vars); + if (in_layout == in2_layout) { + this->SetVarsLayout(outs, in_layout); + return ins; + } + return LightlyLayoutSensitiveOpTransformer::Apply( + ins, outs, attrs, tracer); + } + } +}; + template class TransposeOpTransformer : public LightlyLayoutSensitiveOpTransformer { @@ -286,13 +313,14 @@ class TransposeOpTransformer // transpose Op with the current transpose Op by transforming 'axis' attr. auto& in_var = ins.at("X")[0]; auto var_layout = paddle::imperative::GetDataLayout(in_var); - if (var_layout == LayoutAutoTune::Instance().GetDesiredLayout()) { + auto desired_layout = LayoutAutoTune::Instance().GetDesiredLayout(); + if (var_layout == desired_layout && desired_layout == DataLayout::NHWC) { auto axis = BOOST_GET_CONST(std::vector, (*attrs)["axis"]); // NHWC->NCHW, permutaion will be set as follows. std::vector perm = {0, 3, 1, 2}; // fuse the transpose Ops by transforming axis. - std::vector fusion_axis = {perm[axis[0]], perm[axis[1]], - perm[axis[2]], perm[axis[3]]}; + std::vector fusion_axis = { + perm[axis[0]], perm[axis[1]], perm[axis[2]], perm[axis[3]]}; (*attrs)["axis"] = fusion_axis; } return ins; @@ -322,9 +350,53 @@ class FlattenOpTransformer start_axis == 1 && stop_axis == 3) { return ins; } else { - return LightlyLayoutSensitiveOpTransformer::Apply(ins, outs, - attrs, tracer); + return LightlyLayoutSensitiveOpTransformer::Apply( + ins, outs, attrs, tracer); + } + } +}; + +template +class ArgmaxOpTransformer + : public LightlyLayoutSensitiveOpTransformer { + public: + explicit ArgmaxOpTransformer(const std::string& type) + : LightlyLayoutSensitiveOpTransformer(type) {} + + paddle::imperative::NameVarMap Apply( + const paddle::imperative::NameVarMap& ins, + const paddle::imperative::NameVarMap& outs, + paddle::framework::AttributeMap* attrs, + const std::shared_ptr& tracer) { + VLOG(3) << "Optimze lightly layout sensitive op " << this->Type(); + auto& in_var = ins.at("X")[0]; + auto var_layout = paddle::imperative::GetDataLayout(in_var); + bool keep_dims = BOOST_GET_CONST(bool, (*attrs)["keepdims"]); + if (keep_dims) { + if (var_layout != DataLayout::UNDEFINED) { + std::vector perm_nhwc = {0, 2, 3, 1}; + std::vector perm_nchw = {0, 3, 1, 2}; + auto perm = var_layout == DataLayout::NHWC ? perm_nhwc : perm_nchw; + switch (AttrTypeID((*attrs)["axis"])) { + case paddle::framework::proto::AttrType::INT: { + auto axis = BOOST_GET_CONST(int, (*attrs)["axis"]); + (*attrs)["axis"] = static_cast(perm[axis]); + } + case paddle::framework::proto::AttrType::LONG: { + auto axis = BOOST_GET_CONST(int64_t, (*attrs)["axis"]); + (*attrs)["axis"] = static_cast(perm[axis]); + } + default: + VLOG(4) << "The data_type of axis is Error, axis must be int or " + "int64, bug got " + << (AttrTypeID((*attrs)["axis"])); + } + } + this->SetVarsLayout(outs, var_layout); + return ins; } + return LightlyLayoutSensitiveOpTransformer::Apply( + ins, outs, attrs, tracer); } }; diff --git a/paddle/fluid/imperative/tracer.cc b/paddle/fluid/imperative/tracer.cc index 2295ea4bf6..b68f5b1d1d 100644 --- a/paddle/fluid/imperative/tracer.cc +++ b/paddle/fluid/imperative/tracer.cc @@ -225,11 +225,9 @@ void Tracer::TraceOpImpl(const std::string& type, std::unique_ptr> ins_amp = nullptr; if (amp_level_ == AmpLevel::O1) { if (amp_dtype_ == phi::DataType::FLOAT16) { - const auto& tracer = imperative::GetCurrentTracer(); VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type; ins_amp = std::make_unique>( - AutoCastInputs(type, imperative::AutoTuneLayout( - type, ins, outs, &attrs, tracer))); + AutoCastInputs(type, ins)); } else if (amp_dtype_ == phi::DataType::BFLOAT16) { VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type; ins_amp = std::make_unique>( @@ -237,18 +235,24 @@ void Tracer::TraceOpImpl(const std::string& type, } } else if (amp_level_ == AmpLevel::O2) { if (amp_dtype_ == phi::DataType::FLOAT16) { - const auto& tracer = imperative::GetCurrentTracer(); VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type; - ins_amp = - std::make_unique>(CastPureFp16Inputs( - type, imperative::AutoTuneLayout(type, ins, outs, &attrs, - tracer))); + ins_amp = std::make_unique>( + CastPureFp16Inputs(type, ins)); } else if (amp_dtype_ == phi::DataType::BFLOAT16) { VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type; ins_amp = std::make_unique>( CastPureBf16Inputs(type, ins)); } } + + if (platform::is_gpu_place(place)) { + const auto& new_tmp = ins_amp == nullptr ? ins : *ins_amp; + const auto& tracer = imperative::GetCurrentTracer(); + ins_amp = std::make_unique>( + imperative::AutoTuneLayout(type, new_tmp, outs, &attrs, + tracer)); + } + const auto& new_ins = ins_amp == nullptr ? ins : *ins_amp; try { diff --git a/paddle/fluid/imperative/var_helper.cc b/paddle/fluid/imperative/var_helper.cc index f84606ba9a..ff22268105 100644 --- a/paddle/fluid/imperative/var_helper.cc +++ b/paddle/fluid/imperative/var_helper.cc @@ -307,7 +307,8 @@ void SetCachedValue( // is equal to self: " << key == key << " and res name is:" << res->Name(). } template void SetCachedValue( - std::shared_ptr var, const paddle::framework::OpKernelType &key, + std::shared_ptr var, + const paddle::framework::OpKernelType &key, std::shared_ptr res); template void SetCachedValue( std::shared_ptr var, diff --git a/python/paddle/fluid/tests/unittests/test_layout_autotune.py b/python/paddle/fluid/tests/unittests/test_layout_autotune.py index f17bffe3b8..6e25e3719d 100644 --- a/python/paddle/fluid/tests/unittests/test_layout_autotune.py +++ b/python/paddle/fluid/tests/unittests/test_layout_autotune.py @@ -135,6 +135,32 @@ class LayoutAutoTune(unittest.TestCase): self.assertEqual(conv_out.shape, [1, 14, 12, 8]) self.assertEqual(out.shape, [1, 112, 12]) + def test_argmax_op_transposer_keep_dims(self): + if not self.use_autoune(): + return + conv = paddle.nn.Conv2D(3, 8, (3, 3)) + data = paddle.rand([1, 3, 16, 14]) + with paddle.amp.auto_cast(level="O2"): + conv_out = conv(data) + # conv_out.shape = [1, 14, 12, 8] with NHWC + out = paddle.argmax(conv_out, axis=1, keepdim=True) + + self.assertEqual(conv_out.shape, [1, 14, 12, 8]) + self.assertEqual(out.shape, [1, 14, 1, 8]) + + def test_argmax_op_transposer(self): + if not self.use_autoune(): + return + conv = paddle.nn.Conv2D(3, 8, (3, 3)) + data = paddle.rand([1, 3, 16, 14]) + with paddle.amp.auto_cast(level="O2"): + conv_out = conv(data) + # conv_out.shape = [1, 14, 12, 8] with NHWC + out = paddle.argmax(conv_out) + + self.assertEqual(conv_out.shape, [1, 14, 12, 8]) + self.assertEqual(out.shape, [1]) + class TestAutoTuneAPI(unittest.TestCase): -- GitLab