提交 79082c94 编写于 作者: Y Yancey1989

fix pyreader failed

上级 2dda19f7
...@@ -27,34 +27,31 @@ namespace framework { ...@@ -27,34 +27,31 @@ namespace framework {
namespace details { namespace details {
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor( ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope *> local_scopes, ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
std::vector<std::vector<VariableInfo>> var_infos_list, std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor> &&underlying_executor) std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
: strategy_(std::move(strategy)), : strategy_(std::move(strategy)),
underlying_executor_(std::move(underlying_executor)), underlying_executor_(std::move(underlying_executor)),
local_scopes_(std::move(local_scopes)), local_scopes_(std::move(local_scopes)),
var_infos_list_(std::move(var_infos_list)), var_infos_(std::move(var_infos)),
places_(std::move(places)) {} places_(std::move(places)) {}
FeedFetchList ScopeBufferedSSAGraphExecutor::Run( FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) { const std::vector<std::string> &fetch_tensors) {
if (drop_scope_counter_ == 0) { if (drop_scope_counter_ == 0) {
// Create local scopes. // Create local scopes.
for (size_t i = 0; i < local_scopes_.size(); ++i) { for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = local_scopes_[i]; auto &scope = *it;
Scope &local_scope = scope->NewScope(); Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() = *scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope; &local_scope;
for (auto &var_infos : var_infos_list_) { for (auto &info : var_infos_) {
for (auto &info : var_infos) { if (scope->FindVar(info.name_) != nullptr) {
if (scope->FindVar(info.name_) != nullptr) { continue;
continue; }
} if (info.persistable_) { // Persistable
if (info.persistable_) { // Persistable InitializeVariable(scope->Var(info.name_), info.type_);
InitializeVariable(scope->Var(info.name_), info.type_); } else {
} else { InitializeVariable(local_scope.Var(info.name_), info.type_);
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
} }
} }
} }
......
...@@ -38,8 +38,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -38,8 +38,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
public: public:
ScopeBufferedSSAGraphExecutor( ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope*> local_scopes, ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
std::vector<std::vector<VariableInfo>> var_info_list, std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor>&& underlying_executor); std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
const ir::Graph& Graph() const override { const ir::Graph& Graph() const override {
...@@ -54,7 +53,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -54,7 +53,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
ExecutionStrategy strategy_; ExecutionStrategy strategy_;
std::unique_ptr<SSAGraphExecutor> underlying_executor_; std::unique_ptr<SSAGraphExecutor> underlying_executor_;
std::vector<Scope*> local_scopes_; std::vector<Scope*> local_scopes_;
std::vector<std::vector<VariableInfo>> var_infos_list_; std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
}; };
} // namespace details } // namespace details
......
...@@ -216,7 +216,6 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -216,7 +216,6 @@ void ThreadedSSAGraphExecutor::RunOp(
if (LIKELY(!strategy_.dry_run_)) { if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_); op->Run(strategy_.use_cuda_);
} }
VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--; running_ops_--;
ready_var_q->Extend(op->Outputs()); ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted"; VLOG(10) << op << " " << op->Name() << "Signal posted";
......
...@@ -141,7 +141,6 @@ ParallelExecutor::ParallelExecutor( ...@@ -141,7 +141,6 @@ ParallelExecutor::ParallelExecutor(
std::vector<std::unique_ptr<ir::Graph>> graphs; std::vector<std::unique_ptr<ir::Graph>> graphs;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32) #if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) { if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
VLOG(1) << "kParallelGraph mode!!";
for (size_t i = 0; i < member_->places_.size(); ++i) { for (size_t i = 0; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply( std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, {member_->places_[i]}, loss_var_name, params, main_program, {member_->places_[i]}, loss_var_name, params,
...@@ -178,8 +177,8 @@ ParallelExecutor::ParallelExecutor( ...@@ -178,8 +177,8 @@ ParallelExecutor::ParallelExecutor(
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_); ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_); ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_); ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
graphs[0] = ref_cnt_pass->Apply(std::move(graphs[i])); graphs[i] = ref_cnt_pass->Apply(std::move(graphs[i]));
graphs[0]->SetNotOwned("garbage_collector", &gcs_); graphs[i]->SetNotOwned("garbage_collector", &gcs_);
} }
} }
} }
...@@ -192,6 +191,18 @@ ParallelExecutor::ParallelExecutor( ...@@ -192,6 +191,18 @@ ParallelExecutor::ParallelExecutor(
// Step 3. Create vars in each scope. Passes may also create new vars. // Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars // skip control vars and empty vars
std::vector<details::VariableInfo> var_infos;
for (auto &graph : graphs) {
for (auto &node : graph->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back();
var_infos.back().name_ = node->Var()->Name();
var_infos.back().type_ = node->Var()->GetType();
var_infos.back().persistable_ = node->Var()->Persistable();
}
}
}
/**
std::vector<std::vector<details::VariableInfo>> var_infos_list; std::vector<std::vector<details::VariableInfo>> var_infos_list;
for (size_t i = 0; i < graphs.size(); ++i) { for (size_t i = 0; i < graphs.size(); ++i) {
std::vector<details::VariableInfo> var_infos; std::vector<details::VariableInfo> var_infos;
...@@ -203,8 +214,9 @@ ParallelExecutor::ParallelExecutor( ...@@ -203,8 +214,9 @@ ParallelExecutor::ParallelExecutor(
var_infos.back().persistable_ = node->Var()->Persistable(); var_infos.back().persistable_ = node->Var()->Persistable();
} }
} }
var_infos_list.emplace_back(std::move(var_infos)); var_infos_list.push_back(std::move(var_infos));
} }
**/
// If the loss_var_name is given, the number of graph should be only one. // If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) { if (loss_var_name.size()) {
...@@ -236,7 +248,7 @@ ParallelExecutor::ParallelExecutor( ...@@ -236,7 +248,7 @@ ParallelExecutor::ParallelExecutor(
} }
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos_list), exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_))); member_->places_, std::move(member_->executor_)));
} }
......
...@@ -58,9 +58,7 @@ void BufferedReader::ReadAsync(size_t i) { ...@@ -58,9 +58,7 @@ void BufferedReader::ReadAsync(size_t i) {
TensorVec &gpu = gpu_buffer_[i]; TensorVec &gpu = gpu_buffer_[i];
gpu.resize(cpu.size()); gpu.resize(cpu.size());
for (size_t i = 0; i < cpu.size(); ++i) { for (size_t i = 0; i < cpu.size(); ++i) {
VLOG(1) << "launch tensor copy from cpu to cpu, idx: " << i;
framework::TensorCopySync(cpu[i], place_, &gpu[i]); framework::TensorCopySync(cpu[i], place_, &gpu[i]);
VLOG(1) << "done " << i;
gpu[i].set_lod(cpu[i].lod()); gpu[i].set_lod(cpu[i].lod());
} }
} }
......
...@@ -28,10 +28,8 @@ class PyReader : public framework::FileReader { ...@@ -28,10 +28,8 @@ class PyReader : public framework::FileReader {
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
VLOG(1) << "come in PyReader::ReadNext function, out: " << out;
bool success; bool success;
*out = queue_->Pop(&success); *out = queue_->Pop(&success);
VLOG(1) << "call PyReader::ReadNext " << success;
if (!success) out->clear(); if (!success) out->clear();
} }
......
...@@ -115,12 +115,10 @@ class PreemptiveReaderContainer : public IReaderContainer { ...@@ -115,12 +115,10 @@ class PreemptiveReaderContainer : public IReaderContainer {
} }
void ReadNext(std::vector<framework::LoDTensor>* out) override { void ReadNext(std::vector<framework::LoDTensor>* out) override {
VLOG(1) << "flag";
if (!pending_.empty()) { if (!pending_.empty()) {
auto future_it = complete_queue_.Pop(); auto future_it = complete_queue_.Pop();
FutureItem item = future_it->get(); FutureItem item = future_it->get();
if (item.exception_) { if (item.exception_) {
VLOG(1) << "item has exception!!!";
for (auto it = futures_.begin(); it != futures_.end(); ++it) { for (auto it = futures_.begin(); it != futures_.end(); ++it) {
if (it != future_it) { if (it != future_it) {
it->wait(); // Wait all other threads complete. it->wait(); // Wait all other threads complete.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册