未验证 提交 9a8a7a1d 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix affine_channel no_need buffer bug, test=develop (#18844)

上级 829ef262
......@@ -295,10 +295,10 @@ class AffineChannelNoNeedBufferVarsInference
using framework::NoNeedBufferVarsInference::NoNeedBufferVarsInference;
private:
inline bool HasInput(const std::string& name) const {
auto& inputs = Inputs();
auto iter = inputs.find(name);
if (iter == inputs.end() || iter->second.empty()) {
inline bool HasOutput(const std::string& name) const {
auto& outputs = Outputs();
auto iter = outputs.find(name);
if (iter == outputs.end() || iter->second.empty()) {
return false;
} else {
return iter->second[0] != framework::kEmptyVarName;
......@@ -306,9 +306,9 @@ class AffineChannelNoNeedBufferVarsInference
}
public:
std::unordered_set<std::string> operator()() const {
if (!HasInput(framework::GradVarName("Scale")) &&
!HasInput(framework::GradVarName("Bias"))) {
std::unordered_set<std::string> operator()() const override {
if (!HasOutput(framework::GradVarName("Scale")) &&
!HasOutput(framework::GradVarName("Bias"))) {
return {"X"};
} else {
return {};
......
......@@ -163,7 +163,8 @@ list(REMOVE_ITEM TEST_OPS test_basic_lstm_unit_op)
# Some ops need to check results when gc is enabled
# Currently, only ops that register NoNeedBufferVarsInference need to do this test
set(TEST_OPS_WITH_GC
set(TEST_OPS_WITH_GC
test_affine_channel_op
test_concat_op
test_elementwise_add_op
test_elementwise_sub_op
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册