未验证 提交 9a8a7a1d 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix affine_channel no_need buffer bug, test=develop (#18844)

上级 829ef262
...@@ -295,10 +295,10 @@ class AffineChannelNoNeedBufferVarsInference ...@@ -295,10 +295,10 @@ class AffineChannelNoNeedBufferVarsInference
using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference; using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference;
private: private:
inline bool HasInput(const std::string& name) const { inline bool HasOutput(const std::string& name) const {
auto& inputs = Inputs(); auto& outputs = Outputs();
auto iter = inputs.find(name); auto iter = outputs.find(name);
if (iter == inputs.end() || iter->second.empty()) { if (iter == outputs.end() || iter->second.empty()) {
return false; return false;
} else { } else {
return iter->second[0] != framework::kEmptyVarName; return iter->second[0] != framework::kEmptyVarName;
...@@ -306,9 +306,9 @@ class AffineChannelNoNeedBufferVarsInference ...@@ -306,9 +306,9 @@ class AffineChannelNoNeedBufferVarsInference
} }
public: public:
std::unordered_set<std::string> operator()() const { std::unordered_set<std::string> operator()() const override {
if (!HasInput(framework::GradVarName("Scale")) && if (!HasOutput(framework::GradVarName("Scale")) &&
!HasInput(framework::GradVarName("Bias"))) { !HasOutput(framework::GradVarName("Bias"))) {
return {"X"}; return {"X"};
} else { } else {
return {}; return {};
......
...@@ -163,7 +163,8 @@ list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op) ...@@ -163,7 +163,8 @@ list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op)
# Some ops need to check results when gc is enabled # Some ops need to check results when gc is enabled
# Currently, only ops that register NoNeedBufferVarsInference need to do this test # Currently, only ops that register NoNeedBufferVarsInference need to do this test
set(TEST_OPS_WITH_GC set(TEST_OPS_WITH_GC
test_affine_channel_op
test_concat_op test_concat_op
test_elementwise_add_op test_elementwise_add_op
test_elementwise_sub_op test_elementwise_sub_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册