提交 79082c94 编写于 作者: Y Yancey1989

fix pyreader failed

上级 2dda19f7
......@@ -27,34 +27,31 @@ namespace framework {
namespace details {
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
std::vector<std::vector<VariableInfo>> var_infos_list,
std::vector<platform::Place> places,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
: strategy_(std::move(strategy)),
underlying_executor_(std::move(underlying_executor)),
local_scopes_(std::move(local_scopes)),
var_infos_list_(std::move(var_infos_list)),
var_infos_(std::move(var_infos)),
places_(std::move(places)) {}
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
if (drop_scope_counter_ == 0) {
// Create local scopes.
for (size_t i = 0; i < local_scopes_.size(); ++i) {
auto &scope = local_scopes_[i];
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &var_infos : var_infos_list_) {
for (auto &info : var_infos) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
......
......@@ -38,8 +38,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
public:
ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
std::vector<std::vector<VariableInfo>> var_info_list,
std::vector<platform::Place> places,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
const ir::Graph& Graph() const override {
......@@ -54,7 +53,7 @@ class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
ExecutionStrategy strategy_;
std::unique_ptr<SSAGraphExecutor> underlying_executor_;
std::vector<Scope*> local_scopes_;
std::vector<std::vector<VariableInfo>> var_infos_list_;
std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_;
};
} // namespace details
......
......@@ -216,7 +216,6 @@ void ThreadedSSAGraphExecutor::RunOp(
if (LIKELY(!strategy_.dry_run_)) {
op->Run(strategy_.use_cuda_);
}
VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--;
ready_var_q->Extend(op->Outputs());
VLOG(10) << op << " " << op->Name() << "Signal posted";
......
......@@ -141,7 +141,6 @@ ParallelExecutor::ParallelExecutor(
std::vector<std::unique_ptr<ir::Graph>> graphs;
#if defined(PADDLE_WITH_CUDA) && !defined(_WIN32)
if (exec_strategy.type_ == ExecutionStrategy::kParallelGraph) {
VLOG(1) << "kParallelGraph mode!!";
for (size_t i = 0; i < member_->places_.size(); ++i) {
std::unique_ptr<ir::Graph> graph = build_strategy.Apply(
main_program, {member_->places_[i]}, loss_var_name, params,
......@@ -178,8 +177,8 @@ ParallelExecutor::ParallelExecutor(
ref_cnt_pass->SetNotOwned(details::kGlobalReferenceCount, &ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kCurReferenceCount, &cur_ref_cnts_);
ref_cnt_pass->SetNotOwned(details::kGarbageCollector, &gcs_);
graphs[0] = ref_cnt_pass->Apply(std::move(graphs[i]));
graphs[0]->SetNotOwned("garbage_collector", &gcs_);
graphs[i] = ref_cnt_pass->Apply(std::move(graphs[i]));
graphs[i]->SetNotOwned("garbage_collector", &gcs_);
}
}
}
......@@ -192,6 +191,18 @@ ParallelExecutor::ParallelExecutor(
// Step 3. Create vars in each scope. Passes may also create new vars.
// skip control vars and empty vars
std::vector<details::VariableInfo> var_infos;
for (auto &graph : graphs) {
for (auto &node : graph->Nodes()) {
if (node->IsVar() && !node->IsCtrlVar() && node->Var()) {
var_infos.emplace_back();
var_infos.back().name_ = node->Var()->Name();
var_infos.back().type_ = node->Var()->GetType();
var_infos.back().persistable_ = node->Var()->Persistable();
}
}
}
/**
std::vector<std::vector<details::VariableInfo>> var_infos_list;
for (size_t i = 0; i < graphs.size(); ++i) {
std::vector<details::VariableInfo> var_infos;
......@@ -203,8 +214,9 @@ ParallelExecutor::ParallelExecutor(
var_infos.back().persistable_ = node->Var()->Persistable();
}
}
var_infos_list.emplace_back(std::move(var_infos));
var_infos_list.push_back(std::move(var_infos));
}
**/
// If the loss_var_name is given, the number of graph should be only one.
if (loss_var_name.size()) {
......@@ -236,7 +248,7 @@ ParallelExecutor::ParallelExecutor(
}
member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, std::move(var_infos_list),
exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->places_, std::move(member_->executor_)));
}
......
......@@ -58,9 +58,7 @@ void BufferedReader::ReadAsync(size_t i) {
TensorVec &gpu = gpu_buffer_[i];
gpu.resize(cpu.size());
for (size_t i = 0; i < cpu.size(); ++i) {
VLOG(1) << "launch tensor copy from cpu to cpu, idx: " << i;
framework::TensorCopySync(cpu[i], place_, &gpu[i]);
VLOG(1) << "done " << i;
gpu[i].set_lod(cpu[i].lod());
}
}
......
......@@ -28,10 +28,8 @@ class PyReader : public framework::FileReader {
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
VLOG(1) << "come in PyReader::ReadNext function, out: " << out;
bool success;
*out = queue_->Pop(&success);
VLOG(1) << "call PyReader::ReadNext " << success;
if (!success) out->clear();
}
......
......@@ -115,12 +115,10 @@ class PreemptiveReaderContainer : public IReaderContainer {
}
void ReadNext(std::vector<framework::LoDTensor>* out) override {
VLOG(1) << "flag";
if (!pending_.empty()) {
auto future_it = complete_queue_.Pop();
FutureItem item = future_it->get();
if (item.exception_) {
VLOG(1) << "item has exception!!!";
for (auto it = futures_.begin(); it != futures_.end(); ++it) {
if (it != future_it) {
it->wait(); // Wait all other threads complete.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册