Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
xxadev
tensorflow
提交
f03fe1bf
T
tensorflow
项目概览
xxadev
/
tensorflow
与 Fork 源项目一致
从无法访问的项目Fork
通知
3
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
T
tensorflow
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
f03fe1bf
编写于
9月 11, 2019
作者:
T
tanzhenyu
提交者:
GitHub
9月 11, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #32433 from tanzhenyu/cherrypicks_WIQA6
Fix major Adamax gpu bug.
上级
ffb00d7f
67c17f22
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
42 addition
and
41 deletion
+42
-41
tensorflow/core/kernels/training_ops_gpu.cu.cc
tensorflow/core/kernels/training_ops_gpu.cu.cc
+8
-7
tensorflow/python/keras/optimizer_v2/BUILD
tensorflow/python/keras/optimizer_v2/BUILD
+8
-11
tensorflow/python/keras/optimizer_v2/adamax_test.py
tensorflow/python/keras/optimizer_v2/adamax_test.py
+9
-6
tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
+17
-17
未找到文件。
tensorflow/core/kernels/training_ops_gpu.cu.cc
浏览文件 @
f03fe1bf
...
...
@@ -306,15 +306,16 @@ struct ApplyAdaMax<GPUDevice, T> {
bcast
[
0
]
=
grad
.
dimension
(
0
);
Eigen
::
Sizes
<
1
>
single
;
const
auto
one
=
static_cast
<
T
>
(
1.0
);
m
.
device
(
d
)
=
m
+
(
beta1
.
constant
(
one
)
-
beta1
).
reshape
(
single
).
broadcast
(
bcast
)
*
m
.
device
(
d
)
+
=
(
beta1
.
constant
(
one
)
-
beta1
).
reshape
(
single
).
broadcast
(
bcast
)
*
(
grad
-
m
);
v
.
device
(
d
)
=
(
beta2
.
reshape
(
single
).
broadcast
(
bcast
)
*
v
).
cwiseMax
(
grad
.
abs
());
var
.
device
(
d
)
-=
lr
/
(
beta1_power
.
constant
(
one
)
-
beta1_power
).
reshape
(
single
).
broadcast
(
bcast
)
*
(
m
/
(
v
+
epsilon
));
var
.
device
(
d
)
-=
lr
.
reshape
(
single
).
broadcast
(
bcast
)
/
(
beta1_power
.
constant
(
one
)
-
beta1_power
)
.
reshape
(
single
)
.
broadcast
(
bcast
)
*
(
m
/
(
v
+
epsilon
.
reshape
(
single
).
broadcast
(
bcast
)));
}
};
...
...
tensorflow/python/keras/optimizer_v2/BUILD
浏览文件 @
f03fe1bf
...
...
@@ -201,20 +201,13 @@ cuda_py_test(
xla_enable_strict_auto_jit
=
True
,
)
py_test
(
cuda_
py_test
(
name
=
"optimizer_v2_test"
,
size
=
"medium"
,
srcs
=
[
"optimizer_v2_test.py"
],
python_version
=
"PY2"
,
shard_count
=
8
,
tags
=
[
"no_gpu"
,
# b/127001953
"no_windows"
,
# TODO(b/127092862): Re-enable this test in Kokoro.
"no_oss"
,
],
deps
=
[
additional_deps
=
[
":optimizer_v2"
,
"@absl_py//absl/testing:parameterized"
,
"//tensorflow/python:array_ops"
,
"//tensorflow/python:client_testlib"
,
"//tensorflow/python:clip_ops"
,
...
...
@@ -226,8 +219,12 @@ py_test(
"//tensorflow/python:variables"
,
"//tensorflow/python/eager:def_function"
,
"//tensorflow/python/keras"
,
"@absl_py//absl/testing:parameterized"
,
],
shard_count
=
8
,
tags
=
[
"no_windows"
,
],
xla_enable_strict_auto_jit
=
True
,
)
cuda_py_test
(
...
...
tensorflow/python/keras/optimizer_v2/adamax_test.py
浏览文件 @
f03fe1bf
...
...
@@ -80,7 +80,7 @@ class AdamaxOptimizerTest(test.TestCase):
def
doTestSparse
(
self
,
use_resource
=
False
):
for
dtype
in
[
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]:
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
# Initialize variables for numpy implementation.
zero_slots
=
lambda
:
np
.
zeros
((
3
),
dtype
=
dtype
.
as_numpy_dtype
)
# pylint: disable=cell-var-from-loop
m0
,
v0
,
m1
,
v1
=
zero_slots
(),
zero_slots
(),
zero_slots
(),
zero_slots
()
...
...
@@ -176,9 +176,12 @@ class AdamaxOptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
(
reset_test
=
True
)
def
testBasic
(
self
):
for
i
,
dtype
in
enumerate
([
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]):
with
self
.
session
(
graph
=
ops
.
Graph
()):
with
self
.
session
(
graph
=
ops
.
Graph
()
,
use_gpu
=
True
):
# Initialize variables for numpy implementation.
m0
,
v0
,
m1
,
v1
=
0.0
,
0.0
,
0.0
,
0.0
m0
=
np
.
array
([
0.0
,
0.0
])
v0
=
np
.
array
([
0.0
,
0.0
])
m1
=
np
.
array
([
0.0
,
0.0
])
v1
=
np
.
array
([
0.0
,
0.0
])
var0_np
=
np
.
array
([
1.0
,
2.0
],
dtype
=
dtype
.
as_numpy_dtype
)
grads0_np
=
np
.
array
([
0.1
,
0.1
],
dtype
=
dtype
.
as_numpy_dtype
)
var1_np
=
np
.
array
([
3.0
,
4.0
],
dtype
=
dtype
.
as_numpy_dtype
)
...
...
@@ -224,7 +227,7 @@ class AdamaxOptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
(
reset_test
=
True
)
def
testBasicWithLearningRateDecay
(
self
):
for
i
,
dtype
in
enumerate
([
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]):
with
self
.
session
(
graph
=
ops
.
Graph
()):
with
self
.
session
(
graph
=
ops
.
Graph
()
,
use_gpu
=
True
):
# Initialize variables for numpy implementation.
m0
,
v0
,
m1
,
v1
=
0.0
,
0.0
,
0.0
,
0.0
var0_np
=
np
.
array
([
1.0
,
2.0
],
dtype
=
dtype
.
as_numpy_dtype
)
...
...
@@ -278,7 +281,7 @@ class AdamaxOptimizerTest(test.TestCase):
@
test_util
.
run_deprecated_v1
def
testTensorLearningRate
(
self
):
for
dtype
in
[
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]:
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
# Initialize variables for numpy implementation.
m0
,
v0
,
m1
,
v1
=
0.0
,
0.0
,
0.0
,
0.0
var0_np
=
np
.
array
([
1.0
,
2.0
],
dtype
=
dtype
.
as_numpy_dtype
)
...
...
@@ -315,7 +318,7 @@ class AdamaxOptimizerTest(test.TestCase):
@
test_util
.
run_deprecated_v1
def
testSharing
(
self
):
for
dtype
in
[
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]:
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
# Initialize variables for numpy implementation.
m0
,
v0
,
m1
,
v1
=
0.0
,
0.0
,
0.0
,
0.0
var0_np
=
np
.
array
([
1.0
,
2.0
],
dtype
=
dtype
.
as_numpy_dtype
)
...
...
tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
浏览文件 @
f03fe1bf
...
...
@@ -65,7 +65,7 @@ class OptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
def
testBasic
(
self
):
for
_
,
dtype
in
enumerate
([
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]):
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
var0
=
resource_variable_ops
.
ResourceVariable
([
1.0
,
2.0
],
dtype
=
dtype
)
var1
=
resource_variable_ops
.
ResourceVariable
([
3.0
,
4.0
],
dtype
=
dtype
)
loss
=
lambda
:
5
*
var0
+
3
*
var1
# pylint: disable=cell-var-from-loop
...
...
@@ -129,7 +129,7 @@ class OptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
def
testPrecomputedGradient
(
self
):
for
dtype
in
[
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]:
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
var0
=
variables
.
Variable
([
1.0
,
2.0
],
dtype
=
dtype
)
var1
=
variables
.
Variable
([
3.0
,
4.0
],
dtype
=
dtype
)
loss
=
lambda
:
5
*
var0
+
3
*
var1
# pylint: disable=cell-var-from-loop
...
...
@@ -153,7 +153,7 @@ class OptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
def
testNoGradients
(
self
):
for
_
,
dtype
in
enumerate
([
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]):
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
var0
=
resource_variable_ops
.
ResourceVariable
([
1.0
,
2.0
],
dtype
=
dtype
)
var1
=
resource_variable_ops
.
ResourceVariable
([
3.0
,
4.0
],
dtype
=
dtype
)
loss
=
lambda
:
5
*
var0
# pylint: disable=cell-var-from-loop
...
...
@@ -165,7 +165,7 @@ class OptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
def
testNoGradientsForAnyVariables_Minimize
(
self
):
for
_
,
dtype
in
enumerate
([
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]):
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
var0
=
resource_variable_ops
.
ResourceVariable
([
1.0
,
2.0
],
dtype
=
dtype
)
var1
=
resource_variable_ops
.
ResourceVariable
([
3.0
,
4.0
],
dtype
=
dtype
)
loss
=
lambda
:
constant_op
.
constant
(
5.0
)
...
...
@@ -178,7 +178,7 @@ class OptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
def
testNoGradientsForAnyVariables_ApplyGradients
(
self
):
for
_
,
dtype
in
enumerate
([
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]):
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
var0
=
resource_variable_ops
.
ResourceVariable
([
1.0
,
2.0
],
dtype
=
dtype
)
var1
=
resource_variable_ops
.
ResourceVariable
([
3.0
,
4.0
],
dtype
=
dtype
)
sgd_op
=
gradient_descent
.
SGD
(
3.0
)
...
...
@@ -189,7 +189,7 @@ class OptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
def
testGradientsAsVariables
(
self
):
for
i
,
dtype
in
enumerate
([
dtypes
.
half
,
dtypes
.
float32
,
dtypes
.
float64
]):
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
var0
=
resource_variable_ops
.
ResourceVariable
([
1.0
,
2.0
],
dtype
=
dtype
)
var1
=
resource_variable_ops
.
ResourceVariable
([
3.0
,
4.0
],
dtype
=
dtype
)
loss
=
lambda
:
5
*
var0
+
3
*
var1
# pylint: disable=cell-var-from-loop
...
...
@@ -227,7 +227,7 @@ class OptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
def
testComputeGradientsWithTensors
(
self
):
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
x
=
ops
.
convert_to_tensor
(
1.0
)
def
f
():
...
...
@@ -247,7 +247,7 @@ class OptimizerTest(test.TestCase):
def
testConstraint
(
self
):
constraint_01
=
lambda
x
:
clip_ops
.
clip_by_value
(
x
,
-
0.1
,
0.
)
constraint_0
=
lambda
x
:
clip_ops
.
clip_by_value
(
x
,
0.
,
1.
)
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
var0
=
variables
.
Variable
([
1.0
,
2.0
],
constraint
=
constraint_01
)
var1
=
variables
.
Variable
([
3.0
,
4.0
],
...
...
@@ -269,14 +269,14 @@ class OptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
def
testIterationWithoutMinimize
(
self
):
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
sgd
=
gradient_descent
.
SGD
(
3.0
)
self
.
evaluate
(
sgd
.
iterations
.
initializer
)
self
.
assertEqual
(
0
,
self
.
evaluate
(
sgd
.
iterations
))
@
test_util
.
run_in_graph_and_eager_modes
def
testConfig
(
self
):
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
opt
=
gradient_descent
.
SGD
(
learning_rate
=
1.0
)
config
=
opt
.
get_config
()
opt2
=
gradient_descent
.
SGD
.
from_config
(
config
)
...
...
@@ -296,7 +296,7 @@ class OptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
def
testConfigWithLearningRateDecay
(
self
):
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
var0
=
variables
.
Variable
([[
1.0
],
[
2.0
]],
dtype
=
dtypes
.
float32
)
for
decay_schedule
in
[
learning_rate_schedule
.
InverseTimeDecay
(
...
...
@@ -327,7 +327,7 @@ class OptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
def
testGradClipValue
(
self
):
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
var
=
resource_variable_ops
.
ResourceVariable
([
1.0
,
2.0
])
loss
=
lambda
:
3
*
var
opt
=
gradient_descent
.
SGD
(
learning_rate
=
1.0
,
clipvalue
=
1.0
)
...
...
@@ -338,7 +338,7 @@ class OptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
def
testGradClipNorm
(
self
):
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
var
=
resource_variable_ops
.
ResourceVariable
([
1.0
])
loss
=
lambda
:
3
*
var
opt
=
gradient_descent
.
SGD
(
learning_rate
=
1.0
,
clipnorm
=
1.0
)
...
...
@@ -359,7 +359,7 @@ class OptimizerTest(test.TestCase):
@
test_util
.
run_in_graph_and_eager_modes
def
testWeights
(
self
):
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
opt1
=
adam
.
Adam
(
learning_rate
=
1.0
)
var1
=
resource_variable_ops
.
ResourceVariable
([
1.0
,
2.0
],
dtype
=
dtypes
.
float32
)
...
...
@@ -620,7 +620,7 @@ class OptimizersCompatibilityTest(keras_parameterized.TestCase):
'v1 optimizer does not run in experimental_run_tf_function mode or '
'eager mode'
)
np
.
random
.
seed
(
1331
)
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
train_samples
=
20
input_dim
=
3
num_classes
=
2
...
...
@@ -708,7 +708,7 @@ class OptimizersCompatibilityTest(keras_parameterized.TestCase):
'v1 optimizer does not run in experimental_run_tf_function mode or '
'eager mode'
)
np
.
random
.
seed
(
1331
)
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
train_samples
=
20
input_dim
=
3
num_classes
=
2
...
...
@@ -769,7 +769,7 @@ class OptimizersCompatibilityTest(keras_parameterized.TestCase):
'v1 optimizer does not run in experimental_run_tf_function mode or '
'eager mode'
)
np
.
random
.
seed
(
1331
)
with
self
.
cached_session
():
with
self
.
cached_session
(
use_gpu
=
True
):
train_samples
=
20
input_dim
=
3
num_classes
=
2
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录