提交 35a25784 编写于 作者: S sneaxiy

fix bug

test=develop
上级 64ad051b
......@@ -31,8 +31,6 @@ ComputationOpHandle::ComputationOpHandle(ir::Node *node, Scope *scope,
void ComputationOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
VLOG(10) << "Run Op" << Name();
auto run_func = [this]() {
op_->Run(*scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
......
......@@ -29,7 +29,7 @@ namespace paddle {
namespace framework {
namespace details {
struct OpConnectionDetector {
class OpConnectionDetector {
public:
enum RelationShip { kSame = 0, kNoDeps = 1, kBefore = 2, kAfter = 3 };
......@@ -37,8 +37,8 @@ struct OpConnectionDetector {
: graph_(all_ops) {}
template <typename OpSet>
std::unordered_set<typename OpSet::key_type> MaxNoDepOps(
const OpSet &op_set) {
OpSet MaxNoDepOps(const OpSet &op_set) {
if (op_set.size() <= 1) return op_set;
using KeyType = typename OpSet::key_type;
static_assert(
std::is_base_of<OpHandleBase,
......@@ -46,7 +46,7 @@ struct OpConnectionDetector {
"Key type of OpSet must be or derived of OpHandleBase");
std::vector<OpHandleBase *> ops(op_set.begin(), op_set.end());
std::unordered_set<KeyType> ret;
OpSet ret;
auto rels = GetRelations(ops);
auto not_before = [](RelationShip r) { return r != kBefore; };
for (size_t i = 0; i < rels.size(); ++i) {
......@@ -79,7 +79,7 @@ struct OpConnectionDetector {
auto it = op_to_idx.find(op);
if (it != op_to_idx.end()) {
size_t j = it->second;
if (ret[i][j] != kSame) {
if (i != j && ret[i][j] == kSame) {
ret[i][j] = kBefore;
ret[j][i] = kAfter;
found_num += 2;
......@@ -208,6 +208,10 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
VLOG(10) << "Shrink last living op number of " << var_name << " from "
<< original_size << " to " << last_live_op.size();
}
PADDLE_ENFORCE(!last_live_op.empty(),
"Last living ops of %s cannot be empty", var_name);
ref_cnts[i].emplace(var_name, last_live_op.size());
last_live_ops_of_vars[i].emplace(var_name, std::move(last_live_op));
}
......
......@@ -49,8 +49,6 @@ void StreamCallbackManager::AddCallback(std::function<void()> callback) const {
#endif
}
StreamCallbackManager::~StreamCallbackManager() { Wait(); }
void StreamCallbackManager::Wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
{
......
......@@ -33,7 +33,7 @@ class StreamCallbackManager {
public:
explicit StreamCallbackManager(const cudaStream_t stream);
~StreamCallbackManager();
~StreamCallbackManager() = default;
void AddCallback(std::function<void()> callback) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册