提交 c8a0dfc7 编写于 作者: J Jiri Simsa 提交者: TensorFlower Gardener

[tf.data] Adding support for `tf.data.AUTOTUNE` as a special value for the...

[tf.data] Adding support for `tf.data.AUTOTUNE` as a special value for the `num_parallel_calls` argument of `tf.data.Dataset.map()`, `tf.data.Dataset.interleave()`, and `tf.contrib.data.map_and_batch()`.

When `tf.data.AUTOTUNE` is specified, the level of parallelism is determined at runtime. The underlying mechanism instruments the input pipeline to build a performance model and then uses the model to find the optimal values for the parallelism knobs.

PiperOrigin-RevId: 213283297
上级 07bc3696
......@@ -58,7 +58,8 @@ class ModelDatasetTest(test.TestCase):
dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
np.random.rand(4 * k,
1))).repeat()
dataset = dataset.map(math_ops.matmul, num_parallel_calls=56)
dataset = dataset.map(
math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE)
iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
get_next = iterator.get_next()
......@@ -84,7 +85,9 @@ class ModelDatasetTest(test.TestCase):
1))).repeat()
dataset = dataset.apply(
batching.map_and_batch(
math_ops.matmul, num_parallel_calls=28, batch_size=batch_size))
math_ops.matmul,
num_parallel_calls=optimization.AUTOTUNE,
batch_size=batch_size))
iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
get_next = iterator.get_next()
......@@ -109,7 +112,9 @@ class ModelDatasetTest(test.TestCase):
1))).repeat()
dataset = dataset.map(math_ops.matmul)
dataset = dataset_ops.Dataset.range(1).repeat().interleave(
lambda _: dataset, cycle_length=56, num_parallel_calls=56)
lambda _: dataset,
cycle_length=10,
num_parallel_calls=optimization.AUTOTUNE)
iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
get_next = iterator.get_next()
......@@ -146,15 +151,15 @@ class ModelDatasetTest(test.TestCase):
x, y = c
return a, b, math_ops.matmul(x, y)
dataset = dataset.map(f1, num_parallel_calls=32)
dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE)
dataset = dataset_ops.Dataset.range(1).repeat().interleave(
lambda _: dataset, cycle_length=2)
dataset = dataset.map(f2, num_parallel_calls=16)
dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE)
dataset = dataset_ops.Dataset.range(1).repeat().interleave(
lambda _: dataset, cycle_length=2)
dataset = dataset.map(f3, num_parallel_calls=10)
dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE)
iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
get_next = iterator.get_next()
......
......@@ -10,7 +10,6 @@ tensorflow/core/framework/graph.pb.cc
tensorflow/core/framework/graph_transfer_info.pb.cc
tensorflow/core/framework/kernel_def.pb.cc
tensorflow/core/framework/log_memory.pb.cc
tensorflow/core/framework/model.pb.cc
tensorflow/core/framework/node_def.pb.cc
tensorflow/core/framework/op_def.pb.cc
tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc
......
......@@ -10,7 +10,6 @@ tensorflow/core/framework/graph.pb.h
tensorflow/core/framework/graph_transfer_info.pb.h
tensorflow/core/framework/kernel_def.pb.h
tensorflow/core/framework/log_memory.pb.h
tensorflow/core/framework/model.pb.h
tensorflow/core/framework/node_def.pb.h
tensorflow/core/framework/op_def.pb.h
tensorflow/core/framework/remote_fused_graph_execute_info.pb.h
......
......@@ -10,7 +10,6 @@ tensorflow/core/framework/graph.pb_text.cc
tensorflow/core/framework/graph_transfer_info.pb_text.cc
tensorflow/core/framework/kernel_def.pb_text.cc
tensorflow/core/framework/log_memory.pb_text.cc
tensorflow/core/framework/model.pb_text.cc
tensorflow/core/framework/node_def.pb_text.cc
tensorflow/core/framework/op_def.pb_text.cc
tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc
......
......@@ -14,7 +14,6 @@ tensorflow/core/framework/graph.proto
tensorflow/core/framework/graph_transfer_info.proto
tensorflow/core/framework/kernel_def.proto
tensorflow/core/framework/log_memory.proto
tensorflow/core/framework/model.proto
tensorflow/core/framework/node_def.proto
tensorflow/core/framework/op_def.proto
tensorflow/core/framework/reader_base.proto
......
......@@ -178,7 +178,6 @@ COMMON_PROTO_SRCS = [
"framework/iterator.proto",
"framework/kernel_def.proto",
"framework/log_memory.proto",
"framework/model.proto",
"framework/node_def.proto",
"framework/op_def.proto",
"framework/reader_base.proto",
......@@ -842,7 +841,6 @@ tf_cuda_library(
"framework/log_memory.h",
"framework/lookup_interface.h",
"framework/memory_types.h",
"framework/model.h",
"framework/node_def_builder.h",
"framework/node_def_util.h",
"framework/numeric_op.h",
......
......@@ -20,7 +20,6 @@ limitations under the License.
namespace tensorflow {
namespace data {
namespace {
// A wrapper class for storing a `DatasetBase` instance in a DT_VARIANT tensor.
......
......@@ -47,6 +47,8 @@ class GraphDefBuilder;
class Node;
namespace data {
// A constant that can be used to enable auto-tuning.
constexpr int kAutoTune = -1;
class DatasetBase;
class SerializationContext;
......@@ -670,13 +672,34 @@ class DatasetBaseIterator : public IteratorBase {
return strings::StrCat(params_.prefix, ":", name);
}
// When performance modeling is enabled, this method sets metadata entry for
// the model node corresponding to this iterator.
void SetMetadata(IteratorContext* ctx, const string& key, int64 value) {
// When performance modeling is enabled, this method adds a constant parameter
// to the model node corresponding to this iterator.
void AddConstantParameter(IteratorContext* ctx, const string& name,
int64 value) {
if (ctx->model()) {
std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
if (node) {
node->set_metadata(key, value);
node->add_constant_param(name, value);
}
}
}
// When performance modeling is enabled, this method adds a tunable parameter
// to the model node corresponding to this iterator.
//
// The `set_fn` function should set the tunable parameter to the value of
// its input argument. The function should be thread-safe; in particular, the
// state it updates should be protected by a lock as the function can be
// invoked asynchronously. It is guaranteed that this function will not be
// invoked after the iterator is deleted because the model node that owns
// the function is deleted when the iterator is deleted.
void AddTunableParameter(IteratorContext* ctx, const string& name,
int64 value, int64 min, int64 max,
std::function<void(int64)>&& set_fn) {
if (ctx->model()) {
std::shared_ptr<model::Node> node = ctx->model()->LookupNode(prefix());
if (node) {
node->add_tunable_param(name, value, min, max, std::move(set_fn));
}
}
}
......
......@@ -15,52 +15,28 @@ limitations under the License.
#include "tensorflow/core/framework/model.h"
#include <memory>
#include "tensorflow/core/lib/gtl/map_util.h"
namespace tensorflow {
namespace data {
namespace model {
// TODO(jsimsa): Use `Node` subclassing instead of types and node statements.
void Node::CollectKnobs(std::vector<Node::Knob>* knobs) {
void Node::CollectTunables(
std::vector<std::shared_ptr<Node::Tunable>>* tunables) {
mutex_lock l(mu_);
for (auto input : inputs_) {
input->CollectTunables(tunables);
}
switch (type_) {
case Type::PARALLEL_INTERLEAVE_V2: {
for (auto input : inputs_) {
input->CollectKnobs(knobs);
}
int64 processing_time = static_cast<int64>(
static_cast<double>(ProcessingTimeLocked() -
inputs_.front()->ProcessingTime()) /
static_cast<double>(inputs_.size() - 1));
knobs->emplace_back(
Node::Knob{this, processing_time, metadata_["parallelism"]});
return;
}
case Type::MAP_AND_BATCH:
case Type::PARALLEL_INTERLEAVE_V2:
case Type::PARALLEL_MAP: {
for (auto input : inputs_) {
input->CollectKnobs(knobs);
}
knobs->emplace_back(
Node::Knob{this, NanosPerElementLocked(), metadata_["parallelism"]});
return;
}
case Type::BATCH:
case Type::CACHE:
case Type::CONCATENATE:
case Type::FILTER:
case Type::FLAT_MAP:
case Type::INTERLEAVE:
case Type::MAP:
case Type::PADDED_BATCH:
case Type::PARALLEL_INTERLEAVE:
case Type::PREFETCH:
case Type::REPEAT:
case Type::SHUFFLE:
case Type::SKIP:
case Type::TAKE:
case Type::ZIP: {
for (auto input : inputs_) {
input->CollectKnobs(knobs);
if (auto* tunable_param =
gtl::FindOrNull(tunable_params_, "parallelism")) {
tunables->push_back(*tunable_param);
}
return;
}
......@@ -69,12 +45,19 @@ void Node::CollectKnobs(std::vector<Node::Knob>* knobs) {
}
}
int64 Node::GetParameterValue(const string& name) {
if (auto* tunable_param = gtl::FindOrNull(tunable_params_, name)) {
return (*tunable_param)->value;
}
return constant_params_[name];
}
int64 Node::ProcessingTimeLocked() {
switch (type_) {
case Type::BATCH:
case Type::MAP_AND_BATCH:
case Type::PADDED_BATCH: {
int64 batch_size = metadata_["batch_size"];
int64 batch_size = GetParameterValue("batch_size");
return NanosPerElementLocked() + batch_size * ProcessingTimeForInputs();
}
case Type::FILTER: {
......@@ -122,7 +105,7 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
switch (type_) {
case Type::BATCH:
case Type::PADDED_BATCH: {
double batch_size = metadata_["batch_size"];
double batch_size = GetParameterValue("batch_size");
int64 old_value = (*input_times)[input_times->size() - 1];
(*input_times)[input_times->size() - 1] = static_cast<int64>(
static_cast<double>(old_value + NanosPerElementLocked()) /
......@@ -168,8 +151,8 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
static_cast<double>(inputs_.size() - 1);
}
case Type::MAP_AND_BATCH: {
double batch_size = metadata_["batch_size"];
double parallelism = metadata_["parallelism"];
double batch_size = GetParameterValue("batch_size");
double parallelism = GetParameterValue("parallelism");
int64 delta =
static_cast<int64>(static_cast<double>(NanosPerElementLocked()) /
(batch_size * parallelism));
......@@ -182,22 +165,41 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
return std::max(0LL,
output_time - input_times->at(input_times->size() - 2));
}
case Type::PARALLEL_INTERLEAVE:
case Type::PARALLEL_INTERLEAVE: {
// TODO(jsimsa): model the first input
if (inputs_.size() <= 1) {
return NanosPerElementLocked();
}
int64 delta = static_cast<double>(NanosPerElementLocked()) *
static_cast<double>(inputs_.size() - 1);
input_times->push_back(delta);
auto cleanup =
gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
int64 inputs_output_time = OutputTimeForInputs(input_times) -
inputs_.front()->OutputTime(input_times);
double parallelism = GetParameterValue("parallelism");
int64 output_time =
NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
static_cast<double>(inputs_.size() - 1)) /
parallelism);
return std::max(0LL,
output_time - input_times->at(input_times->size() - 2));
}
case Type::PARALLEL_INTERLEAVE_V2: {
// TODO(jsimsa): model the first input
if (inputs_.size() <= 1) {
return NanosPerElementLocked();
}
int64 delta =
static_cast<int64>(static_cast<double>(NanosPerElementLocked()) *
static_cast<double>(inputs_.size() - 1));
int64 delta = static_cast<double>(NanosPerElementLocked()) *
static_cast<double>(inputs_.size() - 1);
input_times->push_back(delta);
auto cleanup =
gtl::MakeCleanup([input_times]() { input_times->pop_back(); });
int64 inputs_output_time = OutputTimeForInputs(input_times) -
inputs_.front()->OutputTime(input_times);
double parallelism = std::min(port::NumSchedulableCPUs(),
static_cast<int>(metadata_["parallelism"]));
double parallelism =
std::min(static_cast<int>(GetParameterValue("cycle_length")),
static_cast<int>(GetParameterValue("parallelism")));
int64 output_time =
NanosPerElementLocked() + ((static_cast<double>(inputs_output_time) /
static_cast<double>(inputs_.size() - 1)) /
......@@ -206,8 +208,9 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
output_time - input_times->at(input_times->size() - 2));
}
case Type::PARALLEL_MAP: {
double parallelism = std::min(port::NumSchedulableCPUs(),
static_cast<int>(metadata_["parallelism"]));
double parallelism =
std::min(port::NumSchedulableCPUs(),
static_cast<int>(GetParameterValue("parallelism")));
int64 delta = static_cast<int64>(
static_cast<double>(NanosPerElementLocked()) / parallelism);
input_times->push_back(delta);
......@@ -248,23 +251,6 @@ int64 Node::OutputTimeLocked(std::vector<int64>* input_times) {
}
}
Model::Model(const proto::Model& model_proto) {
id_counter_ = model_proto.id_counter();
std::map<int64, std::shared_ptr<Node>> lookup_table;
for (auto node_proto : model_proto.node()) {
std::shared_ptr<Node> node(new Node(node_proto));
lookup_table[node_proto.id()] = node;
}
for (auto node_proto : model_proto.node()) {
std::shared_ptr<Node> node = lookup_table[node_proto.id()];
for (int64 id : node_proto.input()) {
node->add_input(lookup_table[id]);
}
node->set_output(lookup_table[node_proto.output()]);
}
output_ = lookup_table[model_proto.output()];
}
std::shared_ptr<Node> Model::AddNode(const string& name,
const string& output_name) {
mutex_lock l(mu_);
......@@ -294,94 +280,77 @@ std::shared_ptr<Node> Model::LookupNode(const string& name) {
return result;
}
void Model::Optimize() {
mutex_lock l(mu_);
int64 processing_time = ProcessingTime();
int64 num_cpus = port::NumSchedulableCPUs();
std::vector<Node::Knob> knobs = CollectKnobs();
// The optimization algorithm starts by setting all parallelism knobs to 1. It
// then repeatedly identifies the knob that, when turned up by 1, decreases
// the output time the most. This process is repeated until all knobs reach
// the number of schedulable CPUs or the projected output time is less than or
// equal to the processing time needed to produce an element divided by the
// number of schedulable CPUs.
for (auto& knob : knobs) {
LOG(INFO) << knob.node->name() << " " << knob.processing_time;
knob.value = 1;
knob.node->set_metadata("parallelism", knob.value);
}
while (true) {
int64 output_time = OutputTime();
bool all_knobs = true;
for (auto knob : knobs) {
if (knob.value < num_cpus) {
all_knobs = false;
// The optimization algorithm starts by setting all tunable parallelism
// parameters to 1. It then repeatedly identifies the parameter that whose
// increase in parallelism decreases the output time the most. This process is
// repeated until all parameters reach their maximum values or the
// projected output time is less than or equal to the processing time needed to
// produce an element divided by CPU budget.
void Model::Optimize(int64 cpu_budget) {
mutex_lock l(optimization_mu_);
std::vector<std::shared_ptr<Node::Tunable>> tunables;
{
mutex_lock l2(mu_);
const int64 processing_time = ProcessingTime();
tunables = CollectTunables();
for (auto tunable : tunables) {
tunable->value = 1;
}
while (true) {
const int64 output_time = OutputTime();
bool all_tunables = true;
for (auto& tunable : tunables) {
if (tunable->value < tunable->max) {
all_tunables = false;
break;
}
}
if (output_time < processing_time / cpu_budget || all_tunables) {
break;
}
}
if (output_time < processing_time / num_cpus || all_knobs) {
break;
}
int64 best_delta = -1;
int best_knob = -1;
for (int i = 0; i < knobs.size(); ++i) {
if (knobs[i].value == num_cpus) {
continue;
int64 best_delta = -1;
Node::Tunable* best_tunable = nullptr;
for (auto& tunable : tunables) {
if (tunable->value == tunable->max) {
continue;
}
tunable->value++;
int64 delta = output_time - OutputTime();
if (delta > best_delta) {
best_delta = delta;
best_tunable = tunable.get();
}
tunable->value--;
}
knobs[i].node->set_metadata("parallelism", knobs[i].value + 1);
int64 delta = output_time - OutputTime();
if (delta > best_delta) {
best_delta = delta;
best_knob = i;
if (best_tunable) {
// NOTE: This can happen because we are performing the optimization
// while the model data is changing. If this becomes an issue, we should
// look into performing the optimization using a model snapshot.
break;
}
knobs[i].node->set_metadata("parallelism", knobs[i].value);
best_tunable->value++;
}
knobs[best_knob].value++;
knobs[best_knob].node->set_metadata("parallelism", knobs[best_knob].value);
}
for (auto knob : knobs) {
LOG(INFO) << knob.node->name() << " " << knob.value;
// The `set_fn` functions should be invoked without holding a lock to avoid a
// potential deadlock.
for (auto& tunable : tunables) {
tunable->set_fn(tunable->value);
}
LOG(INFO) << "output time: " << OutputTime();
LOG(INFO) << "processing time: " << ProcessingTime();
}
void Model::OutputToFile() {
proto::Model model_proto;
ToProto(&model_proto);
string filename;
Env::Default()->LocalTempFilename(&filename);
TF_CHECK_OK(WriteStringToFile(Env::Default(), filename,
model_proto.SerializeAsString()));
LOG(INFO) << filename;
}
void Model::RemoveNode(const string& prefix) {
mutex_lock l(mu_);
// Nodes are not allowed to be removed when optimization is in progress to
// prevent the optimization from trying to access an iterator that was
// concurrently deleted.
mutex_lock l(optimization_mu_);
mutex_lock l2(mu_);
lookup_table_.erase(prefix);
}
void Model::ToProto(proto::Model* model_proto) {
mutex_lock l(mu_);
model_proto->set_id_counter(id_counter_);
model_proto->set_output(output_->id());
AddNodeToProto(output_, model_proto);
}
// static
void Model::AddNodeToProto(const std::shared_ptr<Node>& node,
proto::Model* model_proto) {
proto::Node* node_proto = model_proto->add_node();
node->ToProto(node_proto);
for (const std::shared_ptr<Node>& input : node->inputs()) {
AddNodeToProto(input, model_proto);
}
}
std::vector<Node::Knob> Model::CollectKnobs() {
std::vector<Node::Knob> knobs;
output_->CollectKnobs(&knobs);
return knobs;
std::vector<std::shared_ptr<Node::Tunable>> Model::CollectTunables() {
std::vector<std::shared_ptr<Node::Tunable>> tunables;
output_->CollectTunables(&tunables);
return tunables;
}
int64 Model::OutputTime() {
......
......@@ -22,7 +22,6 @@ limitations under the License.
#include <utility>
#include <vector>
#include "tensorflow/core/framework/model.pb.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
......@@ -61,13 +60,10 @@ class Node {
public:
Node(int64 id, std::shared_ptr<Node> output) : id_(id), output_(output) {}
explicit Node(const proto::Node& node_proto) : id_(node_proto.id()) {
name_ = node_proto.name();
type_ = TypeFromName(node_proto.name());
processing_time_ = node_proto.processing_time();
num_elements_ = node_proto.num_elements();
metadata_.insert(node_proto.metadata().begin(),
node_proto.metadata().end());
// Adds a constant parameter.
void add_constant_param(const string& name, int64 value) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
constant_params_[name] = value;
}
// Records that the node produced an element.
......@@ -88,6 +84,15 @@ class Node {
processing_time_ += delta;
}
// Adds a tunable parameter.
void add_tunable_param(const string& name, int64 value, int64 min, int64 max,
std::function<void(int64)>&& set_fn)
LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
tunable_params_[name] =
std::make_shared<Tunable>(value, min, max, std::move(set_fn));
}
// Returns the unique node ID.
int64 id() LOCKS_EXCLUDED(mu_) { return id_; }
......@@ -121,12 +126,6 @@ class Node {
inputs_.remove(input);
}
// Adds the given key-value pair to the node metadata.
void set_metadata(const string& key, int64 value) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
metadata_[key] = value;
}
// Sets the node name.
void set_name(const string& name) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
......@@ -157,11 +156,16 @@ class Node {
}
private:
// Represents a performance knob.
struct Knob {
Node* node;
int64 processing_time;
// Represents a tunable parameter.
struct Tunable {
Tunable(int64 value, int64 min, int64 max,
std::function<void(int64)> set_fn)
: value(value), min(min), max(max), set_fn(std::move(set_fn)) {}
int64 value;
int64 min;
int64 max;
std::function<void(int64)> set_fn;
};
enum class Type {
......@@ -186,8 +190,12 @@ class Node {
UNKNOWN,
};
// Collects performance knobs in the subtree rooted in this node.
void CollectKnobs(std::vector<Node::Knob>* knobs) LOCKS_EXCLUDED(mu_);
// Collects tunable parameters in the subtree rooted in this node.
void CollectTunables(std::vector<std::shared_ptr<Node::Tunable>>* tunables)
LOCKS_EXCLUDED(mu_);
// Gets a value of the given parameter (tunable or constant).
int64 GetParameterValue(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Returns the per-element processing time spent in this node.
int64 NanosPerElement() LOCKS_EXCLUDED(mu_) {
......@@ -238,22 +246,6 @@ class Node {
return sum;
}
// Serializes the node state into the given proto.
void ToProto(proto::Node* node_proto) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
node_proto->set_id(id_);
node_proto->set_name(name_);
node_proto->set_num_elements(num_elements_);
node_proto->set_processing_time(processing_time_);
for (const std::shared_ptr<Node>& input : inputs_) {
node_proto->add_input(input->id());
}
if (output_) {
node_proto->set_output(output_->id());
}
node_proto->mutable_metadata()->insert(metadata_.begin(), metadata_.end());
}
Type TypeFromName(const string& name) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
if (name_ == "Batch") {
return Type::BATCH;
......@@ -319,7 +311,9 @@ class Node {
int64 processing_time_ GUARDED_BY(mu_) = 0;
int64 num_elements_ GUARDED_BY(mu_) = 0;
std::map<std::thread::id, int64> work_start_ GUARDED_BY(mu_);
std::map<string, int64> metadata_ GUARDED_BY(mu_);
std::map<string, int64> constant_params_ GUARDED_BY(mu_);
// Tunables are shared with the model during optimization.
std::map<string, std::shared_ptr<Tunable>> tunable_params_ GUARDED_BY(mu_);
std::list<std::shared_ptr<Node>> inputs_ GUARDED_BY(mu_);
std::shared_ptr<Node> output_ GUARDED_BY(mu_);
......@@ -330,21 +324,15 @@ class Node {
// for collecting runtime information and optimizing performance. It collects
// runtime information about execution of the input pipeline that is used to
// create a performance model, which is in turn used to identify optimal values
// of performance knobs.
// of tunable parameters.
//
// Developers of tf.data transformations are not expected to interact with this
// class directly. Boiler plate code for creating the abstract representation of
// the input pipeline and collecting runtime information has been added to the
// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
//
// TODO(jsimsa): Add a mechanism for feeding the result of the optimization
// into the input pipeline.
class Model {
public:
Model() = default;
explicit Model(const proto::Model& model_proto);
~Model() {}
// Returns the model output node.
std::shared_ptr<Node> output() LOCKS_EXCLUDED(mu_) {
......@@ -360,30 +348,25 @@ class Model {
std::shared_ptr<Node> LookupNode(const string& name) LOCKS_EXCLUDED(mu_);
// Runs optimization.
void Optimize() LOCKS_EXCLUDED(mu_);
// Outputs the state of a model to a file.
//
// TODO(jsimsa): Remove this method once the optimization loop is closed.
void OutputToFile() LOCKS_EXCLUDED(mu_);
void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);
// Removes the node identified by the given name.
void RemoveNode(const string& prefix) LOCKS_EXCLUDED(mu_);
// Serializes the model state to the given proto.
void ToProto(proto::Model* model_proto) LOCKS_EXCLUDED(mu_);
private:
static void AddNodeToProto(const std::shared_ptr<Node>& node,
proto::Model* model_proto);
std::vector<Node::Knob> CollectKnobs() EXCLUSIVE_LOCKS_REQUIRED(mu_);
std::vector<std::shared_ptr<Node::Tunable>> CollectTunables()
EXCLUSIVE_LOCKS_REQUIRED(mu_);
int64 OutputTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
int64 ProcessingTime() EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Used for coordination between different input pipeline threads.
mutex mu_;
// Used for preventing iterator deletion when optimization is in progress
// because the optimization may try to update the values of tunable
// parameters.
mutex optimization_mu_ ACQUIRED_BEFORE(mu_);
int64 id_counter_ GUARDED_BY(mu_) = 1;
std::shared_ptr<Node> output_ GUARDED_BY(mu_);
std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);
......
syntax = "proto3";
package tensorflow.data.model.proto;
option cc_enable_arenas = true;
message Model {
// Counter used for generating new node IDs.
int64 id_counter = 1;
// Nodes of this model.
repeated Node node = 2;
// The ID of the output node.
int64 output = 3;
};
message Node {
// The node ID.
int64 id = 1;
// The node name.
string name = 2;
// Input node IDs.
repeated int64 input = 3;
// Output node ID.
int64 output = 4;
// Number of elements produced by the node.
int64 num_elements = 5;
// The CPU time spent by running threads of this node.
int64 processing_time = 6;
// Key-value store for node metadata (e.g. batch size or parallelism).
map<string, int32> metadata = 7;
};
......@@ -117,7 +117,7 @@ class BatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
SetMetadata(ctx, "batch_size", dataset()->batch_size_);
AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
......
......@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/tracing.h"
namespace tensorflow {
......@@ -39,7 +40,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
: UnaryDatasetOpKernel(ctx),
graph_def_version_(ctx->graph_def_version()),
op_version_(ctx->def().op() == "MapAndBatchDataset" ? 1 : 2) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("output_types", &output_types_));
......@@ -77,7 +77,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
case 2:
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
OP_REQUIRES(ctx, num_parallel_calls > 0,
OP_REQUIRES(ctx,
num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
break;
......@@ -190,7 +191,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params) {}
: DatasetIterator<Dataset>(params),
num_parallel_calls_(params.dataset->num_parallel_calls_) {}
~Iterator() override {
mutex_lock l(mu_);
......@@ -204,8 +206,24 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
SetMetadata(ctx, "batch_size", dataset()->batch_size_);
SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
mutex_lock l(mu_);
AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
if (num_parallel_calls_ == kAutoTune) {
num_parallel_calls_ = 1;
std::function<void(int64)> set_fn = [this](int64 value) {
{
mutex_lock l(mu_);
num_parallel_calls_ = value;
}
VLOG(2) << "setting parallelism knob to " << value;
cond_var_.notify_all();
};
AddTunableParameter(
ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */,
port::NumSchedulableCPUs() /* max */, std::move(set_fn));
} else {
AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
}
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
......@@ -428,7 +446,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
int MaxBatchResults() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
return (dataset()->num_parallel_calls_ + dataset()->batch_size_ - 1) /
return (num_parallel_calls_ + dataset()->batch_size_ - 1) /
dataset()->batch_size_;
}
......@@ -480,15 +498,18 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
LOCKS_EXCLUDED(mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
new_calls.reserve(dataset()->num_parallel_calls_);
StartWork(ctx.get());
auto stop_cleanup =
gtl::MakeCleanup([this, &ctx]() { StopWork(ctx.get()); });
{
tf_shared_lock l(mu_);
new_calls.reserve(num_parallel_calls_);
}
while (true) {
{
mutex_lock l(mu_);
while (!cancelled_ &&
(num_calls_ >= dataset()->num_parallel_calls_ ||
(num_calls_ >= num_parallel_calls_ ||
batch_results_.size() > MaxBatchResults() ||
(batch_results_.size() == MaxBatchResults() &&
call_counter_ % dataset()->batch_size_ == 0))) {
......@@ -501,7 +522,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
return;
}
while (num_calls_ < dataset()->num_parallel_calls_ &&
while (num_calls_ < num_parallel_calls_ &&
(batch_results_.size() < MaxBatchResults() ||
(batch_results_.size() == MaxBatchResults() &&
call_counter_ % dataset()->batch_size_ != 0))) {
......@@ -648,6 +669,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// user specified level of parallelism and there are slots available in
// the `batch_results_` buffer.
condition_variable cond_var_;
// Identifies the maximum number of parallel calls.
int64 num_parallel_calls_ GUARDED_BY(mu_) = 0;
// Counts the number of outstanding calls for this batch.
int64 num_calls_ GUARDED_BY(mu_) = 0;
// Counts the total number of calls.
......@@ -671,7 +694,6 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
const Eigen::ThreadPoolDevice* device_; // not owned
};
const int graph_def_version_;
const int op_version_;
DataTypeVector output_types_;
std::vector<PartialTensorShape> output_shapes_;
......
......@@ -17,11 +17,14 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/dataset.h"
#include "tensorflow/core/lib/random/random.h"
#include "tensorflow/core/platform/cpu_info.h"
namespace tensorflow {
namespace data {
namespace {
const int kOptimizationPeriodThresholdMs = 60 * EnvTime::kSecondsToMicros;
class ModelDatasetOp : public UnaryDatasetOpKernel {
public:
explicit ModelDatasetOp(OpKernelConstruction* ctx)
......@@ -71,9 +74,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
class Iterator : public DatasetIterator<Dataset> {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params), model_(new model::Model()) {}
~Iterator() override { model_->OutputToFile(); }
: DatasetIterator<Dataset>(params),
model_(std::make_shared<model::Model>()) {}
Status Initialize(IteratorContext* ctx) override {
IteratorContext ctx_with_model(CreateParams(ctx));
......@@ -85,6 +87,21 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
int64 now = ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
if (last_optimization_ms_ + optimization_period_ms_ < now) {
model_->Optimize(port::NumSchedulableCPUs());
// Exponentially increase the period of running the optimization until
// a threshold is reached.
if (optimization_period_ms_ < kOptimizationPeriodThresholdMs) {
if (optimization_period_ms_ << 1 < kOptimizationPeriodThresholdMs) {
optimization_period_ms_ <<= 1;
} else {
optimization_period_ms_ = kOptimizationPeriodThresholdMs;
}
}
last_optimization_ms_ =
ctx->env()->NowMicros() / EnvTime::kMillisToMicros;
}
IteratorContext ctx_with_model(CreateParams(ctx));
return input_impl_->GetNext(&ctx_with_model, out_tensors,
end_of_sequence);
......@@ -113,6 +130,8 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
private:
mutex mu_;
std::shared_ptr<model::Model> model_;
int64 last_optimization_ms_ GUARDED_BY(mu_) = 0;
int64 optimization_period_ms_ GUARDED_BY(mu_) = 10;
std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
......
......@@ -207,7 +207,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
SetMetadata(ctx, "batch_size", dataset()->batch_size_);
AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
return dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_);
}
......
......@@ -252,7 +252,7 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
SetMetadata(ctx, "parallelism", dataset()->cycle_length_);
AddConstantParameter(ctx, "parallelism", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
......@@ -1120,7 +1120,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
int64 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
OP_REQUIRES(ctx, num_parallel_calls > 0,
OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
OP_REQUIRES(
......@@ -1233,6 +1233,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
args_list_(params.dataset->cycle_length_),
current_elements_(params.dataset->cycle_length_),
element_in_use_(params.dataset->cycle_length_, false),
num_parallel_calls_(params.dataset->num_parallel_calls_),
thread_pool_(new thread::ThreadPool(
Env::Default(), ThreadOptions(), "parallel_interleave",
dataset()->cycle_length_ /* num_threads */,
......@@ -1250,7 +1251,24 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status Initialize(IteratorContext* ctx) override {
SetMetadata(ctx, "parallelism", dataset()->num_parallel_calls_);
mutex_lock l(mu_);
if (num_parallel_calls_ == kAutoTune) {
num_parallel_calls_ = 1;
auto set_fn = [this](int64 value) {
{
mutex_lock l(mu_);
num_parallel_calls_ = value;
}
VLOG(2) << "setting parallelism knob to " << value;
cond_var_.notify_all();
};
AddTunableParameter(
ctx, "parallelism", num_parallel_calls_ /* value */, 1 /* min */,
dataset()->cycle_length_ /* max */, std::move(set_fn));
} else {
AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
}
AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
return dataset()->captured_func_->Instantiate(ctx);
......@@ -1459,7 +1477,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// not in use and there is space in the `invocation_results_` queue.
while (!cancelled_ && (!end_of_input_ || num_open_ > 0) &&
(element_in_use_[cycle_index_] ||
num_calls_ >= dataset()->num_parallel_calls_ ||
num_calls_ >= num_parallel_calls_ ||
invocation_results_.size() >= MaxInvocationResults())) {
StopWork(ctx.get());
cond_var_.wait(l);
......@@ -1472,7 +1490,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
while (!element_in_use_[cycle_index_] &&
(!end_of_input_ || num_open_ > 0) &&
num_calls_ < dataset()->num_parallel_calls_ &&
num_calls_ < num_parallel_calls_ &&
invocation_results_.size() < MaxInvocationResults()) {
if (!current_elements_[cycle_index_]) {
// Try to create a new iterator from the next input element.
......@@ -1647,6 +1665,9 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// Identifies the number of open iterators.
int64 num_open_ GUARDED_BY(mu_) = 0;
// Identifies the maximum number of parallel calls.
int64 num_parallel_calls_ GUARDED_BY(mu_) = 0;
// Identifies the number of outstanding calls.
int64 num_calls_ GUARDED_BY(mu_) = 0;
......
......@@ -55,7 +55,7 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
int32 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
OP_REQUIRES(ctx, num_parallel_calls > 0,
OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
......
......@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/platform/cpu_info.h"
namespace tensorflow {
namespace data {
......@@ -55,7 +56,25 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status Initialize(IteratorContext* ctx) override {
SetMetadata(ctx, "parallelism", num_parallel_calls_);
mutex_lock l(mu_);
if (num_parallel_calls_ == kAutoTune) {
num_parallel_calls_ = 1;
auto set_fn = [this](int64 value) {
{
mutex_lock l(mu_);
num_parallel_calls_ = value;
}
VLOG(2) << "setting parallelism knob to " << value;
cond_var_.notify_all();
};
// TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
// use it here for the maximum.
AddTunableParameter(ctx, "parallelism", num_parallel_calls_ /* value */,
1 /* min */, port::NumSchedulableCPUs() /* max */,
std::move(set_fn));
} else {
AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
}
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
if (init_func_) {
......@@ -211,8 +230,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
std::move(done));
}
int64 MaxInvocationResults() { return num_parallel_calls_; }
Status ProcessResult(const std::shared_ptr<InvocationResult>& result,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) {
......@@ -235,13 +252,16 @@ class ParallelMapIterator : public DatasetBaseIterator {
StartWork(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { StopWork(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
new_calls.reserve(num_parallel_calls_);
{
tf_shared_lock l(mu_);
new_calls.reserve(num_parallel_calls_);
}
while (true) {
{
mutex_lock l(mu_);
while (!cancelled_ &&
(num_calls_ >= num_parallel_calls_ ||
invocation_results_.size() >= MaxInvocationResults())) {
invocation_results_.size() >= num_parallel_calls_)) {
StopWork(ctx.get());
cond_var_.wait(l);
StartWork(ctx.get());
......@@ -250,7 +270,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
return;
}
while (num_calls_ < num_parallel_calls_ &&
invocation_results_.size() < MaxInvocationResults()) {
invocation_results_.size() < num_parallel_calls_) {
invocation_results_.emplace_back(new InvocationResult());
new_calls.push_back(invocation_results_.back());
num_calls_++;
......@@ -305,7 +325,6 @@ class ParallelMapIterator : public DatasetBaseIterator {
const DatasetBase* const input_dataset_; // Not owned.
const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
const int32 num_parallel_calls_;
// Used for coordination between the main thread and the runner thread.
mutex mu_;
// Used for coordination between the main thread and the runner thread. In
......@@ -314,6 +333,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
// parallelism and there are slots available in the `invocation_results_`
// buffer.
condition_variable cond_var_;
// Identifies the maximum number of parallel calls.
int64 num_parallel_calls_ GUARDED_BY(mu_) = 0;
// Counts the number of outstanding calls.
int64 num_calls_ GUARDED_BY(mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册