提交 22cfbd1c 编写于 作者: C Craig Citro 提交者: TensorFlower Gardener

Switch Docker instructions to always `--pull` on build.

This fixes situations like Vincent hit, where a stale base image would lead to
new packages based on old base packages.
Change: 118071412
上级 2e9f0fb8
......@@ -176,6 +176,9 @@ DirectSession::~DirectSession() {
if (options_.config.use_per_session_threads()) {
delete thread_pool_;
}
for (auto it : cost_models_) {
delete it.second;
}
}
Status DirectSession::Create(const GraphDef& graph) {
......@@ -314,9 +317,11 @@ Status DirectSession::Run(const RunOptions& run_options,
LogMemory::RecordStep(args.step_id, run_state_args.handle);
}
if (run_options.trace_level() == RunOptions::FULL_TRACE) {
args.stats_collector =
new StepStatsCollector(run_metadata->mutable_step_stats());
if (run_options.trace_level() == RunOptions::FULL_TRACE ||
options_.config.graph_options().build_cost_model()) {
args.stats_collector = new StepStatsCollector(
run_metadata->mutable_step_stats(), &cost_models_);
run_state.collector = args.stats_collector;
}
for (const auto& item : executors_and_keys->items) {
......@@ -327,10 +332,6 @@ Status DirectSession::Run(const RunOptions& run_options,
? run_options.timeout_in_ms()
: operation_timeout_in_ms_);
if (run_options.trace_level() == RunOptions::FULL_TRACE) {
delete args.stats_collector;
}
{
mutex_lock l(run_state.mu_);
TF_RETURN_IF_ERROR(run_state.status);
......@@ -400,6 +401,11 @@ Status DirectSession::PRunSetup(const std::vector<string>& input_names,
LogMemory::RecordStep(args.step_id, run_state_args.handle);
}
if (options_.config.graph_options().build_cost_model()) {
run_state->collector = new StepStatsCollector(nullptr, &cost_models_);
args.stats_collector = run_state->collector;
}
for (auto& item : executors_and_keys->items) {
Executor* exec = item.executor;
exec->RunAsync(args, barrier->Get());
......@@ -912,6 +918,19 @@ Status DirectSession::CreateGraphs(gtl::ArraySlice<string> feeds,
return ::tensorflow::Status::OK();
}
DirectSession::RunState::~RunState() {
if (rendez != nullptr) {
if (!executors_done.HasBeenNotified()) {
rendez->StartAbort(errors::Cancelled("PRun cancellation"));
executors_done.WaitForNotification();
}
rendez->Unref();
}
if (collector != nullptr) {
delete collector;
}
}
void DirectSession::WaitForNotification(RunState* run_state,
int64 timeout_in_ms) {
if (timeout_in_ms > 0) {
......
......@@ -40,6 +40,7 @@ limitations under the License.
namespace tensorflow {
class CostModel;
class Device;
class ThreadPool;
......@@ -79,6 +80,12 @@ class DirectSession : public Session {
std::vector<Tensor>* outputs) override;
::tensorflow::Status Close() override;
// NOTE: This is a temporary api that is only meant to enable testing.
// This api will be replaced with better ones soon, so DO NOT USE
const std::unordered_map<const Graph*, CostModel*>& GetCostModels() const {
return cost_models_;
}
private:
typedef DirectSession ME;
......@@ -124,6 +131,7 @@ class DirectSession : public Session {
mutex mu_;
Status status GUARDED_BY(mu_);
IntraProcessRendezvous* rendez = nullptr;
StepStatsCollector* collector = nullptr;
Notification executors_done;
std::unordered_set<string> pending_inputs;
std::unordered_set<string> pending_outputs;
......@@ -138,16 +146,7 @@ class DirectSession : public Session {
pending_outputs.emplace(name);
}
}
~RunState() {
if (rendez != nullptr) {
if (!executors_done.HasBeenNotified()) {
rendez->StartAbort(errors::Cancelled("PRun cancellation"));
executors_done.WaitForNotification();
}
rendez->Unref();
}
}
~RunState();
};
struct RunStateArgs {
......@@ -250,6 +249,9 @@ class DirectSession : public Session {
// Global timeout for all blocking operations in this session.
const int64 operation_timeout_in_ms_ = 0;
std::unordered_map<const Graph*, CostModel*> cost_models_
GUARDED_BY(executor_lock_);
TF_DISALLOW_COPY_AND_ASSIGN(DirectSession);
};
......
......@@ -20,11 +20,13 @@ limitations under the License.
#include <unordered_map>
#include <vector>
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/costmodel.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/ops_util.h"
......@@ -654,5 +656,65 @@ TEST(DirectSessionTest, TimeoutSession) {
session->Close();
}
TEST(DirectSessionTest, CostModelTest) {
Graph graph(OpRegistry::Global());
Tensor a_tensor(DT_FLOAT, TensorShape({2, 2}));
test::FillValues<float>(&a_tensor, {3, 2, -1, 0});
Node* a = test::graph::Constant(&graph, a_tensor);
a->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
Tensor x_tensor(DT_FLOAT, TensorShape({2, 1}));
test::FillValues<float>(&x_tensor, {1, 1});
Node* x = test::graph::Constant(&graph, x_tensor);
x->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
// y = A * x
Node* y = test::graph::Matmul(&graph, a, x, false, false);
y->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:0");
Node* y_neg = test::graph::Unary(&graph, "Neg", y);
y_neg->set_assigned_device_name("/job:localhost/replica:0/task:0/cpu:1");
GraphDef def;
test::graph::ToGraphDef(&graph, &def);
SessionOptions options;
(*options.config.mutable_device_count())["CPU"] = 2;
options.config.mutable_graph_options()->set_build_cost_model(true);
std::vector<Device*> devices;
DeviceFactory::AddDevices(options, "/job:localhost/replica:0/task:0",
&devices);
DirectSession session(options, new DeviceMgr(devices));
TF_ASSERT_OK(session.Create(def));
std::vector<std::pair<string, Tensor>> inputs;
// Request two targets: one fetch output and one non-fetched output.
std::vector<string> output_names = {y->name() + ":0"};
std::vector<string> target_nodes = {y_neg->name()};
std::vector<Tensor> outputs;
Status s = session.Run(inputs, output_names, target_nodes, &outputs);
TF_ASSERT_OK(s);
const std::unordered_map<const Graph*, CostModel*>& cost_models =
session.GetCostModels();
// We should have 2 cost models since we have 2 cpu devices.
ASSERT_EQ(2, cost_models.size());
for (auto it : cost_models) {
const Graph* g = (it).first;
const CostModel* cm = (it).second;
for (Node* node : g->nodes()) {
if (node->name() == y->name()) {
EXPECT_EQ(0, cm->MaxSize(node, 0));
EXPECT_EQ(0, cm->Aliases(node, 0));
} else if (node->name() == y_neg->name()) {
EXPECT_EQ(0, cm->MaxSize(node, 0));
EXPECT_EQ(0, cm->Aliases(node, 0));
}
}
}
}
} // namespace
} // namespace tensorflow
......@@ -1454,6 +1454,7 @@ bool ExecutorState::NodeDone(const Status& s, const Node* node,
std::deque<TaggedNode>* inline_ready) {
if (stats_collector_) {
nodestats::SetAllEnd(stats);
stats_collector_->UpdateCostModel(stats, impl_->graph_, node);
if (!SetTimelineLabel(node, stats)) {
// Only record non-transfer nodes.
stats_collector_->Save(impl_->params_.device->name(), stats);
......
......@@ -15,16 +15,51 @@ limitations under the License.
#include "tensorflow/core/common_runtime/step_stats_collector.h"
#include "tensorflow/core/framework/step_stats.pb.h"
#include "tensorflow/core/graph/costmodel.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
StepStatsCollector::StepStatsCollector(StepStats* ss) : step_stats_(ss) {}
StepStatsCollector::StepStatsCollector(
StepStats* ss, std::unordered_map<const Graph*, CostModel*>* cm)
: step_stats_(ss), cost_models_(cm) {}
void StepStatsCollector::UpdateCostModel(const NodeExecStats* nt,
const Graph* graph, const Node* node) {
mutex_lock l(mu_);
if (!cost_models_) {
return;
}
CostModel* cm;
auto it = cost_models_->find(graph);
if (it == cost_models_->end()) {
cm = new CostModel(false);
cm->InitFromGraph(*graph);
cost_models_->emplace(graph, cm);
} else {
cm = (*it).second;
}
for (int i = 0; i < nt->output_size(); ++i) {
cm->RecordMaxSize(node, i, Bytes(nt->output(i)
.tensor_description()
.allocation_description()
.allocated_bytes()));
cm->RecordAliases(node, i, nt->output(i)
.tensor_description()
.allocation_description()
.allocation_id());
}
}
void StepStatsCollector::Save(const string& device, NodeExecStats* nt) {
VLOG(1) << "Save dev " << device << " nt " << nt;
{
mutex_lock l(mu_);
if (!step_stats_) {
delete nt;
return;
}
DeviceStepStats* dss = nullptr;
// Slow linear scan, but it should only be called
// by a Worker in a context with < ~10 devices.
......
......@@ -15,19 +15,27 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
#define TENSORFLOW_CORE_DISTRIBUTED_RUNTIME_STEP_STATS_COLLECTOR_H_
#include <unordered_map>
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
class CostModel;
class Graph;
class Node;
class NodeExecStats;
class StepStats;
class StepStatsCollector {
public:
explicit StepStatsCollector(StepStats* ss);
explicit StepStatsCollector(
StepStats* ss,
std::unordered_map<const Graph*, CostModel*>* cost_models = nullptr);
void UpdateCostModel(const NodeExecStats* nt, const Graph* graph,
const Node* node);
void Save(const string& device, NodeExecStats* nt);
void Swap(StepStats* ss);
......@@ -36,6 +44,7 @@ class StepStatsCollector {
friend class StepStatsMgr;
mutex mu_;
StepStats* step_stats_ GUARDED_BY(mu_);
std::unordered_map<const Graph*, CostModel*>* cost_models_ GUARDED_BY(mu_);
};
} // namespace tensorflow
......
......@@ -124,6 +124,8 @@ void CostModel::Ensure(int id) {
slot_bytes_.resize(id + 1);
count_.resize(id + 1);
time_.resize(id + 1);
max_mem_usage_.resize(id + 1);
output_port_alias_.resize(id + 1);
}
}
......@@ -132,11 +134,16 @@ void CostModel::SetNumOutputs(const Node* node, int num_outputs) {
if (id < 0) return;
Ensure(id);
auto perslot = &slot_bytes_[id];
auto max_mem_usage = &max_mem_usage_[id];
auto output_port_alias = &output_port_alias_[id];
if (perslot->size() > 0) {
CHECK_EQ(num_outputs, perslot->size()) << "Cannot resize slot_bytes, node="
<< node->name();
} else {
perslot->resize(num_outputs, Bytes(-1));
max_mem_usage->output_port_mem.resize(num_outputs, Bytes(-1));
max_mem_usage->temp_memory_size = Bytes(-1);
output_port_alias->resize(num_outputs, -1);
}
}
......@@ -224,6 +231,39 @@ void CostModel::CheckInitialized(const Graph& graph) const {
}
}
void CostModel::RecordMaxSize(const Node* node, int output_slot, Bytes bytes) {
const int id = Id(node);
if (id < 0) return;
Ensure(id);
max_mem_usage_[id].output_port_mem[output_slot] = bytes;
}
Bytes CostModel::MaxSize(const Node* node, int slot) const {
const int id = Id(node);
if (id < 0 || static_cast<size_t>(id) >= slot_bytes_.size() ||
slot_bytes_[id].size() <= static_cast<size_t>(slot)) {
return Bytes(0);
}
return max_mem_usage_[id].output_port_mem[slot];
}
void CostModel::RecordAliases(const Node* node, int output_slot,
int64 alias_id) {
const int id = Id(node);
if (id < 0) return;
Ensure(id);
output_port_alias_[id][output_slot] = alias_id;
}
int64 CostModel::Aliases(const Node* node, int slot) const {
const int id = Id(node);
if (id < 0 || static_cast<size_t>(id) >= slot_bytes_.size() ||
slot_bytes_[id].size() <= static_cast<size_t>(slot)) {
return -1;
}
return output_port_alias_[id][slot];
}
Microseconds CostModel::CopyTimeEstimate(Bytes b, double network_latency_millis,
double estimated_gbps) {
// TODO(jeff,sanjay): estimate cost based on bandwidth along the
......
......@@ -95,6 +95,22 @@ class CostModel {
// Check that an estimate is available for every OP node in graph.
void CheckInitialized(const Graph& graph) const;
// Records the maximum size in bytes of the tensor generated by "output_slot"
// of "node".
void RecordMaxSize(const Node* node, int output_slot, Bytes bytes);
// Returns the maximum size in bytes of the tensor generated by "output_slot"
// of "node".
Bytes MaxSize(const Node* node, int output_slot) const;
// Record the unique id of the tensor generated by "output_slot" of "node".
// Any other tensor sharing the same id will be an alias, i.e. it will share
// the same underlying memory storage area.
void RecordAliases(const Node* node, int output_slot, int64 alias_id);
// Return the unique id of the tensor generated by "output_slot" of "node".
int64 Aliases(const Node* node, int output_slot) const;
// Helper routines to encapsulate static estimation heuristics
// Compute an estimate of the time to copy "b" bytes over the network,
......@@ -131,6 +147,15 @@ class CostModel {
// Cumulative Bytes output on each channel.
std::vector<gtl::InlinedVector<Bytes, 2> > slot_bytes_;
// Maximum memory usage
struct MemUsage {
Bytes temp_memory_size;
gtl::InlinedVector<Bytes, 2> output_port_mem;
};
std::vector<MemUsage> max_mem_usage_;
std::vector<gtl::InlinedVector<int64, 2> > output_port_alias_;
TF_DISALLOW_COPY_AND_ASSIGN(CostModel);
};
......
......@@ -76,6 +76,10 @@ message GraphOptions {
// Options controlling how graph is optimized.
OptimizerOptions optimizer_options = 3;
// Build a cost model detailing the memory usage and performance of
// each node of the graph.
bool build_cost_model = 4;
};
// Session configuration parameters.
......
......@@ -58,7 +58,7 @@ Building a local Docker container
---------------------------------
cd tensorflow/examples/udacity
docker build -t $USER/assignments .
docker build --pull -t $USER/assignments .
Running the local container
---------------------------
......
......@@ -537,12 +537,13 @@ class FunctionInlineControlTest(tf.test.TestCase):
y = Forward(x)
dx, = tf.gradients([y], [x])
np.random.seed(12345)
inp = np.random.uniform(-1, 1, [2 * 1024, 1]).astype(np.float32)
np.random.seed(321)
inp = np.random.uniform(-1, 1, [16, 1]).astype(np.float32)
with tf.Session(graph=g, config=cfg) as sess:
ans = sess.run([y, dx], {x: inp})
self.assertAllClose(ans[0], 1384849.5, rtol=1e-3)
self.assertAllClose(np.sum(ans[1]), 7127613.5, rtol=1e-3)
print(ans[0], np.sum(ans[1]))
self.assertAllClose(ans[0], 255.971, rtol=1e-3)
self.assertAllClose(np.sum(ans[1]), 13.0408, rtol=1e-3)
if __name__ == "__main__":
......
......@@ -897,7 +897,7 @@ def ZerosLikeOutsideLoop(op, index):
branch = op_ctxt.branch
switch_val = switch(op.inputs[0], pred)[1 - branch]
zeros_shape = array_ops.shape(switch_val)
return array_ops.zeros(zeros_shape)
return array_ops.zeros(zeros_shape, dtype=val.dtype)
class ControlFlowContext(object):
......
......@@ -48,4 +48,4 @@ Alternately, you can use the `docker_run_gpu.sh` script in this directory.
Just pick the dockerfile corresponding to the container you want to build, and run;
$ docker build -t $USER/tensorflow-suffix -f Dockerfile.suffix .
$ docker build --pull -t $USER/tensorflow-suffix -f Dockerfile.suffix .
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册