提交 b168b7e9 编写于 作者: A Armando Ugalde Velasco 提交者: TensorFlower Gardener

Use MultipleIterationsAutoScaler in DataServiceDispatcherImpl

Use MultipleIterationsAutoScaler inside the data service dispatcher implementation as follows:
- UpdateOptimalNumberOfWorkersMetric() in the maintenance thread.
- ReportProcessingTime() when receiving processing times from a WorkerHeartbeat.
- ReportTargetProcessingTime() when receiving a target processing time from a ClientHeartbeat.
- RemoveWorker() when detecting missing workers or executing MaybeRemoveTask.
- RemoveConsumer() when releasing missing clients.
- RegisterIteration() when creating a new Iteration.
- UnregisterIteration() when garbage-collecting old Iterations.

PiperOrigin-RevId: 549469531
上级 b3510df4
......@@ -418,6 +418,7 @@ cc_library(
"//tensorflow/core/data:hash_utils",
"//tensorflow/core/data:snapshot_utils",
"//tensorflow/core/data:standalone",
"//tensorflow/core/data/service:auto_scaler",
"//tensorflow/core/data/service/snapshot:file_utils",
"//tensorflow/core/data/service/snapshot:path_utils",
"//tensorflow/core/data/service/snapshot:snapshot_manager",
......
......@@ -35,6 +35,7 @@ limitations under the License.
#include "absl/time/time.h"
#include "tensorflow/core/data/dataset_utils.h"
#include "tensorflow/core/data/hash_utils.h"
#include "tensorflow/core/data/service/auto_scaler.h"
#include "tensorflow/core/data/service/common.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/credentials_factory.h"
......@@ -360,7 +361,24 @@ void DataServiceDispatcherImpl::ReportProcessingTimesFromActiveTasks(
<< " in worker with address " << worker_address
<< ". Time in nanoseconds: " << processing_time_nsec;
// TODO(armandouv): Report processing times to AutoScaler.
std::shared_ptr<const Task> task;
Status s = state_.TaskFromId(task_id, task);
if (!s.ok()) {
LOG(WARNING) << "Could not find task with id " << task_id
<< " in tf.data service dispatcher state: " << s;
continue;
}
Status auto_scaler_status = auto_scaler_.ReportProcessingTime(
task->iteration->iteration_id, worker_address,
absl::Nanoseconds(processing_time_nsec));
if (!auto_scaler_status.ok()) {
LOG_EVERY_N(WARNING, 20)
<< "Failed to report processing time for Iteration "
<< task->iteration->iteration_id << " and worker address "
<< worker_address
<< " to tf.data service AutoScaler: " << auto_scaler_status;
}
}
}
......@@ -701,6 +719,14 @@ Status DataServiceDispatcherImpl::MaybeRemoveTask(
remove_task->set_task_id(request->task_id());
TF_RETURN_IF_ERROR(Apply(update));
}
Status auto_scaler_status = auto_scaler_.RemoveWorker(
task->iteration->iteration_id, task->worker_address);
if (!auto_scaler_status.ok()) {
LOG(WARNING) << "Failed to remove worker with address "
<< task->worker_address << " for Iteration "
<< task->iteration->iteration_id
<< " from tf.data service AutoScaler: " << auto_scaler_status;
}
VLOG(1) << "Task " << task->task_id << " successfully removed";
return OkStatus();
}
......@@ -714,6 +740,13 @@ Status DataServiceDispatcherImpl::ReleaseIterationClient(
std::shared_ptr<const Iteration> iteration;
TF_RETURN_IF_ERROR(
state_.IterationForIterationClientId(iteration_client_id, iteration));
Status auto_scaler_status =
auto_scaler_.RemoveConsumer(iteration->iteration_id, iteration_client_id);
if (!auto_scaler_status.ok()) {
LOG(WARNING) << "Failed to remove consumer with ID " << iteration_client_id
<< " for Iteration " << iteration->iteration_id
<< " from tf.data service AutoScaler: " << auto_scaler_status;
}
Update update;
ReleaseIterationClientUpdate* release_iteration_client =
update.mutable_release_iteration_client();
......@@ -802,6 +835,12 @@ Status DataServiceDispatcherImpl::CreateIteration(
create_iteration->set_num_split_providers(num_split_providers);
TF_RETURN_IF_ERROR(Apply(update));
TF_RETURN_IF_ERROR(state_.IterationFromId(iteration_id, iteration));
Status auto_scaler_status = auto_scaler_.RegisterIteration(iteration_id);
if (!auto_scaler_status.ok()) {
LOG(WARNING) << "Failed to register Iteration " << iteration_id
<< " with tf.data service AutoScaler: " << auto_scaler_status;
}
return OkStatus();
}
......@@ -1035,6 +1074,21 @@ Status DataServiceDispatcherImpl::ClientHeartbeat(
response->set_block_round(iteration->pending_tasks.front().target_round);
}
VLOG(3) << "Received target processing time for iteration "
<< iteration->iteration_id << " from iteration_client_id "
<< request->iteration_client_id() << ". Time in nanoseconds: "
<< request->target_processing_time_nsec();
Status auto_scaler_status = auto_scaler_.ReportTargetProcessingTime(
iteration->iteration_id, request->iteration_client_id(),
absl::Nanoseconds(request->target_processing_time_nsec()));
if (!auto_scaler_status.ok()) {
LOG_EVERY_N(WARNING, 20)
<< "Failed to report target processing time for Iteration "
<< iteration->iteration_id << " and consumer ID "
<< request->iteration_client_id()
<< " to tf.data service AutoScaler: " << auto_scaler_status;
}
std::vector<std::shared_ptr<const Task>> tasks;
TF_RETURN_IF_ERROR(state_.TasksForIteration(iteration->iteration_id, tasks));
for (const auto& task : tasks) {
......@@ -1216,6 +1270,14 @@ void DataServiceDispatcherImpl::MaintenanceThread() {
LOG(WARNING) << "Error releasing missing clients: " << s;
}
}
{
Status s = auto_scaler_.UpdateOptimalNumberOfWorkersMetric();
if (!s.ok()) {
LOG(WARNING) << "Error updating the optimal number of workers metric "
"in tf.data service AutoScaler: "
<< s;
}
}
{
Status s = GcOldIterations();
if (!s.ok()) {
......@@ -1228,6 +1290,25 @@ void DataServiceDispatcherImpl::MaintenanceThread() {
}
}
void DataServiceDispatcherImpl::RemoveClientFromAutoScaler(int64_t client_id)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::shared_ptr<const Iteration> iteration;
Status s = state_.IterationForIterationClientId(client_id, iteration);
if (s.ok()) {
Status auto_scaler_status =
auto_scaler_.RemoveConsumer(iteration->iteration_id, client_id);
if (!auto_scaler_status.ok()) {
LOG(WARNING) << "Failed to remove consumer with ID " << client_id
<< " for Iteration " << iteration->iteration_id
<< " from tf.data service AutoScaler: "
<< auto_scaler_status;
}
} else {
LOG(WARNING) << "Could not find Iteration for client with id " << client_id
<< " in tf.data service dispatcher state: " << s;
}
}
Status DataServiceDispatcherImpl::ReleaseMissingClients()
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
int64_t now = env_->NowMicros();
......@@ -1236,6 +1317,8 @@ Status DataServiceDispatcherImpl::ReleaseMissingClients()
latest_client_heartbeats_time_[client_id] +
absl::Milliseconds(config_.client_timeout_ms())) {
LOG(INFO) << "Releasing timed-out client with id " << client_id;
RemoveClientFromAutoScaler(client_id);
Update update;
ReleaseIterationClientUpdate* release_client =
update.mutable_release_iteration_client();
......@@ -1247,6 +1330,29 @@ Status DataServiceDispatcherImpl::ReleaseMissingClients()
return OkStatus();
}
void DataServiceDispatcherImpl::RemoveWorkerFromAutoScaler(
const std::string& worker_address) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::vector<std::shared_ptr<const Task>> tasks;
Status tasks_for_worker_status = state_.TasksForWorker(worker_address, tasks);
if (tasks_for_worker_status.ok()) {
for (const auto& task : tasks) {
Status auto_scaler_status = auto_scaler_.RemoveWorker(
task->iteration->iteration_id, worker_address);
if (!auto_scaler_status.ok()) {
LOG(WARNING) << "Failed to remove worker with address "
<< worker_address << " for Iteration "
<< task->iteration->iteration_id
<< " from tf.data service AutoScaler: "
<< auto_scaler_status;
}
}
} else {
LOG(WARNING) << "Could not find tasks for worker with address "
<< worker_address << " in tf.data service dispatcher state: "
<< tasks_for_worker_status;
}
}
// TODO(b/250921378): Once snapshots have leases, inform snapshot managers.
void DataServiceDispatcherImpl::DetectMissingWorkers()
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
......@@ -1256,6 +1362,8 @@ void DataServiceDispatcherImpl::DetectMissingWorkers()
if (absl::FromUnixMicros(now) >
it->second + absl::Milliseconds(config_.worker_timeout_ms())) {
LOG(INFO) << "Lost worker " << it->first << " due to timeout";
RemoveWorkerFromAutoScaler(it->first);
latest_worker_heartbeats_time_.erase(it++);
} else {
++it;
......@@ -1276,6 +1384,14 @@ Status DataServiceDispatcherImpl::GcOldIterations()
update.mutable_garbage_collect_iteration()->set_iteration_id(
iteration->iteration_id);
TF_RETURN_IF_ERROR(state_.Apply(update));
Status auto_scaler_status =
auto_scaler_.UnregisterIteration(iteration->iteration_id);
if (!auto_scaler_status.ok()) {
LOG(WARNING) << "Failed to unregister Iteration "
<< iteration->iteration_id
<< " with tf.data service AutoScaler: "
<< auto_scaler_status;
}
LOG(INFO) << "Garbage collected iteration " << iteration->DebugString();
}
return OkStatus();
......
......@@ -25,6 +25,7 @@ limitations under the License.
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/time/time.h"
#include "tensorflow/core/data/service/auto_scaler.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/dataset_store.h"
#include "tensorflow/core/data/service/dispatcher.pb.h"
......@@ -310,8 +311,15 @@ class DataServiceDispatcherImpl {
// used when recovering state when the dispatcher starts.
Status ApplyWithoutJournaling(const Update& update)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Removes the client with `client_id` from `auto_scaler_`
void RemoveClientFromAutoScaler(int64_t client_id)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Releases iteration clients that haven't heartbeated recently.
Status ReleaseMissingClients() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Removes the worker with `worker_address` from `auto_scaler_`, which is
// potentially associated with multiple iterations.
void RemoveWorkerFromAutoScaler(const std::string& worker_address)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Checks for workers that haven't heartbeated recently and alerts the
// snapshot managers.
void DetectMissingWorkers() TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
......@@ -375,6 +383,7 @@ class DataServiceDispatcherImpl {
// Condition variable for waking up the gc thread.
condition_variable maintenance_thread_cv_;
std::unique_ptr<Thread> maintenance_thread_;
MultipleIterationsAutoScaler auto_scaler_;
TF_DISALLOW_COPY_AND_ASSIGN(DataServiceDispatcherImpl);
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册