提交 6b7314de 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Consolidating the code to fill the partition's function library

into one place. Previously, Partition() and MasterSession::RegisterPartition()
both fills in the partitioned graph's function library.

PiperOrigin-RevId: 163400992
上级 28373cfe
......@@ -1342,6 +1342,7 @@ Status DirectSession::CreateGraphs(
// Just return '1'.
return 1;
};
popts.flib_def = &client_graph->graph.flib_def();
popts.control_flow_added = false;
std::unordered_map<string, GraphDef> partitions;
......
......@@ -155,6 +155,7 @@ Status GraphMgr::InitItem(const string& session, const GraphDef& gdef,
return PartitionOptions::kIllegalIncarnation;
}
};
popts.flib_def = &graph.flib_def();
popts.control_flow_added = true;
popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
......
......@@ -76,7 +76,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
debug_opts_(bopts.debug_options),
worker_cache_(worker_cache) {
VLOG(1) << "Created ReffedClientGraph for node with "
<< client_graph_->graph.num_node_ids();
<< client_graph()->graph.num_node_ids();
stats_publisher_ = stats_publisher_factory(handle, bopts, session_opts);
......@@ -166,8 +166,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
// Partitions the graph into subgraphs and registers them on
// workers.
Status RegisterPartitions(const PartitionOptions& popts,
const FunctionLibraryDefinition& flib_def);
Status RegisterPartitions(const PartitionOptions& popts);
// Runs one step of all partitions.
Status RunPartitions(const MasterEnv* env, int64 step_id,
......@@ -263,7 +262,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
PartitionOptions pots,
std::unordered_map<string, GraphDef>* out_partitions);
Status DoRegisterPartitions(
const PartitionOptions& popts, const FunctionDefLibrary& func_def_lib,
const PartitionOptions& popts,
std::unordered_map<string, GraphDef> graph_partitions);
// Deregisters the partitions on the workers. Called in the
......@@ -274,7 +273,7 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
};
Status MasterSession::ReffedClientGraph::RegisterPartitions(
const PartitionOptions& popts, const FunctionLibraryDefinition& flib_def) {
const PartitionOptions& popts) {
{ // Ensure register once.
mu_.lock();
if (!init_started_) {
......@@ -293,8 +292,7 @@ Status MasterSession::ReffedClientGraph::RegisterPartitions(
graph_defs_for_publishing.push_back(&name_def.second);
}
stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
s = DoRegisterPartitions(popts, flib_def.ToProto(),
std::move(graph_defs));
s = DoRegisterPartitions(popts, std::move(graph_defs));
}
mu_.lock();
init_result_ = s;
......@@ -374,7 +372,7 @@ Status MasterSession::ReffedClientGraph::DoBuildPartitions(
}
Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
const PartitionOptions& popts, const FunctionDefLibrary& func_def_lib,
const PartitionOptions& popts,
std::unordered_map<string, GraphDef> graph_partitions) {
partitions_.reserve(graph_partitions.size());
Status s;
......@@ -408,8 +406,6 @@ Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
Call* c = &calls[i];
c->req.set_session_handle(session_handle_);
c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
// For simplicity, we ship the library completely to every worker.
*c->req.mutable_graph_def()->mutable_library() = func_def_lib;
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() = debug_opts_;
VLOG(2) << "Register " << c->req.graph_def().DebugString();
......@@ -1305,6 +1301,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
mutex_lock l(mu_);
return strings::StrCat(prefix, "_S", next_node_id_++);
};
popts.flib_def = rcg->client_graph()->flib_def.get();
popts.get_incarnation = [this](const string& name) -> int64 {
Device* d = devices_->FindDeviceByName(name);
if (d == nullptr) {
......@@ -1332,8 +1329,7 @@ Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
popts.need_to_record_start_times = true;
}
TF_RETURN_IF_ERROR(
rcg->RegisterPartitions(popts, *rcg->client_graph()->flib_def));
TF_RETURN_IF_ERROR(rcg->RegisterPartitions(popts));
return Status::OK();
}
......
......@@ -1165,11 +1165,16 @@ Status Partition(const PartitionOptions& opts, Graph* g,
}
}
const FunctionLibraryDefinition* flib_def = opts.flib_def;
if (flib_def == nullptr) {
flib_def = &g->flib_def();
}
// Set versions, function library and send/recv incarnation.
for (auto& it : *partitions) {
GraphDef* gdef = &it.second;
*gdef->mutable_versions() = g->versions();
*gdef->mutable_library() = g->flib_def().ToProto();
*gdef->mutable_library() = flib_def->ToProto();
// Traverse the graph to fill every send/recv op's incarnation
// information.
......
......@@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/costmodel.h"
#include "tensorflow/core/graph/graph.h"
......@@ -45,6 +46,10 @@ struct PartitionOptions {
typedef std::function<uint64(const string&)> GetIncarnationFunc;
GetIncarnationFunc get_incarnation = nullptr;
// If specified, flib_def defines a function library that should be
// partitioned and replicated into each resulting partition graphs.
const FunctionLibraryDefinition* flib_def = nullptr;
// True if all the control flow "code" has already been added. The
// control flow code needs to be added when we still have the entire
// graph before any partitioning. So this flag should be false for
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册