Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wux_labs
Tensorflow
提交
c64f92f6
T
Tensorflow
项目概览
wux_labs
/
Tensorflow
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
Tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
c64f92f6
编写于
6月 04, 2019
作者:
A
Anna R
提交者:
TensorFlower Gardener
6月 04, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Automated rollback of commit
3125d75c
PiperOrigin-RevId: 251572054
上级
be57b1a6
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
41 addition
and
412 deletion
+41
-412
tensorflow/compiler/tests/BUILD
tensorflow/compiler/tests/BUILD
+0
-20
tensorflow/compiler/tests/unary_ops_composition_test.cc
tensorflow/compiler/tests/unary_ops_composition_test.cc
+0
-137
tensorflow/compiler/tf2xla/kernels/BUILD
tensorflow/compiler/tf2xla/kernels/BUILD
+0
-4
tensorflow/compiler/tf2xla/kernels/elu_op.cc
tensorflow/compiler/tf2xla/kernels/elu_op.cc
+16
-22
tensorflow/compiler/tf2xla/kernels/elu_op.h
tensorflow/compiler/tf2xla/kernels/elu_op.h
+0
-26
tensorflow/compiler/tf2xla/kernels/relu_op.cc
tensorflow/compiler/tf2xla/kernels/relu_op.cc
+9
-14
tensorflow/compiler/tf2xla/kernels/relu_op.h
tensorflow/compiler/tf2xla/kernels/relu_op.h
+0
-26
tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc
tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc
+0
-122
tensorflow/core/kernels/ops_testutil.cc
tensorflow/core/kernels/ops_testutil.cc
+1
-8
tensorflow/core/kernels/ops_testutil.h
tensorflow/core/kernels/ops_testutil.h
+15
-33
未找到文件。
tensorflow/compiler/tests/BUILD
浏览文件 @
c64f92f6
...
...
@@ -1197,26 +1197,6 @@ tf_cuda_cc_test(
deps
=
[
":randomized_tests_library"
],
)
tf_cuda_cc_test
(
name
=
"unary_ops_composition_test"
,
srcs
=
[
"unary_ops_composition_test.cc"
],
tags
=
tf_cuda_tests_tags
(),
deps
=
[
"//tensorflow/cc:cc_ops"
,
"//tensorflow/compiler/jit"
,
"//tensorflow/compiler/jit:common"
,
"//tensorflow/compiler/jit:xla_kernel_creator"
,
"//tensorflow/compiler/tf2xla:xla_compiler"
,
"//tensorflow/core:framework"
,
"//tensorflow/core:graph"
,
"//tensorflow/core:test"
,
"//tensorflow/core:test_main"
,
"//tensorflow/core:testlib"
,
"//tensorflow/core/kernels:ops_testutil"
,
"//tensorflow/core/kernels:ops_util"
,
],
)
py_library
(
name
=
"lstm"
,
testonly
=
1
,
...
...
tensorflow/compiler/tests/unary_ops_composition_test.cc
已删除
100644 → 0
浏览文件 @
be57b1a6
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
#include "tensorflow/core/util/port.h"
namespace
tensorflow
{
namespace
{
class
UnaryOpsCompositionTest
:
public
OpsTestBase
{
protected:
template
<
typename
T
>
void
RunComposedOp
(
const
std
::
vector
<
string
>
op_names
,
T
input_scalar_value
,
T
expected_scalar_value
)
{
string
xla_device_name
=
tensorflow
::
IsGoogleCudaEnabled
()
?
DEVICE_XLA_GPU
:
DEVICE_XLA_CPU
;
SetDevice
(
DeviceType
(
xla_device_name
),
std
::
unique_ptr
<
tensorflow
::
Device
>
(
DeviceFactory
::
NewDevice
(
xla_device_name
,
{},
"/job:a/replica:0/task:0"
)));
TF_ASSERT_OK
(
NodeDefBuilder
(
"unary_op_composition"
,
"_UnaryOpsComposition"
)
.
Input
(
FakeInput
(
DataTypeToEnum
<
T
>::
v
()))
.
Attr
(
"T"
,
DataTypeToEnum
<
T
>::
v
())
.
Attr
(
"op_names"
,
op_names
)
.
Finalize
(
node_def
()));
TF_ASSERT_OK
(
InitOp
());
// We're using an XLA device here which allocates XlaTensors. We can't
// inspect XlaTensors directly so we create the input on the host and copy
// it over to the XLA device. We do the inverse on the output.
TensorShape
shape
({});
AllocatorAttributes
host_alloc_attrs
;
host_alloc_attrs
.
set_gpu_compatible
(
true
);
host_alloc_attrs
.
set_on_host
(
true
);
Allocator
*
cpu_allocator
=
device_
->
GetAllocator
(
host_alloc_attrs
);
DataType
dtype
=
DataTypeToEnum
<
T
>::
value
;
Tensor
input_on_host
(
cpu_allocator
,
dtype
,
shape
);
test
::
FillValues
<
T
>
(
&
input_on_host
,
{
input_scalar_value
});
Tensor
*
input
=
AddInput
(
dtype
,
shape
);
DeviceContext
*
device_context
=
device_
->
tensorflow_gpu_device_info
()
->
default_context
;
TF_CHECK_OK
(
BlockingCopy
([
&
](
StatusCallback
cb
)
{
device_context
->
CopyCPUTensorToDevice
(
&
input_on_host
,
device_
,
input
,
cb
);
}));
TF_ASSERT_OK
(
RunOpKernel
());
Tensor
expected_tensor
(
cpu_allocator
,
dtype
,
shape
);
test
::
FillValues
<
T
>
(
&
expected_tensor
,
{
expected_scalar_value
});
Tensor
*
output
=
GetOutput
(
0
);
Tensor
output_on_host
(
cpu_allocator
,
output
->
dtype
(),
output
->
shape
());
TF_CHECK_OK
(
BlockingCopy
([
&
](
StatusCallback
cb
)
{
device_context
->
CopyDeviceTensorToCPU
(
output
,
"output 0"
,
device_
,
&
output_on_host
,
cb
);
}));
test
::
ExpectClose
(
expected_tensor
,
output_on_host
,
/*atol=*/
1e-5
,
/*rtol=*/
1e-5
);
}
private:
template
<
typename
CopyFnTy
>
Status
BlockingCopy
(
CopyFnTy
copy_fn
)
{
Notification
n
;
Status
status
;
copy_fn
([
&
](
Status
s
)
{
status
=
s
;
n
.
Notify
();
});
n
.
WaitForNotification
();
return
status
;
}
};
TEST_F
(
UnaryOpsCompositionTest
,
Compose_Sqrt_Sqrt_F
)
{
RunComposedOp
<
float
>
({
"Sqrt"
,
"Sqrt"
},
81.0
,
3.0
);
}
TEST_F
(
UnaryOpsCompositionTest
,
Compose_Sqrt_Sqrt_D
)
{
RunComposedOp
<
double
>
({
"Sqrt"
,
"Sqrt"
},
81.0
,
3.0
);
}
TEST_F
(
UnaryOpsCompositionTest
,
Compose_Sqrt_Sin_F
)
{
RunComposedOp
<
float
>
({
"Sqrt"
,
"Sin"
},
81.0
,
std
::
sin
(
9.0
f
));
}
TEST_F
(
UnaryOpsCompositionTest
,
Compose_Cos_Acos_F
)
{
RunComposedOp
<
float
>
({
"Cos"
,
"Acos"
},
0.5
,
std
::
acos
(
std
::
cos
(
0.5
f
)));
}
TEST_F
(
UnaryOpsCompositionTest
,
Compose_Tanh_Relu_F
)
{
RunComposedOp
<
float
>
({
"Tanh"
,
"Relu"
},
0.5
,
std
::
max
(
0.0
f
,
std
::
tanh
(
0.5
f
)));
}
TEST_F
(
UnaryOpsCompositionTest
,
Compose_Tanh_Relu_D
)
{
RunComposedOp
<
double
>
({
"Tanh"
,
"Relu"
},
0.5
,
std
::
max
(
0.0
,
std
::
tanh
(
0.5
)));
}
TEST_F
(
UnaryOpsCompositionTest
,
Compose_Tanh_Relu6_F
)
{
RunComposedOp
<
float
>
({
"Relu6"
},
11.0
f
,
6.0
f
);
}
}
// namespace
}
// end namespace tensorflow
tensorflow/compiler/tf2xla/kernels/BUILD
浏览文件 @
c64f92f6
...
...
@@ -34,7 +34,6 @@ tf_kernel_library(
"dynamic_stitch_op.cc"
,
"einsum_op.cc"
,
"elu_op.cc"
,
"elu_op.h"
,
"empty_op.cc"
,
"extract_image_patches_op.cc"
,
"fake_param_op.cc"
,
...
...
@@ -72,7 +71,6 @@ tf_kernel_library(
"reduction_ops.h"
,
"reduction_ops_common.cc"
,
"relu_op.cc"
,
"relu_op.h"
,
"replica_id_op.cc"
,
"reshape_op.cc"
,
"retval_op.cc"
,
...
...
@@ -104,7 +102,6 @@ tf_kernel_library(
"training_ops.cc"
,
"transpose_op.cc"
,
"unary_ops.cc"
,
"unary_ops_composition.cc"
,
"unpack_op.cc"
,
"variable_ops.cc"
,
"xla_broadcast_helper_op.cc"
,
...
...
@@ -196,7 +193,6 @@ tf_kernel_library(
"//tensorflow/core/kernels:stateful_random_ops"
,
"//tensorflow/core/kernels:training_ops"
,
"@com_google_absl//absl/algorithm:container"
,
"@com_google_absl//absl/container:flat_hash_map"
,
"@com_google_absl//absl/strings"
,
"@com_google_absl//absl/types:optional"
,
"@com_google_absl//absl/types:span"
,
...
...
tensorflow/compiler/tf2xla/kernels/elu_op.cc
浏览文件 @
c64f92f6
...
...
@@ -15,33 +15,14 @@ limitations under the License.
// Native XLA implementations of XLA Elu Ops
#include "tensorflow/compiler/tf2xla/kernels/elu_op.h"
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
namespace
xla
{
XlaOp
Elu
(
XlaOp
x
)
{
const
auto
zero
=
ScalarLike
(
x
,
0
);
const
auto
pred
=
Gt
(
x
,
zero
);
const
auto
expm1
=
Expm1
(
x
);
return
Select
(
pred
,
x
,
expm1
);
}
XlaOp
Selu
(
XlaOp
x
)
{
const
auto
zero
=
ScalarLike
(
x
,
0
);
const
auto
scale
=
ScalarLike
(
x
,
1.0507009873554804934193349852946
);
const
auto
scale_alpha
=
ScalarLike
(
x
,
1.7580993408473768599402175208123
);
const
auto
pred
=
Gt
(
x
,
zero
);
const
auto
expm1
=
Expm1
(
x
);
return
Select
(
pred
,
Mul
(
scale
,
x
),
Mul
(
scale_alpha
,
expm1
));
}
}
// namespace xla
namespace
tensorflow
{
namespace
{
...
...
@@ -50,7 +31,11 @@ class EluOp : public XlaOpKernel {
explicit
EluOp
(
OpKernelConstruction
*
ctx
)
:
XlaOpKernel
(
ctx
)
{}
// Computes the max of the scalar input x and 0.
void
Compile
(
XlaOpKernelContext
*
ctx
)
override
{
ctx
->
SetOutput
(
0
,
xla
::
Elu
(
ctx
->
Input
(
0
)));
xla
::
XlaBuilder
*
b
=
ctx
->
builder
();
const
auto
zero
=
XlaHelpers
::
Zero
(
b
,
input_type
(
0
));
const
auto
pred
=
xla
::
Gt
(
ctx
->
Input
(
0
),
zero
);
const
auto
expm1
=
xla
::
Expm1
(
ctx
->
Input
(
0
));
ctx
->
SetOutput
(
0
,
xla
::
Select
(
pred
,
ctx
->
Input
(
0
),
expm1
));
}
};
...
...
@@ -79,7 +64,16 @@ class SeluOp : public XlaOpKernel {
explicit
SeluOp
(
OpKernelConstruction
*
ctx
)
:
XlaOpKernel
(
ctx
)
{}
// Computes the max of the scalar input x and 0.
void
Compile
(
XlaOpKernelContext
*
ctx
)
override
{
ctx
->
SetOutput
(
0
,
xla
::
Selu
(
ctx
->
Input
(
0
)));
xla
::
XlaBuilder
*
b
=
ctx
->
builder
();
const
auto
zero
=
XlaHelpers
::
Zero
(
b
,
input_type
(
0
));
const
auto
scale
=
XlaHelpers
::
FloatLiteral
(
b
,
input_type
(
0
),
1.0507009873554804934193349852946
);
const
auto
scale_alpha
=
XlaHelpers
::
FloatLiteral
(
b
,
input_type
(
0
),
1.7580993408473768599402175208123
);
const
auto
pred
=
xla
::
Gt
(
ctx
->
Input
(
0
),
zero
);
const
auto
expm1
=
xla
::
Expm1
(
ctx
->
Input
(
0
));
ctx
->
SetOutput
(
0
,
xla
::
Select
(
pred
,
xla
::
Mul
(
scale
,
ctx
->
Input
(
0
)),
xla
::
Mul
(
scale_alpha
,
expm1
)));
}
};
...
...
tensorflow/compiler/tf2xla/kernels/elu_op.h
已删除
100644 → 0
浏览文件 @
be57b1a6
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_ELU_OP_H_
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_ELU_OP_H_
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace
xla
{
XlaOp
Elu
(
XlaOp
x
);
XlaOp
Selu
(
XlaOp
x
);
}
// namespace xla
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_ELU_OP_H_
tensorflow/compiler/tf2xla/kernels/relu_op.cc
浏览文件 @
c64f92f6
...
...
@@ -15,23 +15,13 @@ limitations under the License.
// Native XLA implementations of XLA Relu Ops
#include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/literal.h"
namespace
xla
{
XlaOp
Relu
(
XlaOp
x
)
{
return
Max
(
ScalarLike
(
x
,
0
),
x
);
}
XlaOp
Relu6
(
XlaOp
x
)
{
auto
zero
=
ScalarLike
(
x
,
0
);
auto
six
=
ScalarLike
(
x
,
6
);
return
Clamp
(
zero
,
x
,
six
);
}
}
// namespace xla
namespace
tensorflow
{
namespace
{
...
...
@@ -40,7 +30,9 @@ class ReluOp : public XlaOpKernel {
explicit
ReluOp
(
OpKernelConstruction
*
ctx
)
:
XlaOpKernel
(
ctx
)
{}
// Computes the max of the scalar input x and 0.
void
Compile
(
XlaOpKernelContext
*
ctx
)
override
{
ctx
->
SetOutput
(
0
,
xla
::
Relu
(
ctx
->
Input
(
0
)));
xla
::
XlaBuilder
*
builder
=
ctx
->
builder
();
auto
zero
=
XlaHelpers
::
Zero
(
builder
,
input_type
(
0
));
ctx
->
SetOutput
(
0
,
xla
::
Max
(
zero
,
ctx
->
Input
(
0
)));
}
};
REGISTER_XLA_OP
(
Name
(
"Relu"
),
ReluOp
);
...
...
@@ -50,7 +42,10 @@ class Relu6Op : public XlaOpKernel {
explicit
Relu6Op
(
OpKernelConstruction
*
ctx
)
:
XlaOpKernel
(
ctx
)
{}
// Clamp the scalar input between 0 and 6.
void
Compile
(
XlaOpKernelContext
*
ctx
)
override
{
ctx
->
SetOutput
(
0
,
xla
::
Relu6
(
ctx
->
Input
(
0
)));
xla
::
XlaBuilder
*
builder
=
ctx
->
builder
();
auto
zero
=
XlaHelpers
::
Zero
(
builder
,
input_type
(
0
));
auto
six
=
XlaHelpers
::
IntegerLiteral
(
builder
,
input_type
(
0
),
6
);
ctx
->
SetOutput
(
0
,
xla
::
Clamp
(
zero
,
ctx
->
Input
(
0
),
six
));
}
};
REGISTER_XLA_OP
(
Name
(
"Relu6"
),
Relu6Op
);
...
...
tensorflow/compiler/tf2xla/kernels/relu_op.h
已删除
100644 → 0
浏览文件 @
be57b1a6
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_RELU_OP_H_
#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_RELU_OP_H_
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
namespace
xla
{
XlaOp
Relu
(
XlaOp
x
);
XlaOp
Relu6
(
XlaOp
x
);
}
// namespace xla
#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_RELU_OP_H_
tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc
已删除
100644 → 0
浏览文件 @
be57b1a6
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/kernels/elu_op.h"
#include "tensorflow/compiler/tf2xla/kernels/relu_op.h"
#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/client_library.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/client/lib/math.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
namespace
tensorflow
{
namespace
{
using
XlaUnaryOpGenerator
=
std
::
function
<
xla
::
XlaOp
(
xla
::
XlaOp
)
>
;
using
XlaOpGeneratorMap
=
absl
::
flat_hash_map
<
string
,
XlaUnaryOpGenerator
>
;
void
PopulateXlaOpGeneratorMap
(
XlaOpGeneratorMap
*
op_generator_map
)
{
auto
add_xla_op_generator
=
[
&
](
std
::
string
name
,
XlaUnaryOpGenerator
xla_op_generator
)
{
CHECK
(
op_generator_map
->
insert
({
name
,
xla_op_generator
}).
second
);
};
#define ADD_XLA_OP_GENERATOR(Name) add_xla_op_generator(#Name, xla::Name);
ADD_XLA_OP_GENERATOR
(
Abs
);
ADD_XLA_OP_GENERATOR
(
Acos
);
ADD_XLA_OP_GENERATOR
(
Acosh
);
ADD_XLA_OP_GENERATOR
(
Asin
);
ADD_XLA_OP_GENERATOR
(
Asinh
);
ADD_XLA_OP_GENERATOR
(
Atan
);
ADD_XLA_OP_GENERATOR
(
Atanh
);
ADD_XLA_OP_GENERATOR
(
Ceil
);
ADD_XLA_OP_GENERATOR
(
Cos
);
ADD_XLA_OP_GENERATOR
(
Cosh
);
ADD_XLA_OP_GENERATOR
(
Expm1
);
ADD_XLA_OP_GENERATOR
(
Exp
);
ADD_XLA_OP_GENERATOR
(
Floor
);
add_xla_op_generator
(
"Inv"
,
[](
xla
::
XlaOp
x
)
{
return
xla
::
ScalarLike
(
x
,
1.0
)
/
x
;
});
ADD_XLA_OP_GENERATOR
(
Log
);
ADD_XLA_OP_GENERATOR
(
Log1p
);
ADD_XLA_OP_GENERATOR
(
Neg
);
ADD_XLA_OP_GENERATOR
(
Reciprocal
);
add_xla_op_generator
(
"Rint"
,
xla
::
RoundToEven
);
ADD_XLA_OP_GENERATOR
(
Round
);
ADD_XLA_OP_GENERATOR
(
Rsqrt
);
add_xla_op_generator
(
"Sigmoid"
,
xla
::
Logistic
);
ADD_XLA_OP_GENERATOR
(
Sin
);
ADD_XLA_OP_GENERATOR
(
Sinh
);
ADD_XLA_OP_GENERATOR
(
Sqrt
);
ADD_XLA_OP_GENERATOR
(
Square
);
ADD_XLA_OP_GENERATOR
(
Tan
);
ADD_XLA_OP_GENERATOR
(
Tanh
);
ADD_XLA_OP_GENERATOR
(
Elu
);
ADD_XLA_OP_GENERATOR
(
Relu
);
ADD_XLA_OP_GENERATOR
(
Relu6
);
ADD_XLA_OP_GENERATOR
(
Selu
);
#undef ADD_XLA_OP_GENERATOR
}
const
XlaOpGeneratorMap
&
GetXlaOpGeneratorMap
()
{
static
XlaOpGeneratorMap
*
result
=
[]()
{
auto
*
result
=
new
XlaOpGeneratorMap
;
PopulateXlaOpGeneratorMap
(
result
);
return
result
;
}();
return
*
result
;
}
class
UnaryOpsCompositionOp
:
public
XlaOpKernel
{
public:
explicit
UnaryOpsCompositionOp
(
OpKernelConstruction
*
ctx
)
:
XlaOpKernel
(
ctx
)
{
OP_REQUIRES_OK
(
ctx
,
ctx
->
GetAttr
(
"op_names"
,
&
op_names_
));
const
XlaOpGeneratorMap
&
op_generator_map
=
GetXlaOpGeneratorMap
();
for
(
absl
::
string_view
op_name
:
op_names_
)
{
OP_REQUIRES
(
ctx
,
op_generator_map
.
contains
(
op_name
),
errors
::
Unimplemented
(
op_name
,
" not supported in _UnaryOpsComposition"
));
}
}
void
Compile
(
XlaOpKernelContext
*
ctx
)
override
{
xla
::
XlaOp
x
=
ctx
->
Input
(
0
);
const
XlaOpGeneratorMap
&
op_generator_map
=
GetXlaOpGeneratorMap
();
for
(
absl
::
string_view
op_name
:
op_names_
)
{
x
=
op_generator_map
.
find
(
op_name
)
->
second
(
x
);
}
ctx
->
SetOutput
(
0
,
x
);
}
private:
std
::
vector
<
string
>
op_names_
;
};
REGISTER_XLA_OP
(
Name
(
"_UnaryOpsComposition"
),
UnaryOpsCompositionOp
);
}
// namespace
}
// namespace tensorflow
tensorflow/core/kernels/ops_testutil.cc
浏览文件 @
c64f92f6
...
...
@@ -25,14 +25,8 @@ namespace tensorflow {
void
OpsTestBase
::
SetDevice
(
const
DeviceType
&
device_type
,
std
::
unique_ptr
<
Device
>
device
)
{
CHECK
(
device_
)
<<
"No device provided"
;
device_
=
device
.
get
();
device_mgr_
=
absl
::
make_unique
<
DeviceMgr
>
(
std
::
move
(
device
));
pflr_
=
absl
::
make_unique
<
ProcessFunctionLibraryRuntime
>
(
device_mgr_
.
get
(),
Env
::
Default
(),
TF_GRAPH_DEF_VERSION
,
flib_def_
.
get
(),
OptimizerOptions
());
device_type_
=
device_type
;
device_
=
std
::
move
(
device
);
#ifdef GOOGLE_CUDA
if
(
device_type
==
DEVICE_GPU
)
{
managed_allocator_
.
reset
(
new
GpuManagedAllocator
());
...
...
@@ -44,7 +38,6 @@ void OpsTestBase::SetDevice(const DeviceType& device_type,
#else
CHECK_NE
(
device_type
,
DEVICE_GPU
)
<<
"Requesting GPU on binary compiled without GOOGLE_CUDA."
;
allocator_
=
device_
->
GetAllocator
(
AllocatorAttributes
());
#endif
}
...
...
tensorflow/core/kernels/ops_testutil.h
浏览文件 @
c64f92f6
...
...
@@ -21,8 +21,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/device.h"
#include "tensorflow/core/common_runtime/device_factory.h"
#include "tensorflow/core/common_runtime/device_mgr.h"
#include "tensorflow/core/common_runtime/process_function_library_runtime.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/graph.pb.h"
...
...
@@ -72,21 +70,11 @@ inline void SetOutputAttrs(OpKernelContext::Params* params,
// to use the BrainClient interface.
class
OpsTestBase
:
public
::
testing
::
Test
{
public:
OpsTestBase
()
:
device_type_
(
DEVICE_CPU
)
{
auto
device
=
DeviceFactory
::
NewDevice
(
"CPU"
,
{},
"/job:a/replica:0/task:0"
);
CHECK
(
device
)
<<
"Could not create CPU device"
;
device_
=
device
.
get
();
device_mgr_
=
absl
::
make_unique
<
DeviceMgr
>
(
std
::
move
(
device
));
OpsTestBase
()
:
device_
(
DeviceFactory
::
NewDevice
(
"CPU"
,
{},
"/job:a/replica:0/task:0"
)),
device_type_
(
DEVICE_CPU
)
{
CHECK
(
device_
.
get
())
<<
"Could not create CPU device"
;
allocator_
=
device_
->
GetAllocator
(
AllocatorAttributes
());
flib_def_
=
absl
::
make_unique
<
FunctionLibraryDefinition
>
(
OpRegistry
::
Global
(),
FunctionDefLibrary
{});
pflr_
=
absl
::
make_unique
<
ProcessFunctionLibraryRuntime
>
(
device_mgr_
.
get
(),
Env
::
Default
(),
TF_GRAPH_DEF_VERSION
,
flib_def_
.
get
(),
OptimizerOptions
());
}
~
OpsTestBase
()
override
{
...
...
@@ -113,8 +101,8 @@ class OpsTestBase : public ::testing::Test {
// Only use this directly if you have a deprecated op that you need to test.
Status
InitOpWithGraphVersion
(
int
graph_def_version
)
{
Status
status
;
kernel_
=
CreateOpKernel
(
device_type_
,
device_
,
allocator
(),
node_def_
,
graph_def_version
,
&
status
);
kernel_
=
CreateOpKernel
(
device_type_
,
device_
.
get
(),
allocator
()
,
node_def_
,
graph_def_version
,
&
status
);
if
(
kernel_
!=
nullptr
)
input_types_
=
kernel_
->
input_types
();
return
status
;
}
...
...
@@ -177,18 +165,17 @@ class OpsTestBase : public ::testing::Test {
context_
.
reset
(
nullptr
);
params_
.
reset
(
new
OpKernelContext
::
Params
);
params_
->
device
=
device_
;
params_
->
frame_iter
=
FrameAndIter
(
0
,
0
);
params_
->
inputs
=
&
inputs_
;
params_
->
op_kernel
=
kernel_
.
get
();
params_
.
get
()
->
device
=
device_
.
get
()
;
params_
.
get
()
->
frame_iter
=
FrameAndIter
(
0
,
0
);
params_
.
get
()
->
inputs
=
&
inputs_
;
params_
.
get
()
->
op_kernel
=
kernel_
.
get
();
step_container_
.
reset
(
new
ScopedStepContainer
(
0
,
[](
const
string
&
)
{}));
params_
->
step_container
=
step_container_
.
get
();
std
::
vector
<
AllocatorAttributes
>
attrs
;
test
::
SetOutputAttrs
(
params_
.
get
(),
&
attrs
);
checkpoint
::
TensorSliceReaderCacheWrapper
slice_reader_cache_wrapper
;
params_
->
slice_reader_cache
=
&
slice_reader_cache_wrapper
;
params_
->
resource_manager
=
device_
->
resource_manager
();
params_
->
function_library
=
pflr_
->
GetFLR
(
device_
->
name
());
params_
.
get
()
->
slice_reader_cache
=
&
slice_reader_cache_wrapper
;
params_
.
get
()
->
resource_manager
=
device_
.
get
()
->
resource_manager
();
context_
.
reset
(
new
OpKernelContext
(
params_
.
get
()));
device_
->
Compute
(
kernel_
.
get
(),
context_
.
get
());
...
...
@@ -217,7 +204,7 @@ class OpsTestBase : public ::testing::Test {
const
DataTypeVector
&
output_types
()
const
{
return
kernel_
->
output_types
();
}
pr
otected
:
pr
ivate
:
Tensor
*
AddInput
(
DataType
dtype
,
const
TensorShape
&
shape
)
{
CHECK_GT
(
input_types_
.
size
(),
inputs_
.
size
())
<<
"Adding more inputs than types; perhaps you need to call MakeOp"
;
...
...
@@ -234,10 +221,8 @@ class OpsTestBase : public ::testing::Test {
return
input
;
}
// device_mgr_ owns device_.
std
::
unique_ptr
<
DeviceMgr
>
device_mgr_
;
Device
*
device_
;
protected:
std
::
unique_ptr
<
Device
>
device_
;
// The device allocator, or the managed_allocator_ below if running on GPU.
Allocator
*
allocator_
;
...
...
@@ -260,9 +245,6 @@ class OpsTestBase : public ::testing::Test {
// Unified memory allocator, only used when running on GPU.
std
::
unique_ptr
<
Allocator
>
managed_allocator_
;
std
::
unique_ptr
<
FunctionLibraryDefinition
>
flib_def_
;
std
::
unique_ptr
<
ProcessFunctionLibraryRuntime
>
pflr_
;
private:
TF_DISALLOW_COPY_AND_ASSIGN
(
OpsTestBase
);
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录