未验证 提交 8358d614 编写于 作者: W wanghuancoder 提交者: GitHub

fix 3 bug of new_executor (#37142)

* fix 3 bug, test=develop

* refine, test=develop
上级 b628c316
......@@ -98,6 +98,9 @@ void InterpreterCore::Convert() {
for (auto& item : op_func_node.input_index) {
for (auto id : item.second) {
if (id == kEmptyVarIndex) {
continue;
}
input_var2op_info_.at(id).push_back(op_idx);
// var can be gc-ed
if (!info.IsBuilt()) {
......
......@@ -60,6 +60,10 @@ void InterpreterCoreGarbageCollector::Add(
void InterpreterCoreGarbageCollector::Add(paddle::framework::Variable* var,
paddle::platform::DeviceEvent& event,
const platform::DeviceContext* ctx) {
if (!var) {
return;
}
if (var->IsType<LoDTensor>()) {
Add(var->GetMutable<LoDTensor>()->MoveMemoryHolder(), event, ctx);
} else if (var->IsType<
......
......@@ -446,7 +446,13 @@ void build_op_func_list(const platform::Place& place,
VariableValueMap ins_map;
VariableIdMap ins_name2id;
bool enforce_exist = true;
if (op->Type() == "recurrent_grad") enforce_exist = false;
if (op->Type() == "recurrent_grad" || op->Type() == "rnn_memory_helper" ||
op->Type() == "rnn_memory_helper_grad" ||
op->Type() == "conditional_block" ||
op->Type() == "conditional_block_grad" || op->Type() == "while" ||
op->Type() == "while_grad") {
enforce_exist = false;
}
std::tie(ins_map, ins_name2id) =
build_variable_map(inputs_names, var_scope, enforce_exist);
......
......@@ -480,7 +480,7 @@ const std::vector<Variable*>& InterpretercoreInferShapeContext::OutputVars(
VariableScope::VariableScope(Scope* scope) {
// for @EMPTY@ variable
var_list_.push_back(nullptr);
name2id_[kEmptyVarName] = 0;
name2id_[kEmptyVarName] = kEmptyVarIndex;
vec_meta_info_.emplace_back(0, nullptr);
scope_ = scope;
PADDLE_ENFORCE_NE(
......
......@@ -43,6 +43,8 @@ using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
using OpKernelMap =
std::unordered_map<OpKernelType, OpKernelComputeFunc, OpKernelType::Hash>;
constexpr int kEmptyVarIndex = 0;
class InterpretercoreInferShapeContext : public InferShapeContext {
public:
InterpretercoreInferShapeContext(const OperatorBase& op,
......
......@@ -598,13 +598,13 @@ class _ExecutorCache(object):
assert isinstance(
program, Program), "Required type(Program), but received {}".format(
type(program).__name__)
if program not in self._cached_executors:
if str(program) not in self._cached_executors:
new_program = program.clone()
_prune_feed_ops(new_program)
new_exe = _StandaloneExecutor(self._place, new_program, scope)
self._cached_executors[program] = new_exe
self._cached_executors[str(program)] = new_exe
return self._cached_executors[program]
return self._cached_executors[str(program)]
class Executor(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册