提交 cdfcea08 编写于 作者: K Kuangyuan Chen 提交者: TensorFlower Gardener

Refactor the step id generation so that it can be reused by other library.

PiperOrigin-RevId: 564441419
上级 907fb5bd
......@@ -102,6 +102,7 @@ cc_library(
"//tensorflow/core/tfrt/mlrt/interpreter:execute",
"//tensorflow/core/tfrt/mlrt/kernel:context",
"//tensorflow/core/tfrt/runtime",
"//tensorflow/core/tfrt/runtime:step_id",
"//tensorflow/core/tfrt/runtime:stream",
"//tensorflow/core/tfrt/runtime:work_queue_interface",
"//tensorflow/core/tfrt/stubs:tfrt_native_lowering_stub",
......
......@@ -75,6 +75,7 @@ limitations under the License.
#include "tensorflow/core/tfrt/mlrt/interpreter/execute.h"
#include "tensorflow/core/tfrt/mlrt/kernel/context.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"
#include "tensorflow/core/tfrt/runtime/step_id.h"
#include "tensorflow/core/tfrt/runtime/stream.h"
#include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
#include "tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h"
......@@ -109,6 +110,11 @@ constexpr char kArgumentTypeJoiningDelimiter[] = "^";
constexpr char kFallbackInitFunction[] = "_tfrt_fallback_init";
constexpr char kResourceInitFunction[] = "_tfrt_resource_init";
StepId GetNextStepId() {
static StepIdGenerator gen;
return gen.GetNextStepId();
}
} // namespace
tensorflow::Status RunMlrtFunction(
......@@ -206,11 +212,11 @@ StatusOr<std::unique_ptr<RequestInfo>> CreateRequestInfo(
// If the user provides a work_queue, we use it for inter-op tasks.
request_id = work_queue->id();
// If the user does not provide a valid id, we need to generate one.
if (request_id == 0) request_id = tfrt::GetUniqueInt();
if (request_id == 0) request_id = GetNextStepId().id;
request_info->request_queue = work_queue;
} else {
request_id = GetNextStepId().id;
// Otherwise we use the global queue in `runtime`.
request_id = tfrt::GetUniqueInt();
TF_ASSIGN_OR_RETURN(request_info->request_queue_owner,
runtime.CreateRequestQueue(request_id));
request_info->request_queue = request_info->request_queue_owner.get();
......
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
"//tensorflow:tensorflow.default.bzl",
"get_compatible_with_portable",
)
package(
......@@ -38,6 +39,7 @@ cc_library(
cc_library(
name = "stream_ops_util_constants",
hdrs = ["stream_ops_util_constants.h"],
compatible_with = get_compatible_with_portable(),
visibility = [
"//visibility:public",
],
......
......@@ -110,6 +110,7 @@ cc_library(
hdrs = ["stream.h"],
deps = [
":channel",
":step_id",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/core/framework:tensor",
"//tensorflow/core/framework:tensor_proto_cc",
......@@ -141,6 +142,16 @@ cc_library(
],
)
cc_library(
name = "step_id",
srcs = ["step_id.cc"],
hdrs = ["step_id.h"],
deps = [
"//tensorflow/core/tfrt/kernels:stream_ops_util_constants",
"@com_google_absl//absl/strings:str_format",
],
)
tf_cc_shared_test(
name = "stream_test",
srcs = ["stream_test.cc"],
......
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tfrt/runtime/step_id.h"
#include <atomic>
#include <cstdint>
namespace tensorflow {
namespace tfrt_stub {
std::atomic<uint64_t>& GetGlobalInitialStepId() {
static std::atomic<uint64_t> global_step_id = 0;
return global_step_id;
}
TEST_ScopedInitialStepId::TEST_ScopedInitialStepId(uint64_t step_id) {
step_id_ = GetGlobalInitialStepId().exchange(step_id);
}
TEST_ScopedInitialStepId::~TEST_ScopedInitialStepId() {
GetGlobalInitialStepId().store(step_id_);
}
} // namespace tfrt_stub
} // namespace tensorflow
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TFRT_RUNTIME_STEP_ID_H_
#define TENSORFLOW_CORE_TFRT_RUNTIME_STEP_ID_H_
#include <atomic>
#include <cstdint>
#include "absl/strings/str_format.h"
#include "tensorflow/core/tfrt/kernels/stream_ops_util_constants.h"
namespace tensorflow {
namespace tfrt_stub {
// A base template for common utilities for a type safe id.
template <typename Derived>
struct SafeId {
SafeId() : id(0) {}
explicit constexpr SafeId(int64_t id) : id(id) {}
using Base = SafeId;
int64_t id;
friend bool operator==(const Derived& x, const Derived& y) {
return x.id == y.id;
}
template <typename Sink>
friend void AbslStringify(Sink& sink, const Derived& x) {
absl::Format(&sink, "%d", x.id);
}
template <typename H>
friend H AbslHashValue(H h, const Derived& x) {
return H::combine(std::move(h), x.id);
}
};
// A type-safe step id.
struct StepId : SafeId<StepId> {
using Base::Base;
bool valid() const { return id != 0; }
static constexpr StepId GetInvalidStepId() { return StepId(0); }
};
// The initial value of the step id.
std::atomic<uint64_t>& GetGlobalInitialStepId();
// StepIdGenerator provides the utility to generate a monotonically increasing
// step id. And the number of bits can be configured at compile time. The step
// id is positive and the maximum value is 2^(kStepIdBitSize)-1.
class StepIdGenerator {
public:
StepIdGenerator() : next_id_(GetGlobalInitialStepId().load()) {}
StepIdGenerator(const StepIdGenerator&) = delete;
StepIdGenerator& operator=(const StepIdGenerator&) = delete;
// Generates a positive step id that is within the bit-range specified by
// `kStepIdBitSize`.
StepId GetNextStepId() {
uint64_t next_id = next_id_.fetch_add(1, std::memory_order_relaxed);
// Use kStepIdBitSize bits because we need to pack it with batch id if batch
// function is used.
static_assert(kStepIdBitSize <= 32);
next_id = (next_id & ((1ull << kStepIdBitSize) - 1));
if (next_id == 0) {
return GetNextStepId();
}
return StepId(static_cast<int64_t>(next_id));
}
private:
std::atomic<uint64_t> next_id_{0};
};
// Set up the initial step_id used by StepIdGenerator. This class is
// test-only.
class TEST_ScopedInitialStepId {
public:
explicit TEST_ScopedInitialStepId(uint64_t step_id);
~TEST_ScopedInitialStepId();
TEST_ScopedInitialStepId(const TEST_ScopedInitialStepId&) = delete;
TEST_ScopedInitialStepId& operator=(const TEST_ScopedInitialStepId&) = delete;
private:
uint64_t step_id_ = 0;
};
} // namespace tfrt_stub
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TFRT_RUNTIME_STEP_ID_H_
......@@ -37,35 +37,12 @@ under the License.
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/tfrt/runtime/channel.h"
#include "tensorflow/core/tfrt/runtime/step_id.h"
#include "tsl/platform/env.h"
namespace tensorflow {
namespace tfrt_stub {
template <typename Derived>
struct SafeId {
SafeId() : id(0) {}
explicit constexpr SafeId(int64_t id) : id(id) {}
using Base = SafeId;
int64_t id;
friend bool operator==(const Derived& x, const Derived& y) {
return x.id == y.id;
}
template <typename Sink>
friend void AbslStringify(Sink& sink, const Derived& x) {
absl::Format(&sink, "%d", x.id);
}
template <typename H>
friend H AbslHashValue(H h, const Derived& x) {
return H::combine(std::move(h), x.id);
}
};
struct StreamedResult {
absl::flat_hash_map<std::string, tensorflow::Tensor> tensors;
absl::Time enqueued_time;
......@@ -75,13 +52,6 @@ struct StreamCallbackId : SafeId<StreamCallbackId> {
using Base::Base;
};
struct StepId : SafeId<StepId> {
using Base::Base;
bool valid() const { return id != 0; }
static constexpr StepId GetInvalidStepId() { return StepId(0); }
};
// An interface that abstracts communication between the
// `StreamCallbackRegistry` and the stream controller backend.
class StreamControllerInterface {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册