提交 4e69d936 编写于 作者: D Derek Murray 提交者: TensorFlower Gardener

Rolling forward "[tf.data] Reduce locking in node processing-time calculations."

The previous version had a non-deterministic use-after-free error in "parallel_interleave_dataset_op.cc", which TSAN testing picked up. The buggy version decremented a `BlockingCounter` before calling `RecordStop()`, but the last decrement call would unblock a thread that could lead to the iterator and context being deleted before `RecordStop()` would be called. The fix is to ensure that the `BlockingCounter` is always decremented after the call to `RecordStop()`.

PiperOrigin-RevId: 306689496
Change-Id: Ic85cf7b79a96a9d3f25ca4f1043e2c82505fcca0
上级 d9061ae1
......@@ -668,6 +668,11 @@ class IteratorBase {
virtual Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) = 0;
// Returns a pointer to the node representing this iterator in the performance
// model. It may be null, if performance modeling is not enabled for this
// iterator.
std::shared_ptr<model::Node> model_node() const { return node_; }
// Returns the number of elements produced by this iterator.
int64 num_elements() const {
if (node_) return node_->num_elements();
......@@ -684,7 +689,7 @@ class IteratorBase {
const string& output_prefix);
std::vector<std::function<void()>> cleanup_fns_;
model::Node* node_ = nullptr; // Not owned.
std::shared_ptr<model::Node> node_ = nullptr;
const IteratorBase* parent_ = nullptr; // Not owned.
int64 id_ = 0;
int64 parent_id_ = 0;
......
......@@ -696,7 +696,8 @@ string Node::DebugString() const {
"\n");
strings::StrAppend(&result, " bytes_produced=", bytes_produced_.load(),
"\n");
strings::StrAppend(&result, " processing_time=", processing_time_, "\n");
strings::StrAppend(&result, " processing_time=", processing_time_.load(),
"\n");
strings::StrAppend(&result, " num_elements=", num_elements_.load(), "\n");
string inputs;
for (auto& input : inputs_) {
......@@ -735,9 +736,9 @@ std::shared_ptr<Node> Node::Snapshot(std::shared_ptr<Node> output) {
result->bytes_produced_.store(bytes_produced_);
result->num_elements_.store(num_elements_);
result->record_metrics_.store(false);
result->processing_time_.store(processing_time_);
mutex_lock l2(result->mu_);
result->parameters_ = parameters_;
result->processing_time_ = processing_time_;
}
for (auto& input : inputs_) {
result->add_input(input->Snapshot(result));
......@@ -862,7 +863,8 @@ double Node::SelfProcessingTimeLocked() const {
}
void Model::AddNode(Node::Factory factory, const string& name,
const string& output_name, Node** out_node) {
const string& output_name,
std::shared_ptr<Node>* out_node) {
// The name captures the sequence of iterators joined by `::`. We use the full
// sequence as the key in the lookup table, but only the last element of the
// sequence as the name node.
......@@ -894,15 +896,7 @@ void Model::AddNode(Node::Factory factory, const string& name,
collect_resource_usage_ =
collect_resource_usage_ || node->has_tunable_parameters();
lookup_table_.insert(std::make_pair(name, node));
*out_node = node.get();
}
void Model::AddProcessingTime(const string& name, int64 delta) {
tf_shared_lock l(mu_);
auto node = gtl::FindOrNull(lookup_table_, name);
if (node) {
(*node)->add_processing_time(delta);
}
*out_node = node;
}
void Model::FlushMetrics() {
......@@ -912,15 +906,6 @@ void Model::FlushMetrics() {
}
}
int64 Model::NumElements(const string& name) {
tf_shared_lock l(mu_);
auto node = gtl::FindOrNull(lookup_table_, name);
if (node) {
return (*node)->num_elements();
}
return 0;
}
void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
int64 ram_budget) {
switch (algorithm) {
......@@ -933,30 +918,6 @@ void Model::Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget,
}
}
void Model::RecordStart(const string& name, bool stop_output) {
tf_shared_lock l(mu_);
auto node = gtl::FindOrNull(lookup_table_, name);
if (collect_resource_usage_ && node) {
int64 now_nanos = absl::GetCurrentTimeNanos();
if (stop_output && (*node)->output()) {
(*node)->output()->record_stop(now_nanos);
}
(*node)->record_start(now_nanos);
}
}
void Model::RecordStop(const string& name, bool start_output) {
tf_shared_lock l(mu_);
auto node = gtl::FindOrNull(lookup_table_, name);
if (collect_resource_usage_ && node) {
int64 now_nanos = absl::GetCurrentTimeNanos();
(*node)->record_stop(now_nanos);
if (start_output && (*node)->output()) {
(*node)->output()->record_start(now_nanos);
}
}
}
void Model::RemoveNode(const string& name) {
mutex_lock l(mu_);
auto node = gtl::FindOrNull(lookup_table_, name);
......
......@@ -133,6 +133,7 @@ class Node {
bytes_consumed_(0),
bytes_produced_(0),
num_elements_(0),
processing_time_(0),
record_metrics_(true),
metrics_(name_),
output_(args.output.get()) {}
......@@ -147,7 +148,6 @@ class Node {
// Increments the aggregate processing time by the given delta.
void add_processing_time(int64 delta) TF_LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
processing_time_ += delta;
}
......@@ -210,7 +210,6 @@ class Node {
// Returns the aggregate processing time.
int64 processing_time() const TF_LOCKS_EXCLUDED(mu_) {
tf_shared_lock l(mu_);
return processing_time_;
}
......@@ -418,10 +417,10 @@ class Node {
std::atomic<int64> bytes_consumed_;
std::atomic<int64> bytes_produced_;
std::atomic<int64> num_elements_;
std::atomic<int64> processing_time_;
std::atomic<bool> record_metrics_;
Metrics metrics_;
std::map<string, std::shared_ptr<Parameter>> parameters_ TF_GUARDED_BY(mu_);
int64 processing_time_ TF_GUARDED_BY(mu_) = 0;
std::map<std::thread::id, int64> work_start_ TF_GUARDED_BY(mu_);
// Statistic of inputs processing time history.
......@@ -491,31 +490,16 @@ class Model {
// Adds a node with the given name and given output. The method returns
// a pointer to the node but does not transfer ownership.
void AddNode(Node::Factory factory, const string& name,
const string& output_name, Node** out_node)
TF_LOCKS_EXCLUDED(mu_);
// Increments the processing time for the given node..
void AddProcessingTime(const string& name, int64 delta)
const string& output_name, std::shared_ptr<Node>* out_node)
TF_LOCKS_EXCLUDED(mu_);
// Flushes metrics record by the model.
void FlushMetrics() TF_LOCKS_EXCLUDED(mu_);
// Returns the number of elements that the input pipeline has produced.
int64 NumElements(const string& name) TF_LOCKS_EXCLUDED(mu_);
// Uses the given algorithm to perform the autotuning optimization.
void Optimize(AutotuneAlgorithm algorithm, int64 cpu_budget, int64 ram_budget)
TF_LOCKS_EXCLUDED(mu_);
// Records that the given node has started work. If `stop_output` is set, it
// also records that the output of the given node has stopped work.
void RecordStart(const string& name, bool stop_output) TF_LOCKS_EXCLUDED(mu_);
// Records that the given node has stopped work. If `stop_output` is set, it
// also records that the output of the given node has started work.
void RecordStop(const string& name, bool start_output) TF_LOCKS_EXCLUDED(mu_);
// Removes the given node.
void RemoveNode(const string& name) TF_LOCKS_EXCLUDED(mu_);
......
......@@ -759,7 +759,8 @@ Status InstantiatedCapturedFunction::RunInstantiated(
void InstantiatedCapturedFunction::RunAsync(
IteratorContext* ctx, std::vector<Tensor>&& args, std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done, const string& prefix) const {
FunctionLibraryRuntime::DoneCallback done,
const std::shared_ptr<model::Node>& node) const {
auto& info = captured_func_->short_circuit_info();
if (!info.indices.empty()) {
// Run the `done` callback on a threadpool thread, because it will
......@@ -792,18 +793,21 @@ void InstantiatedCapturedFunction::RunAsync(
f_opts.cancellation_manager = cancellation_manager.get();
std::shared_ptr<SimpleStepStatsCollector> stats_collector;
if (ctx->model() || ctx->stats_aggregator()) {
stats_collector = absl::make_unique<SimpleStepStatsCollector>();
if (node || ctx->stats_aggregator()) {
stats_collector = std::make_shared<SimpleStepStatsCollector>();
}
const bool collect_usage =
node && ctx->model() && ctx->model()->collect_resource_usage();
f_opts.stats_collector = stats_collector.get();
// Transfer ownership of the cancellation manager to `callback`.
CancellationManager* raw_cancellation_manager =
cancellation_manager.release();
auto callback = std::bind(
[this, rets, step_container, raw_cancellation_manager, frame](
[this, rets, step_container, raw_cancellation_manager, frame, node,
collect_usage](
const FunctionLibraryRuntime::DoneCallback& done,
IteratorContext* ctx, const string& prefix,
IteratorContext* ctx,
const std::shared_ptr<SimpleStepStatsCollector>& stats_collector,
// Begin unbound arguments.
Status s) {
......@@ -813,32 +817,30 @@ void InstantiatedCapturedFunction::RunAsync(
s = frame->ConsumeRetvals(rets);
}
delete frame;
if (ctx->model()) {
if (node) {
// TODO(b/129085499) Utilize the `node_name` which would be unique
// than the prefix for the function execution time statistics.
// prefix_with_func_name would then be node_name + func_name.
if (ctx->stats_aggregator()) {
string prefix_end =
str_util::Split(prefix, "::", str_util::SkipEmpty()).back();
string prefix_with_func_name =
strings::StrCat(prefix_end, stats_utils::kDelimiter,
strings::StrCat(node->name(), stats_utils::kDelimiter,
captured_func_->func().name());
ctx->stats_aggregator()->AddToHistogram(
stats_utils::ExecutionTimeHistogramName(prefix_with_func_name),
{static_cast<float>(stats_collector->processing_time())},
ctx->model()->NumElements(prefix));
node->num_elements());
}
ctx->model()->AddProcessingTime(prefix,
stats_collector->processing_time());
ctx->model()->RecordStart(prefix, false /* stop_output */);
node->add_processing_time(stats_collector->processing_time());
}
if (collect_usage) {
node->record_start(EnvTime::NowNanos());
}
done(s);
if (ctx->model()) {
ctx->model()->RecordStop(prefix, false /* start_output */);
if (collect_usage) {
node->record_stop(EnvTime::NowNanos());
}
},
std::move(done), ctx, prefix, std::move(stats_collector),
std::placeholders::_1);
std::move(done), ctx, std::move(stats_collector), std::placeholders::_1);
profiler::TraceMe activity(
[&] {
......@@ -846,7 +848,12 @@ void InstantiatedCapturedFunction::RunAsync(
"InstantiatedCapturedFunction::RunAsync#id=", f_opts.step_id, "#");
},
profiler::TraceMeLevel::kInfo);
// Stop the usage collection before calling `Run()` because `callback` may
// be executed synchronously, and so the `node->record_start()` call within
// `callback` would violate nesting.
if (collect_usage) node->record_stop(EnvTime::NowNanos());
lib_->Run(f_opts, f_handle_, frame, std::move(callback));
if (collect_usage) node->record_start(EnvTime::NowNanos());
}
bool InstantiatedCapturedFunction::ShouldCreateRendezvous() const {
......
......@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/cancellation.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/lib/core/status.h"
......@@ -95,7 +96,7 @@ class InstantiatedCapturedFunction {
void RunAsync(IteratorContext* ctx, std::vector<Tensor>&& args,
std::vector<Tensor>* rets,
FunctionLibraryRuntime::DoneCallback done,
const string& prefix) const;
const std::shared_ptr<model::Node>& node) const;
private:
InstantiatedCapturedFunction(
......
......@@ -319,6 +319,8 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
}
void RunnerThread(IteratorContext* ctx, InvocationResult* result, int i) {
RecordStart(ctx);
auto cleanup = gtl::MakeCleanup([this, ctx]() { RecordStop(ctx); });
int64 start = EnvTime::NowNanos();
Status s = input_impls_[i]->GetNext(ctx, &result->out_tensors,
&result->end_of_sequence);
......
......@@ -439,7 +439,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
// `return_values`, and invoking `done` when finished.
instantiated_captured_func_->RunAsync(ctx.get(), std::move(input_element),
return_values.get(),
std::move(done), prefix());
std::move(done), model_node());
}
void CancelThreads(bool wait) TF_LOCKS_EXCLUDED(mu_) {
......
......@@ -351,10 +351,12 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
Status CheckExternalState() override { return Status::OK(); }
void MapFunc(IteratorContext* ctx, const string& prefix,
void MapFunc(IteratorContext* ctx,
const std::shared_ptr<model::Node>& node,
std::vector<Tensor> input, std::vector<Tensor>* output,
StatusCallback callback) override {
(*ctx->runner())([this, ctx, prefix, input, output, callback]() {
(*ctx->runner())([this, ctx, node, input, output,
callback = std::move(callback)]() {
thread::ThreadPool* device_threadpool =
ctx->flr()->device()->tensorflow_cpu_worker_threads()->workers;
std::vector<tstring> slice_vec;
......@@ -423,7 +425,7 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
stats_aggregator->IncrementCounter(
stats_utils::kFeatureValuesCount, "trainer",
feature_stats.feature_values_count);
int64 steps = ctx->model()->NumElements(prefix);
int64 steps = node ? node->num_elements() : 0;
stats_aggregator->AddToHistogram(
stats_utils::FeatureHistogramName(dataset_->node_name()),
{static_cast<double>(feature_stats.features_count)}, steps);
......
......@@ -714,6 +714,12 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// Thread responsible for launching all worker threads. The thread stays
// around after startup in case autotuning increases num_parallel_calls.
void WorkerManagerThread() TF_LOCKS_EXCLUDED(mu_) {
RecordStart(ctx_.get());
auto cleanup = gtl::MakeCleanup([this]() {
RecordStop(ctx_.get());
mutex_lock l(*mu_);
DecrementOutstandingThreads();
});
int initial_current_workers;
// When elements are moved from `future_elements_` to `current_elements_`,
// the future worker which created the element may continue to process
......@@ -748,7 +754,6 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
RecordStart(ctx_.get());
}
if (cancelled_ || end_of_input_) {
DecrementOutstandingThreads();
return;
}
IncrementOutstandingThreads();
......@@ -1323,16 +1328,19 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
for (int idx = 0; idx < size; ++idx) {
threadpool->Schedule(
[this, ctx, reader, idx, name, &s, &counter, elements] {
RecordStart(ctx);
auto cleanup = gtl::MakeCleanup([this, ctx, &counter]() {
RecordStop(ctx);
counter.DecrementCount();
});
std::shared_ptr<Element> elem;
Status ret_status = ReadElement(ctx, reader, idx, name, &elem);
mutex_lock l(*mu_);
if (!ret_status.ok()) {
s.Update(ret_status);
counter.DecrementCount();
return;
}
(*elements)[idx] = elem;
counter.DecrementCount();
});
}
counter.Wait();
......
......@@ -194,22 +194,22 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
return dataset_->captured_func_->CheckExternalState();
}
void MapFunc(IteratorContext* ctx, const string& prefix,
void MapFunc(IteratorContext* ctx, const std::shared_ptr<model::Node>& node,
std::vector<Tensor> input_element, std::vector<Tensor>* result,
StatusCallback done) override {
auto map_func = [this](IteratorContext* ctx, const string& prefix,
auto map_func = [this](IteratorContext* ctx,
const std::shared_ptr<model::Node>& node,
std::vector<Tensor> input_element,
std::vector<Tensor>* result, StatusCallback done) {
instantiated_captured_func_->RunAsync(ctx, std::move(input_element),
result, std::move(done), prefix);
result, std::move(done), node);
};
if (!dataset_->captured_func_->use_inter_op_parallelism()) {
(*ctx->runner())(std::bind(map_func, ctx, prefix,
(*ctx->runner())(std::bind(map_func, ctx, node,
std::move(input_element), result,
std::move(done)));
} else {
map_func(ctx, prefix, std::move(input_element), result,
std::move(done));
map_func(ctx, node, std::move(input_element), result, std::move(done));
}
}
......@@ -540,7 +540,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
// Apply the map function on `input_element`, storing the result in
// `result->return_values`, and invoking `done` when finished.
parallel_map_functor_->MapFunc(ctx.get(), prefix(),
parallel_map_functor_->MapFunc(ctx.get(), model_node(),
std::move(input_element),
&result->return_values, std::move(done));
}
......
......@@ -77,7 +77,8 @@ class ParallelMapFunctor {
// 2. A `std::vector<Tensor>` containing the input element.
// 3. A `std::vector<Tensor>*` to which the function will write the result.
// 4. A `StatusCallback` that should be invoked when the function is complete.
virtual void MapFunc(IteratorContext* ctx, const string& prefix,
virtual void MapFunc(IteratorContext* ctx,
const std::shared_ptr<model::Node>& node,
std::vector<Tensor> input, std::vector<Tensor>* output,
StatusCallback callback) = 0;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册