Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
qq_38905368
tensorflow
提交
4a2abacb
T
tensorflow
项目概览
qq_38905368
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
5
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
4a2abacb
编写于
12月 12, 2018
作者:
P
Peter Hawkins
提交者:
TensorFlower Gardener
12月 12, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XLA:Python] Add CustomCall support to Python LocalComputationBuilder.
PiperOrigin-RevId: 225205868
上级
46afcd06
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
50 addition
and
1 deletion
+50
-1
tensorflow/compiler/xla/python/local_computation_builder.cc
tensorflow/compiler/xla/python/local_computation_builder.cc
+15
-0
tensorflow/compiler/xla/python/local_computation_builder.h
tensorflow/compiler/xla/python/local_computation_builder.h
+6
-0
tensorflow/compiler/xla/python/local_computation_builder.i
tensorflow/compiler/xla/python/local_computation_builder.i
+1
-0
tensorflow/compiler/xla/python/xla_client.py
tensorflow/compiler/xla/python/xla_client.py
+25
-0
tensorflow/compiler/xla/service/hlo_verifier.cc
tensorflow/compiler/xla/service/hlo_verifier.cc
+3
-1
未找到文件。
tensorflow/compiler/xla/python/local_computation_builder.cc
浏览文件 @
4a2abacb
...
...
@@ -783,6 +783,21 @@ LocalOp LocalComputationBuilder::Call(const LocalComputation& local_computation,
return
xla
::
Call
(
&
builder_
,
local_computation
.
computation
(),
xla_ops
);
}
LocalOp
LocalComputationBuilder
::
CustomCall
(
const
string
&
call_target_name
,
absl
::
Span
<
const
LocalOp
>
operands
,
const
Shape
&
shape_with_layout
,
const
std
::
vector
<
Shape
>&
operand_shapes_with_layout
,
const
string
&
opaque
)
{
std
::
vector
<
XlaOp
>
xla_ops
;
xla_ops
.
reserve
(
operands
.
size
());
for
(
const
auto
&
op
:
operands
)
{
xla_ops
.
push_back
(
op
.
op
());
}
return
xla
::
CustomCallWithLayout
(
&
builder_
,
call_target_name
,
xla_ops
,
shape_with_layout
,
operand_shapes_with_layout
,
opaque
);
}
LocalOp
LocalComputationBuilder
::
Transpose
(
const
LocalOp
&
operand
,
absl
::
Span
<
const
int64
>
permutation
)
{
return
xla
::
Transpose
(
operand
.
op
(),
permutation
);
...
...
tensorflow/compiler/xla/python/local_computation_builder.h
浏览文件 @
4a2abacb
...
...
@@ -352,6 +352,12 @@ class LocalComputationBuilder {
LocalOp
Call
(
const
LocalComputation
&
local_computation
,
absl
::
Span
<
const
LocalOp
>
operands
);
LocalOp
CustomCall
(
const
string
&
call_target_name
,
absl
::
Span
<
const
LocalOp
>
operands
,
const
Shape
&
shape_with_layout
,
const
std
::
vector
<
Shape
>&
operand_shapes_with_layout
,
const
string
&
opaque
);
LocalOp
Transpose
(
const
LocalOp
&
operand
,
absl
::
Span
<
const
int64
>
permutation
);
...
...
tensorflow/compiler/xla/python/local_computation_builder.i
浏览文件 @
4a2abacb
...
...
@@ -1147,6 +1147,7 @@ tensorflow::ImportNumpy();
%
unignore
xla
::
swig
::
LocalComputationBuilder
::
Cholesky
;
%
unignore
xla
::
swig
::
LocalComputationBuilder
::
QR
;
%
unignore
xla
::
swig
::
LocalComputationBuilder
::
TriangularSolve
;
%
unignore
xla
::
swig
::
LocalComputationBuilder
::
CustomCall
;
%
unignore
xla
::
swig
::
DeleteLocalComputation
;
%
unignore
xla
::
swig
::
DestructureLocalShapedBufferTuple
;
%
unignore
xla
::
swig
::
DestructureXrtAllocationTuple
;
...
...
tensorflow/compiler/xla/python/xla_client.py
浏览文件 @
4a2abacb
...
...
@@ -1102,6 +1102,31 @@ class ComputationBuilder(object):
"""
return
self
.
_client
.
Call
(
computation_to_apply
.
computation
,
operands
)
def
CustomCall
(
self
,
call_target_name
,
operands
,
shape_with_layout
,
operand_shapes_with_layout
,
opaque
=
None
):
"""Enqueues a custom call operation onto the computation.
Args:
call_target_name: the name of the function to call.
operands: an iterable of LocalOp. The number and types of operands must
match the arity of `operand_shapes_with_layout`.
shape_with_layout: the shape of the operator's output, with layout.
operand_shapes_with_layout: the shapes of `operands`, including the
expected layouts.
opaque: an opaque string passed to the backend.
Returns:
A LocalOp representing the added custom call op.
"""
opaque
=
opaque
or
''
return
self
.
_client
.
CustomCall
(
call_target_name
,
operands
,
shape_with_layout
,
operand_shapes_with_layout
,
opaque
)
def
Map
(
self
,
operands
,
computation_to_apply
,
dimensions
):
"""Enqueues a map operation onto the computation.
...
...
tensorflow/compiler/xla/service/hlo_verifier.cc
浏览文件 @
4a2abacb
...
...
@@ -481,7 +481,9 @@ Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
const
Shape
&
operand_shape_with_layout
=
custom_call
->
operand_shapes_with_layout
()[
i
];
TF_RET_CHECK
(
ShapeUtil
::
Compatible
(
custom_call
->
operand
(
i
)
->
shape
(),
operand_shape_with_layout
));
operand_shape_with_layout
))
<<
custom_call
->
operand
(
i
)
->
shape
().
ToString
()
<<
" operand "
<<
operand_shape_with_layout
.
ToString
();
TF_RET_CHECK
(
LayoutUtil
::
HasLayout
(
operand_shape_with_layout
));
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录