Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
cdfcea08
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
cdfcea08
编写于
9月 11, 2023
作者:
K
Kuangyuan Chen
提交者:
TensorFlower Gardener
9月 11, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refactor the step id generation so that it can be reused by other library.
PiperOrigin-RevId: 564441419
上级
907fb5bd
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
172 addition
and
35 deletion
+172
-35
tensorflow/core/tfrt/graph_executor/BUILD
tensorflow/core/tfrt/graph_executor/BUILD
+1
-0
tensorflow/core/tfrt/graph_executor/graph_executor.cc
tensorflow/core/tfrt/graph_executor/graph_executor.cc
+8
-2
tensorflow/core/tfrt/kernels/BUILD
tensorflow/core/tfrt/kernels/BUILD
+4
-2
tensorflow/core/tfrt/runtime/BUILD
tensorflow/core/tfrt/runtime/BUILD
+11
-0
tensorflow/core/tfrt/runtime/step_id.cc
tensorflow/core/tfrt/runtime/step_id.cc
+37
-0
tensorflow/core/tfrt/runtime/step_id.h
tensorflow/core/tfrt/runtime/step_id.h
+110
-0
tensorflow/core/tfrt/runtime/stream.h
tensorflow/core/tfrt/runtime/stream.h
+1
-31
未找到文件。
tensorflow/core/tfrt/graph_executor/BUILD
浏览文件 @
cdfcea08
...
...
@@ -102,6 +102,7 @@ cc_library(
"//tensorflow/core/tfrt/mlrt/interpreter:execute"
,
"//tensorflow/core/tfrt/mlrt/kernel:context"
,
"//tensorflow/core/tfrt/runtime"
,
"//tensorflow/core/tfrt/runtime:step_id"
,
"//tensorflow/core/tfrt/runtime:stream"
,
"//tensorflow/core/tfrt/runtime:work_queue_interface"
,
"//tensorflow/core/tfrt/stubs:tfrt_native_lowering_stub"
,
...
...
tensorflow/core/tfrt/graph_executor/graph_executor.cc
浏览文件 @
cdfcea08
...
...
@@ -75,6 +75,7 @@ limitations under the License.
#include "tensorflow/core/tfrt/mlrt/interpreter/execute.h"
#include "tensorflow/core/tfrt/mlrt/kernel/context.h"
#include "tensorflow/core/tfrt/runtime/runtime.h"
#include "tensorflow/core/tfrt/runtime/step_id.h"
#include "tensorflow/core/tfrt/runtime/stream.h"
#include "tensorflow/core/tfrt/runtime/work_queue_interface.h"
#include "tensorflow/core/tfrt/stubs/tfrt_native_lowering_stub.h"
...
...
@@ -109,6 +110,11 @@ constexpr char kArgumentTypeJoiningDelimiter[] = "^";
constexpr
char
kFallbackInitFunction
[]
=
"_tfrt_fallback_init"
;
constexpr
char
kResourceInitFunction
[]
=
"_tfrt_resource_init"
;
StepId
GetNextStepId
()
{
static
StepIdGenerator
gen
;
return
gen
.
GetNextStepId
();
}
}
// namespace
tensorflow
::
Status
RunMlrtFunction
(
...
...
@@ -206,11 +212,11 @@ StatusOr<std::unique_ptr<RequestInfo>> CreateRequestInfo(
// If the user provides a work_queue, we use it for inter-op tasks.
request_id
=
work_queue
->
id
();
// If the user does not provide a valid id, we need to generate one.
if
(
request_id
==
0
)
request_id
=
tfrt
::
GetUniqueInt
()
;
if
(
request_id
==
0
)
request_id
=
GetNextStepId
().
id
;
request_info
->
request_queue
=
work_queue
;
}
else
{
request_id
=
GetNextStepId
().
id
;
// Otherwise we use the global queue in `runtime`.
request_id
=
tfrt
::
GetUniqueInt
();
TF_ASSIGN_OR_RETURN
(
request_info
->
request_queue_owner
,
runtime
.
CreateRequestQueue
(
request_id
));
request_info
->
request_queue
=
request_info
->
request_queue_owner
.
get
();
...
...
tensorflow/core/tfrt/kernels/BUILD
浏览文件 @
cdfcea08
load
(
"//tensorflow:tensorflow.bzl"
,
"tf_cc_test"
)
load
(
"//tensorflow:tensorflow.bzl"
,
"
tf_cc_test
"
,
"//tensorflow:tensorflow.
default.
bzl"
,
"
get_compatible_with_portable
"
,
)
package
(
...
...
@@ -38,6 +39,7 @@ cc_library(
cc_library
(
name
=
"stream_ops_util_constants"
,
hdrs
=
[
"stream_ops_util_constants.h"
],
compatible_with
=
get_compatible_with_portable
(),
visibility
=
[
"//visibility:public"
,
],
...
...
tensorflow/core/tfrt/runtime/BUILD
浏览文件 @
cdfcea08
...
...
@@ -110,6 +110,7 @@ cc_library(
hdrs
=
[
"stream.h"
],
deps
=
[
":channel"
,
":step_id"
,
"//tensorflow/compiler/mlir/tensorflow"
,
"//tensorflow/core/framework:tensor"
,
"//tensorflow/core/framework:tensor_proto_cc"
,
...
...
@@ -141,6 +142,16 @@ cc_library(
],
)
cc_library
(
name
=
"step_id"
,
srcs
=
[
"step_id.cc"
],
hdrs
=
[
"step_id.h"
],
deps
=
[
"//tensorflow/core/tfrt/kernels:stream_ops_util_constants"
,
"@com_google_absl//absl/strings:str_format"
,
],
)
tf_cc_shared_test
(
name
=
"stream_test"
,
srcs
=
[
"stream_test.cc"
],
...
...
tensorflow/core/tfrt/runtime/step_id.cc
0 → 100644
浏览文件 @
cdfcea08
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tfrt/runtime/step_id.h"
#include <atomic>
#include <cstdint>
namespace
tensorflow
{
namespace
tfrt_stub
{
std
::
atomic
<
uint64_t
>&
GetGlobalInitialStepId
()
{
static
std
::
atomic
<
uint64_t
>
global_step_id
=
0
;
return
global_step_id
;
}
TEST_ScopedInitialStepId
::
TEST_ScopedInitialStepId
(
uint64_t
step_id
)
{
step_id_
=
GetGlobalInitialStepId
().
exchange
(
step_id
);
}
TEST_ScopedInitialStepId
::~
TEST_ScopedInitialStepId
()
{
GetGlobalInitialStepId
().
store
(
step_id_
);
}
}
// namespace tfrt_stub
}
// namespace tensorflow
tensorflow/core/tfrt/runtime/step_id.h
0 → 100644
浏览文件 @
cdfcea08
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TFRT_RUNTIME_STEP_ID_H_
#define TENSORFLOW_CORE_TFRT_RUNTIME_STEP_ID_H_
#include <atomic>
#include <cstdint>
#include "absl/strings/str_format.h"
#include "tensorflow/core/tfrt/kernels/stream_ops_util_constants.h"
namespace
tensorflow
{
namespace
tfrt_stub
{
// A base template for common utilities for a type safe id.
template
<
typename
Derived
>
struct
SafeId
{
SafeId
()
:
id
(
0
)
{}
explicit
constexpr
SafeId
(
int64_t
id
)
:
id
(
id
)
{}
using
Base
=
SafeId
;
int64_t
id
;
friend
bool
operator
==
(
const
Derived
&
x
,
const
Derived
&
y
)
{
return
x
.
id
==
y
.
id
;
}
template
<
typename
Sink
>
friend
void
AbslStringify
(
Sink
&
sink
,
const
Derived
&
x
)
{
absl
::
Format
(
&
sink
,
"%d"
,
x
.
id
);
}
template
<
typename
H
>
friend
H
AbslHashValue
(
H
h
,
const
Derived
&
x
)
{
return
H
::
combine
(
std
::
move
(
h
),
x
.
id
);
}
};
// A type-safe step id.
struct
StepId
:
SafeId
<
StepId
>
{
using
Base
::
Base
;
bool
valid
()
const
{
return
id
!=
0
;
}
static
constexpr
StepId
GetInvalidStepId
()
{
return
StepId
(
0
);
}
};
// The initial value of the step id.
std
::
atomic
<
uint64_t
>&
GetGlobalInitialStepId
();
// StepIdGenerator provides the utility to generate a monotonically increasing
// step id. And the number of bits can be configured at compile time. The step
// id is positive and the maximum value is 2^(kStepIdBitSize)-1.
class
StepIdGenerator
{
public:
StepIdGenerator
()
:
next_id_
(
GetGlobalInitialStepId
().
load
())
{}
StepIdGenerator
(
const
StepIdGenerator
&
)
=
delete
;
StepIdGenerator
&
operator
=
(
const
StepIdGenerator
&
)
=
delete
;
// Generates a positive step id that is within the bit-range specified by
// `kStepIdBitSize`.
StepId
GetNextStepId
()
{
uint64_t
next_id
=
next_id_
.
fetch_add
(
1
,
std
::
memory_order_relaxed
);
// Use kStepIdBitSize bits because we need to pack it with batch id if batch
// function is used.
static_assert
(
kStepIdBitSize
<=
32
);
next_id
=
(
next_id
&
((
1ull
<<
kStepIdBitSize
)
-
1
));
if
(
next_id
==
0
)
{
return
GetNextStepId
();
}
return
StepId
(
static_cast
<
int64_t
>
(
next_id
));
}
private:
std
::
atomic
<
uint64_t
>
next_id_
{
0
};
};
// Set up the initial step_id used by StepIdGenerator. This class is
// test-only.
class
TEST_ScopedInitialStepId
{
public:
explicit
TEST_ScopedInitialStepId
(
uint64_t
step_id
);
~
TEST_ScopedInitialStepId
();
TEST_ScopedInitialStepId
(
const
TEST_ScopedInitialStepId
&
)
=
delete
;
TEST_ScopedInitialStepId
&
operator
=
(
const
TEST_ScopedInitialStepId
&
)
=
delete
;
private:
uint64_t
step_id_
=
0
;
};
}
// namespace tfrt_stub
}
// namespace tensorflow
#endif // TENSORFLOW_CORE_TFRT_RUNTIME_STEP_ID_H_
tensorflow/core/tfrt/runtime/stream.h
浏览文件 @
cdfcea08
...
...
@@ -37,35 +37,12 @@ under the License.
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/tfrt/runtime/channel.h"
#include "tensorflow/core/tfrt/runtime/step_id.h"
#include "tsl/platform/env.h"
namespace
tensorflow
{
namespace
tfrt_stub
{
template
<
typename
Derived
>
struct
SafeId
{
SafeId
()
:
id
(
0
)
{}
explicit
constexpr
SafeId
(
int64_t
id
)
:
id
(
id
)
{}
using
Base
=
SafeId
;
int64_t
id
;
friend
bool
operator
==
(
const
Derived
&
x
,
const
Derived
&
y
)
{
return
x
.
id
==
y
.
id
;
}
template
<
typename
Sink
>
friend
void
AbslStringify
(
Sink
&
sink
,
const
Derived
&
x
)
{
absl
::
Format
(
&
sink
,
"%d"
,
x
.
id
);
}
template
<
typename
H
>
friend
H
AbslHashValue
(
H
h
,
const
Derived
&
x
)
{
return
H
::
combine
(
std
::
move
(
h
),
x
.
id
);
}
};
struct
StreamedResult
{
absl
::
flat_hash_map
<
std
::
string
,
tensorflow
::
Tensor
>
tensors
;
absl
::
Time
enqueued_time
;
...
...
@@ -75,13 +52,6 @@ struct StreamCallbackId : SafeId<StreamCallbackId> {
using
Base
::
Base
;
};
struct
StepId
:
SafeId
<
StepId
>
{
using
Base
::
Base
;
bool
valid
()
const
{
return
id
!=
0
;
}
static
constexpr
StepId
GetInvalidStepId
()
{
return
StepId
(
0
);
}
};
// An interface that abstracts communication between the
// `StreamCallbackRegistry` and the stream controller backend.
class
StreamControllerInterface
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录