diff --git a/paddle/fluid/framework/details/op_handle_base.h b/paddle/fluid/framework/details/op_handle_base.h index fe1735d05dde5f09d5c72c68e5002d16f0083eb5..8f94206a87dbae8a81727ca48718886bbabbe25c 100644 --- a/paddle/fluid/framework/details/op_handle_base.h +++ b/paddle/fluid/framework/details/op_handle_base.h @@ -70,6 +70,14 @@ class OpHandleBase { const std::vector &Inputs() const { return inputs_; } + size_t NoDupInputSize() const { + std::unordered_set res; + for (auto *var : inputs_) { + res.emplace(var); + } + return res.size(); + } + const std::vector &Outputs() const { return outputs_; } protected: diff --git a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc index ef263d82c5ec93f0673eb0ac70e4fb02904bff13..815f739371e77d953a28be99b38ec1b8ff26506c 100644 --- a/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/threaded_ssa_graph_executor.cc @@ -174,7 +174,7 @@ void ThreadedSSAGraphExecutor::InsertFetchOps( void ThreadedSSAGraphExecutor::InsertPendingOp( std::unordered_map *pending_ops, OpHandleBase *op_instance) const { - pending_ops->insert({op_instance, op_instance->Inputs().size()}); + pending_ops->insert({op_instance, op_instance->NoDupInputSize()}); } void ThreadedSSAGraphExecutor::InsertPendingVar(