Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
3689c213
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
3689c213
编写于
1月 12, 2017
作者:
P
Peter Hawkins
提交者:
TensorFlower Gardener
1月 12, 2017
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XLA] Add support for multiple computations to CompileAheadOfTime.
Change: 144362931
上级
725da748
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
207 addition
and
153 deletion
+207
-153
tensorflow/compiler/aot/compile.cc
tensorflow/compiler/aot/compile.cc
+15
-12
tensorflow/compiler/xla/client/local_client.cc
tensorflow/compiler/xla/client/local_client.cc
+17
-6
tensorflow/compiler/xla/client/local_client.h
tensorflow/compiler/xla/client/local_client.h
+17
-10
tensorflow/compiler/xla/service/compiler.h
tensorflow/compiler/xla/service/compiler.h
+5
-4
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+74
-65
tensorflow/compiler/xla/service/cpu/cpu_compiler.h
tensorflow/compiler/xla/service/cpu/cpu_compiler.h
+5
-4
tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
+5
-4
tensorflow/compiler/xla/service/gpu/gpu_compiler.h
tensorflow/compiler/xla/service/gpu/gpu_compiler.h
+5
-4
tensorflow/compiler/xla/service/local_service.cc
tensorflow/compiler/xla/service/local_service.cc
+40
-32
tensorflow/compiler/xla/service/local_service.h
tensorflow/compiler/xla/service/local_service.h
+15
-7
tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
...orflow/compiler/xla/tests/local_client_aot_test_helper.cc
+9
-5
未找到文件。
tensorflow/compiler/aot/compile.cc
浏览文件 @
3689c213
...
...
@@ -204,23 +204,23 @@ Status RewriteAndPruneGraph(Graph* graph, const Config& config,
string
feed_id
;
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
n
->
def
(),
kFeedIdAttr
,
&
feed_id
));
if
(
missing_feeds
.
erase
(
feed_id
)
==
0
)
{
return
errors
::
Aborted
(
kArgOp
,
" node found with unknown feed id: "
,
feed_id
);
return
errors
::
Aborted
(
kArgOp
,
" node found with unknown feed id: "
,
feed_id
);
}
}
else
if
(
n
->
type_string
()
==
kRetvalOp
)
{
string
fetch_id
;
TF_RETURN_IF_ERROR
(
GetNodeAttr
(
n
->
def
(),
kFetchIdAttr
,
&
fetch_id
));
if
(
missing_fetches
.
erase
(
fetch_id
)
==
0
)
{
return
errors
::
Aborted
(
kRetvalOp
,
" node found with unknown fetch id: "
,
fetch_id
);
return
errors
::
Aborted
(
kRetvalOp
,
" node found with unknown fetch id: "
,
fetch_id
);
}
}
}
if
(
!
missing_feeds
.
empty
()
||
!
missing_fetches
.
empty
())
{
return
errors
::
Aborted
(
"Post graph-pruning"
,
", missing feeds: "
,
str_util
::
Join
(
missing_feeds
,
", "
)
,
", missing fetches: "
,
str_util
::
Join
(
missing_fetches
,
", "
));
return
errors
::
Aborted
(
"Post graph-pruning"
,
", missing feeds: "
,
str_util
::
Join
(
missing_feeds
,
", "
)
,
", missing fetches: "
,
str_util
::
Join
(
missing_fetches
,
", "
));
}
return
Status
::
OK
();
}
...
...
@@ -351,16 +351,19 @@ Status CompileXla(xla::LocalClient* client, const xla::Computation& computation,
for
(
int
i
=
0
;
i
<
pshape
->
parameters_size
();
++
i
)
{
arg_layouts
.
push_back
(
pshape
->
mutable_parameters
(
i
));
}
xla
::
StatusOr
<
std
::
unique_ptr
<
xla
::
AotCompilationResult
>>
aot_or
=
client
->
CompileAheadOfTime
(
computation
,
arg_layouts
,
pshape
->
result
(),
aot_opts
);
xla
::
LocalClient
::
AheadOfTimeComputationInstance
instance
;
instance
.
computation
=
&
computation
;
instance
.
argument_layouts
=
std
::
move
(
arg_layouts
);
instance
.
result_layout
=
&
pshape
->
result
();
xla
::
StatusOr
<
std
::
vector
<
std
::
unique_ptr
<
xla
::
AotCompilationResult
>>>
aot_or
=
client
->
CompileAheadOfTime
({
instance
},
aot_opts
);
if
(
!
aot_or
.
ok
())
{
return
errors
::
Unknown
(
"XLA compilation failed: "
,
aot_or
.
status
().
error_message
());
}
compile_result
->
aot
=
xla
::
unique_ptr_static_cast
<
xla
::
cpu
::
CpuAotCompilationResult
>
(
aot_or
.
ConsumeValueOrDie
(
));
std
::
move
(
aot_or
.
ValueOrDie
().
back
()
));
compile_result
->
entry_point
=
aot_opts
.
entry_point_name
();
compile_result
->
pointer_size
=
xla
::
LocalClient
::
PointerSizeForTriple
(
aot_opts
.
triple
());
...
...
tensorflow/compiler/xla/client/local_client.cc
浏览文件 @
3689c213
...
...
@@ -314,12 +314,23 @@ tensorflow::Status LocalClient::ExecuteLocally(
options
,
result
);
}
StatusOr
<
std
::
unique_ptr
<
AotCompilationResult
>>
LocalClient
::
CompileAheadOfTime
(
const
Computation
&
computation
,
const
tensorflow
::
gtl
::
ArraySlice
<
const
Shape
*>
argument_layouts
,
const
Shape
&
result_layout
,
const
AotCompilationOptions
&
options
)
{
return
local_service_
->
CompileAheadOfTime
(
computation
.
handle
(),
argument_layouts
,
result_layout
,
options
);
StatusOr
<
std
::
vector
<
std
::
unique_ptr
<
AotCompilationResult
>>>
LocalClient
::
CompileAheadOfTime
(
const
tensorflow
::
gtl
::
ArraySlice
<
AheadOfTimeComputationInstance
>
computations
,
const
AotCompilationOptions
&
options
)
{
std
::
vector
<
LocalService
::
AheadOfTimeComputationInstance
>
service_instances
;
service_instances
.
reserve
(
computations
.
size
());
for
(
const
AheadOfTimeComputationInstance
&
instance
:
computations
)
{
service_instances
.
push_back
({});
LocalService
::
AheadOfTimeComputationInstance
&
service_instance
=
service_instances
.
back
();
TF_RET_CHECK
(
instance
.
computation
!=
nullptr
);
service_instance
.
computation
=
instance
.
computation
->
handle
();
service_instance
.
argument_layouts
=
instance
.
argument_layouts
;
service_instance
.
result_layout
=
instance
.
result_layout
;
}
return
local_service_
->
CompileAheadOfTime
(
service_instances
,
options
);
}
int64
LocalClient
::
PointerSizeForTriple
(
tensorflow
::
StringPiece
target_triple
)
{
...
...
tensorflow/compiler/xla/client/local_client.h
浏览文件 @
3689c213
...
...
@@ -219,19 +219,26 @@ class LocalClient : public Client {
const
tensorflow
::
gtl
::
ArraySlice
<
const
Shape
*>
argument_layouts
,
const
ExecutableBuildOptions
&
options
);
// Compiles the computation for ahead-of-time execution. This is intended for
// use in static compilation. The |argument_layouts| parameter is used to
// inform the compiler of the expected layout for arguments while
// |result_layout| is used to signal the layout of the result. The |options|
// parameter is used to request which target the compiler should emit code
// for.
// A description of a computation to compile using CompileAheadOfTime.
struct
AheadOfTimeComputationInstance
{
const
Computation
*
computation
;
// Inform the compiler of the expected layout for arguments.
std
::
vector
<
const
Shape
*>
argument_layouts
;
// Specifies the expected result layout.
const
Shape
*
result_layout
;
};
// Compiles a list of computations for ahead-of-time execution. This is
// intended for use in static compilation. The |options| parameter describes
// the target for which the compiler should emit code.
//
// TODO(b/31222190): This doesn't really belong in LocalClient. Move it to its
// own library.
StatusOr
<
std
::
unique_ptr
<
AotCompilationResult
>>
CompileAheadOfTime
(
const
Computation
&
computation
,
const
tensorflow
::
gtl
::
ArraySlice
<
const
Shape
*>
argument_layouts
,
const
Shape
&
result_layout
,
const
AotCompilationOptions
&
options
);
StatusOr
<
std
::
vector
<
std
::
unique_ptr
<
AotCompilationResult
>>>
CompileAheadOfTime
(
const
tensorflow
::
gtl
::
ArraySlice
<
AheadOfTimeComputationInstance
>
computations
,
const
AotCompilationOptions
&
options
);
// Returns the size of a pointer in bytes for a given triple.
static
int64
PointerSizeForTriple
(
tensorflow
::
StringPiece
triple
);
...
...
tensorflow/compiler/xla/service/compiler.h
浏览文件 @
3689c213
...
...
@@ -128,10 +128,11 @@ class Compiler {
// Compiles the HLO module for ahead-of-time execution. This is intended for
// use in static compilation.
virtual
StatusOr
<
std
::
unique_ptr
<
AotCompilationResult
>>
CompileAheadOfTime
(
std
::
unique_ptr
<
HloModule
>
module
,
std
::
unique_ptr
<
HloModuleConfig
>
module_config
,
HloDumper
dump_hlo
,
const
AotCompilationOptions
&
options
)
=
0
;
virtual
StatusOr
<
std
::
vector
<
std
::
unique_ptr
<
AotCompilationResult
>>>
CompileAheadOfTime
(
std
::
vector
<
std
::
unique_ptr
<
HloModule
>>
module
,
std
::
vector
<
std
::
unique_ptr
<
HloModuleConfig
>>
module_config
,
HloDumper
dump_hlo
,
const
AotCompilationOptions
&
options
)
=
0
;
/////
// The Compiler class also serves as a point to register compiler objects
...
...
tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
浏览文件 @
3689c213
...
...
@@ -478,10 +478,13 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> CpuCompiler::Compile(
"Compilation of multiple HLO modules is not yet supported on CPU."
);
}
StatusOr
<
std
::
unique_ptr
<
AotCompilationResult
>>
CpuCompiler
::
CompileAheadOfTime
(
std
::
unique_ptr
<
HloModule
>
hlo_module
,
std
::
unique_ptr
<
HloModuleConfig
>
module_config
,
HloDumper
dump_hlo
,
const
AotCompilationOptions
&
aot_options
)
{
StatusOr
<
std
::
vector
<
std
::
unique_ptr
<
AotCompilationResult
>>>
CpuCompiler
::
CompileAheadOfTime
(
std
::
vector
<
std
::
unique_ptr
<
HloModule
>>
hlo_modules
,
std
::
vector
<
std
::
unique_ptr
<
HloModuleConfig
>>
module_configs
,
HloDumper
dump_hlo
,
const
AotCompilationOptions
&
aot_options
)
{
TF_RET_CHECK
(
hlo_modules
.
size
()
==
module_configs
.
size
());
if
(
aot_options
.
PlatformId
()
!=
se
::
host
::
kHostPlatformId
)
{
return
InvalidArgument
(
"Incompatible AOT compilation platform"
);
}
...
...
@@ -549,72 +552,78 @@ StatusOr<std::unique_ptr<AotCompilationResult>> CpuCompiler::CompileAheadOfTime(
const
llvm
::
DataLayout
&
data_layout
=
llvm_module
.
getDataLayout
();
int64
pointer_size
=
data_layout
.
getPointerSize
();
TF_RETURN_IF_ERROR
(
RunHloPasses
(
hlo_module
.
get
(),
module_config
.
get
(),
dump_hlo
));
std
::
vector
<
std
::
unique_ptr
<
AotCompilationResult
>>
results
;
for
(
int
i
=
0
;
i
<
hlo_modules
.
size
();
++
i
)
{
HloModule
*
hlo_module
=
hlo_modules
[
i
].
get
();
HloModuleConfig
*
module_config
=
module_configs
[
i
].
get
();
SequentialHloOrdering
::
HloModuleSequence
module_sequence
=
CreateModuleSequence
(
hlo_module
.
get
());
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN
(
std
::
unique_ptr
<
BufferAssignment
>
assignment
,
BufferAssigner
::
Run
(
hlo_module
.
get
(),
MakeUnique
<
SequentialHloOrdering
>
(
hlo_module
.
get
(),
module_sequence
),
pointer_size
));
IrEmitter
ir_emitter
(
*
hlo_module
,
*
module_config
,
*
assignment
,
&
llvm_module
,
/*hlo_to_profile_idx=*/
nullptr
);
HloComputation
*
computation
=
hlo_module
->
entry_computation
();
for
(
auto
embedded_computation
:
computation
->
MakeEmbeddedComputationsList
())
{
TF_RETURN_IF_ERROR
(
ir_emitter
.
EmitComputation
(
embedded_computation
,
embedded_computation
->
name
(),
/*is_entry_computation=*/
false
,
&
module_sequence
.
at
(
embedded_computation
))
.
status
());
}
const
string
&
entry_point_name
=
options
.
entry_point_name
();
TF_ASSIGN_OR_RETURN
(
llvm
::
Function
*
entry_function
,
ir_emitter
.
EmitComputation
(
computation
,
entry_point_name
,
/*is_entry_computation=*/
true
));
entry_function
->
setName
(
llvm_ir
::
AsStringRef
(
entry_point_name
));
Disassembler
disassembler
(
*
target_machine
);
CompilerFunctor
compiler_functor
(
target_machine
.
get
(),
&
disassembler
,
opt_level
,
CompilerFunctor
::
AllIntrinsics
());
llvm
::
object
::
OwningBinary
<
llvm
::
object
::
ObjectFile
>
object_file
=
compiler_functor
(
llvm_module
);
llvm
::
StringRef
object_file_data_ref
=
object_file
.
getBinary
()
->
getData
();
ObjectFileData
object_file_data
(
object_file_data_ref
.
begin
(),
object_file_data_ref
.
end
());
BufferSizes
buffer_sizes
;
for
(
const
BufferAllocation
&
allocation
:
assignment
->
Allocations
())
{
// Callers don't need to allocate temporary buffers for parameters.
if
(
allocation
.
is_entry_computation_parameter
())
{
buffer_sizes
.
push_back
(
-
1
);
continue
;
TF_RETURN_IF_ERROR
(
RunHloPasses
(
hlo_module
,
module_config
,
dump_hlo
));
SequentialHloOrdering
::
HloModuleSequence
module_sequence
=
CreateModuleSequence
(
hlo_module
);
// Run buffer analysis on the HLO graph. This analysis figures out which
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN
(
std
::
unique_ptr
<
BufferAssignment
>
assignment
,
BufferAssigner
::
Run
(
hlo_module
,
MakeUnique
<
SequentialHloOrdering
>
(
hlo_module
,
module_sequence
),
pointer_size
));
IrEmitter
ir_emitter
(
*
hlo_module
,
*
module_config
,
*
assignment
,
&
llvm_module
,
/*hlo_to_profile_idx=*/
nullptr
);
HloComputation
*
computation
=
hlo_module
->
entry_computation
();
for
(
auto
embedded_computation
:
computation
->
MakeEmbeddedComputationsList
())
{
TF_RETURN_IF_ERROR
(
ir_emitter
.
EmitComputation
(
embedded_computation
,
embedded_computation
->
name
(),
/*is_entry_computation=*/
false
,
&
module_sequence
.
at
(
embedded_computation
))
.
status
());
}
// Callers don't need to allocate anything for thread-local temporary
// buffers. They are lowered to allocas.
if
(
allocation
.
is_thread_local
())
{
buffer_sizes
.
push_back
(
-
1
);
continue
;
const
string
&
entry_point_name
=
options
.
entry_point_name
();
TF_ASSIGN_OR_RETURN
(
llvm
::
Function
*
entry_function
,
ir_emitter
.
EmitComputation
(
computation
,
entry_point_name
,
/*is_entry_computation=*/
true
));
entry_function
->
setName
(
llvm_ir
::
AsStringRef
(
entry_point_name
));
Disassembler
disassembler
(
*
target_machine
);
CompilerFunctor
compiler_functor
(
target_machine
.
get
(),
&
disassembler
,
opt_level
,
CompilerFunctor
::
AllIntrinsics
());
llvm
::
object
::
OwningBinary
<
llvm
::
object
::
ObjectFile
>
object_file
=
compiler_functor
(
llvm_module
);
llvm
::
StringRef
object_file_data_ref
=
object_file
.
getBinary
()
->
getData
();
ObjectFileData
object_file_data
(
object_file_data_ref
.
begin
(),
object_file_data_ref
.
end
());
BufferSizes
buffer_sizes
;
for
(
const
BufferAllocation
&
allocation
:
assignment
->
Allocations
())
{
// Callers don't need to allocate temporary buffers for parameters.
if
(
allocation
.
is_entry_computation_parameter
())
{
buffer_sizes
.
push_back
(
-
1
);
continue
;
}
// Callers don't need to allocate anything for thread-local temporary
// buffers. They are lowered to allocas.
if
(
allocation
.
is_thread_local
())
{
buffer_sizes
.
push_back
(
-
1
);
continue
;
}
buffer_sizes
.
push_back
(
allocation
.
size
());
}
buffer_sizes
.
push_back
(
allocation
.
size
());
}
TF_ASSIGN_OR_RETURN
(
const
BufferAllocation
*
result_allocation
,
assignment
->
GetUniqueTopLevelOutputAllocation
());
TF_ASSIGN_OR_RETURN
(
const
BufferAllocation
*
result_allocation
,
assignment
->
GetUniqueTopLevelOutputAllocation
());
return
std
::
unique_ptr
<
AotCompilationResult
>
(
MakeUnique
<
CpuAotCompilationResult
>
(
std
::
move
(
object_file_data
),
std
::
move
(
buffer_sizes
),
result_allocation
->
index
()));
results
.
emplace_back
(
MakeUnique
<
CpuAotCompilationResult
>
(
std
::
move
(
object_file_data
),
std
::
move
(
buffer_sizes
),
result_allocation
->
index
()));
}
return
std
::
move
(
results
);
}
se
::
Platform
::
Id
CpuCompiler
::
PlatformId
()
const
{
...
...
tensorflow/compiler/xla/service/cpu/cpu_compiler.h
浏览文件 @
3689c213
...
...
@@ -123,10 +123,11 @@ class CpuCompiler : public Compiler {
HloDumper
dump_hlo
,
std
::
vector
<
perftools
::
gputools
::
StreamExecutor
*>
stream_exec
)
override
;
StatusOr
<
std
::
unique_ptr
<
AotCompilationResult
>>
CompileAheadOfTime
(
std
::
unique_ptr
<
HloModule
>
module
,
std
::
unique_ptr
<
HloModuleConfig
>
module_config
,
HloDumper
dump_hlo
,
const
AotCompilationOptions
&
options
)
override
;
StatusOr
<
std
::
vector
<
std
::
unique_ptr
<
AotCompilationResult
>>>
CompileAheadOfTime
(
std
::
vector
<
std
::
unique_ptr
<
HloModule
>>
module
,
std
::
vector
<
std
::
unique_ptr
<
HloModuleConfig
>>
module_config
,
HloDumper
dump_hlo
,
const
AotCompilationOptions
&
options
)
override
;
perftools
::
gputools
::
Platform
::
Id
PlatformId
()
const
override
;
...
...
tensorflow/compiler/xla/service/gpu/gpu_compiler.cc
浏览文件 @
3689c213
...
...
@@ -312,10 +312,11 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> GpuCompiler::Compile(
"Compilation of multiple HLO modules is not yet supported on GPU."
);
}
StatusOr
<
std
::
unique_ptr
<
AotCompilationResult
>>
GpuCompiler
::
CompileAheadOfTime
(
std
::
unique_ptr
<
HloModule
>
module
,
std
::
unique_ptr
<
HloModuleConfig
>
module_config
,
HloDumper
dump_hlo
,
const
AotCompilationOptions
&
options
)
{
StatusOr
<
std
::
vector
<
std
::
unique_ptr
<
AotCompilationResult
>>>
GpuCompiler
::
CompileAheadOfTime
(
std
::
vector
<
std
::
unique_ptr
<
HloModule
>>
module
,
std
::
vector
<
std
::
unique_ptr
<
HloModuleConfig
>>
module_config
,
HloDumper
dump_hlo
,
const
AotCompilationOptions
&
options
)
{
return
Unimplemented
(
"not yet implemented: GpuCompiler::CompileAheadOfTime"
);
}
...
...
tensorflow/compiler/xla/service/gpu/gpu_compiler.h
浏览文件 @
3689c213
...
...
@@ -52,10 +52,11 @@ class GpuCompiler : public Compiler {
HloDumper
dump_hlo
,
std
::
vector
<
perftools
::
gputools
::
StreamExecutor
*>
stream_exec
)
override
;
StatusOr
<
std
::
unique_ptr
<
AotCompilationResult
>>
CompileAheadOfTime
(
std
::
unique_ptr
<
HloModule
>
module
,
std
::
unique_ptr
<
HloModuleConfig
>
module_config
,
HloDumper
dump_hlo
,
AotCompilationOptions
const
&
options
)
override
;
StatusOr
<
std
::
vector
<
std
::
unique_ptr
<
AotCompilationResult
>>>
CompileAheadOfTime
(
std
::
vector
<
std
::
unique_ptr
<
HloModule
>>
module
,
std
::
vector
<
std
::
unique_ptr
<
HloModuleConfig
>>
module_config
,
HloDumper
dump_hlo
,
AotCompilationOptions
const
&
options
)
override
;
perftools
::
gputools
::
Platform
::
Id
PlatformId
()
const
override
;
...
...
tensorflow/compiler/xla/service/local_service.cc
浏览文件 @
3689c213
...
...
@@ -206,42 +206,49 @@ tensorflow::Status LocalService::ExecuteLocally(
return
tensorflow
::
Status
::
OK
();
}
StatusOr
<
std
::
unique_ptr
<
AotCompilationResult
>>
StatusOr
<
std
::
vector
<
std
::
unique_ptr
<
AotCompilationResult
>
>>
LocalService
::
CompileAheadOfTime
(
const
ComputationHandle
&
computation
,
const
tensorflow
::
gtl
::
ArraySlice
<
const
Shape
*>
argument_layouts
,
const
Shape
&
result_layout
,
const
AotCompilationOptions
&
options
)
{
TF_ASSIGN_OR_RETURN
(
UserComputation
*
user_computation
,
computation_tracker_
.
Resolve
(
computation
));
VersionedComputationHandle
versioned_handle
=
user_computation
->
GetVersionedHandle
();
TF_ASSIGN_OR_RETURN
(
std
::
unique_ptr
<
HloModule
>
hlo_module
,
computation_tracker_
.
BuildHloModule
(
versioned_handle
,
/*include_unused_parameters=*/
true
));
TF_ASSIGN_OR_RETURN
(
std
::
shared_ptr
<
const
ProgramShape
>
program_shape
,
user_computation
->
ComputeProgramShape
(
versioned_handle
.
version
));
auto
module_config
=
MakeUnique
<
HloModuleConfig
>
(
*
program_shape
);
auto
*
computation_layout
=
module_config
->
mutable_entry_computation_layout
();
for
(
int
i
=
0
;
i
<
argument_layouts
.
size
();
++
i
)
{
const
Shape
&
argument_layout
=
*
argument_layouts
[
i
];
if
(
ShapeUtil
::
IsTuple
(
argument_layout
))
{
return
Unimplemented
(
"tuple arguments not supported yet"
);
const
tensorflow
::
gtl
::
ArraySlice
<
AheadOfTimeComputationInstance
>
computations
,
const
AotCompilationOptions
&
options
)
{
std
::
vector
<
std
::
unique_ptr
<
HloModule
>>
hlo_modules
;
std
::
vector
<
std
::
unique_ptr
<
HloModuleConfig
>>
module_configs
;
for
(
const
AheadOfTimeComputationInstance
&
instance
:
computations
)
{
TF_ASSIGN_OR_RETURN
(
UserComputation
*
user_computation
,
computation_tracker_
.
Resolve
(
instance
.
computation
));
VersionedComputationHandle
versioned_handle
=
user_computation
->
GetVersionedHandle
();
TF_ASSIGN_OR_RETURN
(
std
::
unique_ptr
<
HloModule
>
hlo_module
,
computation_tracker_
.
BuildHloModule
(
versioned_handle
,
/*include_unused_parameters=*/
true
));
hlo_modules
.
push_back
(
std
::
move
(
hlo_module
));
TF_ASSIGN_OR_RETURN
(
std
::
shared_ptr
<
const
ProgramShape
>
program_shape
,
user_computation
->
ComputeProgramShape
(
versioned_handle
.
version
));
module_configs
.
push_back
(
MakeUnique
<
HloModuleConfig
>
(
*
program_shape
));
HloModuleConfig
*
module_config
=
module_configs
.
back
().
get
();
auto
*
computation_layout
=
module_config
->
mutable_entry_computation_layout
();
for
(
int
i
=
0
;
i
<
instance
.
argument_layouts
.
size
();
++
i
)
{
const
Shape
&
argument_layout
=
*
instance
.
argument_layouts
[
i
];
if
(
ShapeUtil
::
IsTuple
(
argument_layout
))
{
return
Unimplemented
(
"tuple arguments not supported yet"
);
}
TF_RETURN_IF_ERROR
(
computation_layout
->
mutable_parameter_layout
(
i
)
->
CopyLayoutFromShape
(
argument_layout
));
}
TF_RETURN_IF_ERROR
(
computation_layout
->
mutable_
parameter_layout
(
i
)
->
CopyLayoutFromShape
(
argumen
t_layout
));
computation_layout
->
mutable_
result_layout
(
)
->
CopyLayoutFromShape
(
*
instance
.
resul
t_layout
));
}
TF_RETURN_IF_ERROR
(
computation_layout
->
mutable_result_layout
()
->
CopyLayoutFromShape
(
result_layout
));
return
execute_backend_
->
compiler
()
->
CompileAheadOfTime
(
std
::
move
(
hlo_module
),
std
::
move
(
module_config
),
->
CompileAheadOfTime
(
std
::
move
(
hlo_module
s
),
std
::
move
(
module_configs
),
MakeHloDumper
(),
options
)
.
ConsumeValueOrDie
();
}
...
...
@@ -426,8 +433,9 @@ StatusOr<std::unique_ptr<ShapedBuffer>> LocalService::ExecuteLocallyInternal(
}
else
{
se
::
StreamExecutor
*
stream_executor
;
if
(
options
.
device_ordinal
()
>=
0
)
{
TF_ASSIGN_OR_RETURN
(
stream_executor
,
execute_backend_
->
stream_executor
(
options
.
device_ordinal
()));
TF_ASSIGN_OR_RETURN
(
stream_executor
,
execute_backend_
->
stream_executor
(
options
.
device_ordinal
()));
}
else
{
stream_executor
=
execute_backend_
->
default_stream_executor
();
}
...
...
tensorflow/compiler/xla/service/local_service.h
浏览文件 @
3689c213
...
...
@@ -139,13 +139,21 @@ class LocalService : public Service {
tensorflow
::
gtl
::
ArraySlice
<
const
ShapedBuffer
*>
arguments
,
const
LocalExecuteOptions
&
options
,
ShapedBuffer
*
result_buffer
);
// Compiles the computation for ahead-of-time execution. This is intended for
// use in static compilation. See |LocalClient::CompileAheadOfTime| for
// additional details.
StatusOr
<
std
::
unique_ptr
<
AotCompilationResult
>>
CompileAheadOfTime
(
const
ComputationHandle
&
computation
,
const
tensorflow
::
gtl
::
ArraySlice
<
const
Shape
*>
argument_layouts
,
const
Shape
&
result_layout
,
const
AotCompilationOptions
&
Options
);
// A description of a computation to compile using CompileAheadOfTime.
struct
AheadOfTimeComputationInstance
{
ComputationHandle
computation
;
std
::
vector
<
const
Shape
*>
argument_layouts
;
const
Shape
*
result_layout
=
nullptr
;
};
// Compiles a list of computations for ahead-of-time execution. This is
// intended for use in static compilation. See
// |LocalClient::CompileAheadOfTime| for additional details.
StatusOr
<
std
::
vector
<
std
::
unique_ptr
<
AotCompilationResult
>>>
CompileAheadOfTime
(
const
tensorflow
::
gtl
::
ArraySlice
<
AheadOfTimeComputationInstance
>
computations
,
const
AotCompilationOptions
&
Options
);
// Builds an Executable with the given argument layouts and options. If
// result_layout is non-null, then the executable is compiled to produce a
...
...
tensorflow/compiler/xla/tests/local_client_aot_test_helper.cc
浏览文件 @
3689c213
...
...
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_compiler.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow/core/platform/logging.h"
...
...
@@ -72,16 +73,19 @@ int main(int argc, char** argv) {
llvm
::
Triple
triple
(
xla
::
llvm_ir
::
AsStringRef
(
triple_string
));
xla
::
Computation
computation
=
builder
.
Build
().
ConsumeValueOrDie
();
xla
::
LocalClient
::
AheadOfTimeComputationInstance
instance
{
&
computation
,
/*argument_layouts=*/
{
&
opaque_shape
},
&
r0f32
};
xla
::
cpu
::
CpuAotCompilationOptions
options
(
triple_string
,
/*cpu_name=*/
""
,
/*features=*/
""
,
"SumAndDouble"
,
xla
::
cpu
::
CpuAotCompilationOptions
::
RelocationModel
::
Static
);
auto
results
=
client
->
CompileAheadOfTime
({
instance
},
options
).
ConsumeValueOrDie
();
auto
result
=
xla
::
unique_ptr_static_cast
<
xla
::
cpu
::
CpuAotCompilationResult
>
(
client
->
CompileAheadOfTime
(
builder
.
Build
().
ValueOrDie
(),
/*argument_layouts=*/
{
&
opaque_shape
},
r0f32
,
options
)
.
ConsumeValueOrDie
());
std
::
move
(
results
.
front
()));
// We should have two buffers, one for the result and one temporary buffer,
// and both should be float-sized. It's lame to hard-code this, but we need
// local_client_aot_test.cc to be able to easily invoke the function.
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录