未验证 提交 bb143052 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix gc bug in conditional block (#16673)

test=develop
上级 229dc932
......@@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
......@@ -174,24 +177,41 @@ class ConditionalBlockGradOp : public ConditionalOp {
framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block");
exec.Run(*block->Program(), &cur_scope, block->ID(), false);
AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Input"),
Outputs(framework::GradVarName("Input")));
const auto &ins = Inputs("Input");
const auto &d_ins = Outputs(framework::GradVarName("Input"));
const auto &conds = Inputs("Cond");
const auto &d_conds = Outputs(framework::GradVarName("Cond"));
std::vector<std::string> ins_conds_grads;
ins_conds_grads.reserve(ins.size() + conds.size());
for (auto &in : ins) {
ins_conds_grads.emplace_back(framework::GradVarName(in));
}
for (auto &cond : conds) {
ins_conds_grads.emplace_back(framework::GradVarName(cond));
}
exec.Run(*block->Program(), &cur_scope, block->ID(), false, true,
ins_conds_grads);
AssignLocalGradientToGlobal(dev_place, cur_scope, ins_conds_grads.data(),
ins.size(), d_ins);
AssignLocalGradientToGlobal(dev_place, cur_scope, Inputs("Cond"),
Outputs(framework::GradVarName("Cond")));
AssignLocalGradientToGlobal(dev_place, cur_scope,
ins_conds_grads.data() + ins.size(),
conds.size(), d_conds);
}
}
private:
void AssignLocalGradientToGlobal(
const platform::Place &place, const framework::Scope &cur_scope,
const std::vector<std::string> &p_names,
const std::string *p_grad_names, size_t p_grad_names_num,
const std::vector<std::string> &pg_names) const {
for (size_t i = 0; i < p_names.size(); ++i) {
for (size_t i = 0; i < p_grad_names_num; ++i) {
auto out_grad_name = pg_names[i];
auto in_grad_name = framework::GradVarName(p_names[i]);
const auto &in_grad_name = p_grad_names[i];
auto *in_var = cur_scope.FindVar(in_grad_name);
if (in_var == nullptr) {
continue;
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.fluid as fluid
import unittest
fluid.core._set_eager_deletion_mode(0.0, 1.0, True)
from test_conditional_block import *
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册