From 9a8a7a1ddc286b6093ce955573c571d83891cd82 Mon Sep 17 00:00:00 2001 From: Zeng Jinle <32832641+sneaxiy@users.noreply.github.com> Date: Sun, 28 Jul 2019 22:09:15 +0800 Subject: [PATCH] fix affine_channel no_need buffer bug, test=develop (#18844) --- paddle/fluid/operators/affine_channel_op.cc | 14 +++++++------- python/paddle/fluid/tests/unittests/CMakeLists.txt | 3 ++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/paddle/fluid/operators/affine_channel_op.cc b/paddle/fluid/operators/affine_channel_op.cc index da063541438..f9a79d93889 100644 --- a/paddle/fluid/operators/affine_channel_op.cc +++ b/paddle/fluid/operators/affine_channel_op.cc @@ -295,10 +295,10 @@ class AffineChannelNoNeedBufferVarsInference using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference; private: - inline bool HasInput(const std::string& name) const { - auto& inputs = Inputs(); - auto iter = inputs.find(name); - if (iter == inputs.end() || iter->second.empty()) { + inline bool HasOutput(const std::string& name) const { + auto& outputs = Outputs(); + auto iter = outputs.find(name); + if (iter == outputs.end() || iter->second.empty()) { return false; } else { return iter->second[0] != framework::kEmptyVarName; @@ -306,9 +306,9 @@ class AffineChannelNoNeedBufferVarsInference } public: - std::unordered_set operator()() const { - if (!HasInput(framework::GradVarName("Scale")) && - !HasInput(framework::GradVarName("Bias"))) { + std::unordered_set operator()() const override { + if (!HasOutput(framework::GradVarName("Scale")) && + !HasOutput(framework::GradVarName("Bias"))) { return {"X"}; } else { return {}; diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index b84d4af8295..6fb38e95955 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -163,7 +163,8 @@ list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op) # Some ops need to check results when gc is enabled # Currently, only ops that register NoNeedBufferVarsInference need to do this test -set(TEST_OPS_WITH_GC +set(TEST_OPS_WITH_GC + test_affine_channel_op test_concat_op test_elementwise_add_op test_elementwise_sub_op -- GitLab