提交 35a25784 编写于 作者: S sneaxiy

fix bug

test=develop
上级 64ad051b
...@@ -31,8 +31,6 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope, ...@@ -31,8 +31,6 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
void ComputationOpHandle::RunImpl() { void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_); WaitInputVarGenerated(place_);
VLOG(10) << "Run Op" << Name();
auto run_func = [this]() { auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_); op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
}; };
......
...@@ -29,7 +29,7 @@ namespace paddle { ...@@ -29,7 +29,7 @@ namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
struct OpConnectionDetector { class OpConnectionDetector {
public: public:
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 }; enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
...@@ -37,8 +37,8 @@ struct OpConnectionDetector { ...@@ -37,8 +37,8 @@ struct OpConnectionDetector {
: graph_(all_ops) {} : graph_(all_ops) {}
template <typename OpSet> template <typename OpSet>
std::unordered_set<typename OpSet::key_type> MaxNoDepOps( OpSet MaxNoDepOps(const OpSet &op_set) {
const OpSet &op_set) { if (op_set.size() <= 1) return op_set;
using KeyType = typename OpSet::key_type; using KeyType = typename OpSet::key_type;
static_assert( static_assert(
std::is_base_of<OpHandleBase, std::is_base_of<OpHandleBase,
...@@ -46,7 +46,7 @@ struct OpConnectionDetector { ...@@ -46,7 +46,7 @@ struct OpConnectionDetector {
"Key type of OpSet must be or derived of OpHandleBase"); "Key type of OpSet must be or derived of OpHandleBase");
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end()); std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end());
std::unordered_set<KeyType> ret; OpSet ret;
auto rels = GetRelations(ops); auto rels = GetRelations(ops);
auto not_before = [](RelationShip r) { return r != kBefore; }; auto not_before = [](RelationShip r) { return r != kBefore; };
for (size_t i = 0; i < rels.size(); ++i) { for (size_t i = 0; i < rels.size(); ++i) {
...@@ -79,7 +79,7 @@ struct OpConnectionDetector { ...@@ -79,7 +79,7 @@ struct OpConnectionDetector {
auto it = op_to_idx.find(op); auto it = op_to_idx.find(op);
if (it != op_to_idx.end()) { if (it != op_to_idx.end()) {
size_t j = it->second; size_t j = it->second;
if (ret[i][j] != kSame) { if (i != j && ret[i][j] == kSame) {
ret[i][j] = kBefore; ret[i][j] = kBefore;
ret[j][i] = kAfter; ret[j][i] = kAfter;
found_num += 2; found_num += 2;
...@@ -208,6 +208,10 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -208,6 +208,10 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
VLOG(10) << "Shrink last living op number of " << var_name << " from " VLOG(10) << "Shrink last living op number of " << var_name << " from "
<< original_size << " to " << last_live_op.size(); << original_size << " to " << last_live_op.size();
} }
PADDLE_ENFORCE(!last_live_op.empty(),
"Last living ops of %s cannot be empty", var_name);
ref_cnts[i].emplace(var_name, last_live_op.size()); ref_cnts[i].emplace(var_name, last_live_op.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(last_live_op)); last_live_ops_of_vars[i].emplace(var_name, std::move(last_live_op));
} }
......
...@@ -49,8 +49,6 @@ void StreamCallbackManager::AddCallback(std::function<void()> callback) const { ...@@ -49,8 +49,6 @@ void StreamCallbackManager::AddCallback(std::function<void()> callback) const {
#endif #endif
} }
StreamCallbackManager::~StreamCallbackManager() { Wait(); }
void StreamCallbackManager::Wait() const { void StreamCallbackManager::Wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
{ {
......
...@@ -33,7 +33,7 @@ class StreamCallbackManager { ...@@ -33,7 +33,7 @@ class StreamCallbackManager {
public: public:
explicit StreamCallbackManager(const cudaStream_t stream); explicit StreamCallbackManager(const cudaStream_t stream);
~StreamCallbackManager(); ~StreamCallbackManager() = default;
void AddCallback(std::function<void()> callback) const; void AddCallback(std::function<void()> callback) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册