提交 fb83aa5d 编写于 作者: J Jiri Simsa 提交者: TensorFlower Gardener

[tf.data] Reduce the overhead of performance modeling when there are no autotunable knobs.

PiperOrigin-RevId: 225405978
上级 7b4bfd90
......@@ -50,8 +50,6 @@ class GraphDefBuilder;
class Node;
namespace data {
// A constant that can be used to enable auto-tuning.
constexpr int kAutoTune = -1;
constexpr int kInfiniteCardinality = -1;
constexpr int kUnknownCardinality = -2;
......@@ -723,36 +721,36 @@ class DatasetBaseIterator : public IteratorBase {
return model::MakeUnknownNode(std::move(args));
}
// When performance modeling is enabled, this method records the fact that
// this iterator has dequeued a element from an internal buffer.
// When modeling is enabled, this method records the fact that this iterator
// has dequeued an element from an internal buffer.
void RecordBufferDequeue(IteratorContext* ctx,
const std::vector<Tensor>& element) {
if (node_) {
if (collect_resource_usage(ctx)) {
node_->add_buffered_bytes(-GetAllocatedBytes(element));
}
}
// When performance modeling is enabled, this method records the fact that
// this iterator has enqueued a element in an internal buffer.
// When modeling is enabled, this method records the fact that this iterator
// has enqueued an element in an internal buffer.
void RecordBufferEnqueue(IteratorContext* ctx,
const std::vector<Tensor>& element) {
if (node_) {
if (collect_resource_usage(ctx)) {
node_->add_buffered_bytes(GetAllocatedBytes(element));
}
}
// When performance modeling is enabled, this method records the fact that
// this iterator has produced an element.
// When modeling is enabled, this method records the fact that this iterator
// has produced an element.
void RecordElement(IteratorContext* ctx) {
if (node_) {
node_->record_element();
}
}
// When performance modeling is enabled, this method records the fact that
// a thread of this iterator has started work.
// When modeling is enabled, this method records the fact that a thread of
// this iterator has started work.
void RecordStart(IteratorContext* ctx, bool stop_output = false) {
if (node_) {
if (collect_resource_usage(ctx)) {
int64 now_nanos = Env::Default()->NowNanos();
if (stop_output && node_->output()) {
node_->output()->record_stop(now_nanos);
......@@ -761,10 +759,10 @@ class DatasetBaseIterator : public IteratorBase {
}
}
// When performance modeling is enabled, this method records the fact that
// a thread of this iterator has stopped work.
// When modeling is enabled, this method records the fact that a thread of
// this iterator has stopped work.
void RecordStop(IteratorContext* ctx, bool start_output = false) {
if (node_) {
if (collect_resource_usage(ctx)) {
int64 now_nanos = Env::Default()->NowNanos();
node_->record_stop(now_nanos);
if (start_output && node_->output()) {
......@@ -774,6 +772,11 @@ class DatasetBaseIterator : public IteratorBase {
}
private:
inline bool collect_resource_usage(IteratorContext* ctx) {
auto model = ctx->model();
return model && model->collect_resource_usage() && node_;
}
BaseParams params_;
};
......
......@@ -356,6 +356,8 @@ std::shared_ptr<Node> Model::AddNode(Node::Factory factory, const string& name,
if (output) {
output->add_input(node);
}
collect_resource_usage_ =
collect_resource_usage_ || node->has_tunable_parameters();
lookup_table_.insert(std::make_pair(name, node));
return node;
}
......@@ -441,7 +443,7 @@ void Model::RecordElement(const string& name) {
void Model::RecordStart(const string& name, bool stop_output) {
tf_shared_lock l(mu_);
auto node = gtl::FindOrNull(lookup_table_, name);
if (node) {
if (collect_resource_usage_ && node) {
int64 now_nanos = Env::Default()->NowNanos();
if (stop_output && (*node)->output()) {
(*node)->output()->record_stop(now_nanos);
......@@ -453,7 +455,7 @@ void Model::RecordStart(const string& name, bool stop_output) {
void Model::RecordStop(const string& name, bool start_output) {
tf_shared_lock l(mu_);
auto node = gtl::FindOrNull(lookup_table_, name);
if (node) {
if (collect_resource_usage_ && node) {
int64 now_nanos = Env::Default()->NowNanos();
(*node)->record_stop(now_nanos);
if (start_output && (*node)->output()) {
......
......@@ -34,18 +34,24 @@ namespace tensorflow {
namespace data {
namespace model {
// A constant that can be used to enable auto-tuning.
constexpr int kAutoTune = -1;
// Represents thread-safe state that can be shared between an input pipeline and
// the performance model.
struct SharedState {
public:
SharedState(int64 value, std::shared_ptr<mutex> mu,
std::shared_ptr<condition_variable> cond_var)
: value(value), mu(std::move(mu)), cond_var(std::move(cond_var)) {}
: value(value),
mu(std::move(mu)),
cond_var(std::move(cond_var)),
tunable(value == kAutoTune) {}
int64 value;
std::shared_ptr<mutex> mu;
std::shared_ptr<condition_variable> cond_var;
bool tunable = false;
const bool tunable;
};
// Represents a parameter.
......@@ -136,6 +142,15 @@ class Node {
return buffered_bytes_;
}
// Indicates whether the node has tunable parameters.
bool has_tunable_parameters() const LOCKS_EXCLUDED(mu_) {
tf_shared_lock l(mu_);
for (const auto& pair : parameters_) {
if (pair.second->state->tunable) return true;
}
return false;
}
// Returns the unique node ID.
int64 id() const LOCKS_EXCLUDED(mu_) { return id_; }
......@@ -344,7 +359,10 @@ std::shared_ptr<Node> MakeUnknownNode(Node::Args args);
// implementation of `DatasetBase` and `DatasetBaseIterator` respectively.
class Model {
public:
Model() = default;
Model() : collect_resource_usage_(false) {}
// Indicates whether to collect resource usage.
bool collect_resource_usage() const { return collect_resource_usage_; }
// Adds a node with the given name and given output.
std::shared_ptr<Node> AddNode(Node::Factory factory, const string& name,
......@@ -388,6 +406,14 @@ class Model {
int64 id_counter_ GUARDED_BY(mu_) = 1;
std::shared_ptr<Node> output_ GUARDED_BY(mu_);
std::map<string, std::shared_ptr<Node>> lookup_table_ GUARDED_BY(mu_);
// Indicates whether the modeling framework should collect resource usage
// (e.g. CPU, memory). The logic for collecting this information assumes that
// the collection is not repeatedly disabled and enabled. As a consequence,
// the implementation starts collecting resource usage when it encounters a
// tunable parameter (because the information is used for for tuning the value
// of the parameter) and never stops.
std::atomic<bool> collect_resource_usage_;
};
} // namespace model
......
......@@ -71,9 +71,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
int64 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
OP_REQUIRES(
ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
bool drop_remainder;
OP_REQUIRES_OK(ctx,
......@@ -268,9 +269,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == kAutoTune) {
if (num_parallel_calls_->value == model::kAutoTune) {
num_parallel_calls_->value = ctx->runner_threadpool_size();
num_parallel_calls_->tunable = true;
}
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
......
......@@ -76,9 +76,10 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
int64 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
OP_REQUIRES(
ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
bool drop_remainder;
OP_REQUIRES_OK(ctx,
......@@ -214,9 +215,8 @@ class NumaMapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == kAutoTune) {
if (num_parallel_calls_->value == model::kAutoTune) {
num_parallel_calls_->value = ctx->runner_threadpool_size();
num_parallel_calls_->tunable = true;
}
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
......
......@@ -76,9 +76,10 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
int64 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
OP_REQUIRES(
ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
OP_REQUIRES(
ctx, num_parallel_calls <= cycle_length,
errors::InvalidArgument(
......@@ -220,9 +221,8 @@ class ParallelInterleaveDatasetOp : public UnaryDatasetOpKernel {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == kAutoTune) {
if (num_parallel_calls_->value == model::kAutoTune) {
num_parallel_calls_->value = dataset()->cycle_length_;
num_parallel_calls_->tunable = true;
}
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
......
......@@ -51,9 +51,10 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
int32 num_parallel_calls;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "num_parallel_calls",
&num_parallel_calls));
OP_REQUIRES(ctx, num_parallel_calls > 0 || num_parallel_calls == kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
OP_REQUIRES(
ctx, num_parallel_calls > 0 || num_parallel_calls == model::kAutoTune,
errors::InvalidArgument(
"num_parallel_calls must be greater than zero."));
std::unique_ptr<CapturedFunction> captured_func;
OP_REQUIRES_OK(ctx, CapturedFunction::Create(func_, ctx, "other_arguments",
......
......@@ -76,9 +76,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
Status Initialize(IteratorContext* ctx) override {
mutex_lock l(*mu_);
if (num_parallel_calls_->value == kAutoTune) {
if (num_parallel_calls_->value == model::kAutoTune) {
num_parallel_calls_->value = ctx->runner_threadpool_size();
num_parallel_calls_->tunable = true;
}
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册