提交 60ec764c 编写于 作者: S Suharsh Sivakumar 提交者: TensorFlower Gardener

Cache Graphs for Run calls and graphs for PartialRun calls separately.

name_to_node_ is only populated on a PartialRun call so if a Run for graph g is followed by a PartialRun to the same graph g, the PartialRun will fail since a cached ReffedClientGraph with no name_to_node map will be returned. We should consider a better solution for this if rebuilding the graph for new partial run calls proves to be too expensive.
Change: 139498839
上级 bfe96c9b
......@@ -970,7 +970,8 @@ MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env,
handle_(strings::FpToString(random::New64())),
stats_publisher_factory_(std::move(stats_publisher_factory)),
graph_version_(0),
runs_(5),
run_graphs_(5),
partial_run_graphs_(5),
cancellation_manager_(new CancellationManager) {
UpdateLastAccessTime();
......@@ -996,8 +997,8 @@ MasterSession::MasterSession(const SessionOptions& opt, const MasterEnv* env,
MasterSession::~MasterSession() {
delete cancellation_manager_;
for (const auto& iter : runs_) iter.second->Unref();
for (const auto& iter : obsolete_) iter.second->Unref();
for (const auto& iter : run_graphs_) iter.second->Unref();
for (const auto& iter : partial_run_graphs_) iter.second->Unref();
for (Device* dev : remote_devs_) delete dev;
}
......@@ -1065,23 +1066,23 @@ Status MasterSession::StartStep(const BuildGraphOptions& opts, int64* count,
// this session.
int64* c = &subgraph_execution_counts_[hash];
*count = (*c)++;
auto iter = runs_.find(hash);
if (iter == runs_.end()) {
// TODO(suharshs): We cache partial run graphs and run graphs separately
// because there is preprocessing that needs to only be run for partial
// run calls.
RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
auto iter = m->find(hash);
if (iter == m->end()) {
// We have not seen this subgraph before. Build the subgraph and
// cache it.
VLOG(1) << "Unseen hash " << hash << " for "
<< BuildGraphOptionsString(opts);
<< BuildGraphOptionsString(opts) << " is_partial = " << is_partial
<< "\n";
std::unique_ptr<SimpleClientGraph> client_graph;
TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
auto entry = new ReffedClientGraph(
handle_, opts, std::move(client_graph), session_opts_,
stats_publisher_factory_, execution_state_.get(), is_partial);
iter = runs_.insert({hash, entry}).first;
auto obs_iter = obsolete_.find(hash);
if (obs_iter != obsolete_.end()) {
to_unref = obs_iter->second;
obsolete_.erase(obs_iter);
}
iter = m->insert({hash, entry}).first;
VLOG(1) << "Preparing to execute new graph";
}
*rcg = iter->second;
......@@ -1383,8 +1384,8 @@ Status MasterSession::Close() {
while (num_running_ != 0) {
num_running_is_zero_.wait(l);
}
ClearRunsTable(&to_unref, &runs_);
ClearRunsTable(&to_unref, &obsolete_);
ClearRunsTable(&to_unref, &run_graphs_);
ClearRunsTable(&to_unref, &partial_run_graphs_);
}
for (ReffedClientGraph* rcg : to_unref) rcg->Unref();
delete this;
......
......@@ -119,8 +119,8 @@ class MasterSession {
// scope and lose their state.
class ReffedClientGraph;
typedef std::unordered_map<uint64, ReffedClientGraph*> RCGMap;
RCGMap runs_ GUARDED_BY(mu_);
RCGMap obsolete_ GUARDED_BY(mu_);
RCGMap run_graphs_ GUARDED_BY(mu_);
RCGMap partial_run_graphs_ GUARDED_BY(mu_);
struct PerStepState {
bool collect_costs = false;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册