Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
25989425
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
410
Star
4707
Fork
583
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
25989425
编写于
4月 15, 2021
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/autodiff): add grad clip
GitOrigin-RevId: f344f4a2330ca2f560f45241d5aebcbab0f959b6
上级
601a33a8
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
273 addition
and
0 deletion
+273
-0
imperative/python/megengine/optimizer/__init__.py
imperative/python/megengine/optimizer/__init__.py
+1
-0
imperative/python/megengine/optimizer/clip_grad.py
imperative/python/megengine/optimizer/clip_grad.py
+72
-0
imperative/python/test/integration/test_converge_with_gradient_clip.py
...thon/test/integration/test_converge_with_gradient_clip.py
+120
-0
imperative/python/test/unit/optimizer/test_clip_grad.py
imperative/python/test/unit/optimizer/test_clip_grad.py
+80
-0
未找到文件。
imperative/python/megengine/optimizer/__init__.py
浏览文件 @
25989425
...
...
@@ -10,6 +10,7 @@ from .adadelta import Adadelta
from
.adagrad
import
Adagrad
from
.adam
import
Adam
from
.adamw
import
AdamW
from
.clip_grad
import
*
from
.lr_scheduler
import
LRScheduler
from
.multi_step_lr
import
MultiStepLR
from
.optimizer
import
Optimizer
...
...
imperative/python/megengine/optimizer/clip_grad.py
0 → 100644
浏览文件 @
25989425
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# pylint: disable=redefined-builtin
from
typing
import
Iterable
,
Union
from
..core._imperative_rt.core2
import
pop_scope
,
push_scope
from
..functional
import
clip
,
concat
,
minimum
,
norm
from
..tensor
import
Tensor
__all__
=
[
"clip_grad_norm"
,
"clip_grad_value"
]
def
clip_grad_norm
(
tensors
:
Union
[
Tensor
,
Iterable
[
Tensor
]],
max_norm
:
float
,
ord
:
float
=
2.0
,
):
r
"""Clips gradient norm of an iterable of parameters.
The norm is computed over all gradients together, as if they were
concatenated into a single vector. Gradients are modified in-place.
:param tensors: an iterable of Tensors or a single Tensor.
:param max_norm: max norm of the gradients.
:param ord: type of the used p-norm. Can be ``'inf'`` for infinity norm.
:return: total norm of the parameters (viewed as a single vector).
"""
push_scope
(
"clip_grad_norm"
)
if
isinstance
(
tensors
,
Tensor
):
tensors
=
[
tensors
]
tensors
=
[
t
for
t
in
tensors
if
t
.
grad
is
not
None
]
if
len
(
tensors
)
==
0
:
pop_scope
(
"clip_grad_norm"
)
return
Tensor
(
0.0
)
norm_
=
[
norm
(
t
.
grad
.
flatten
(),
ord
=
ord
)
for
t
in
tensors
]
if
len
(
norm_
)
>
1
:
norm_
=
norm
(
concat
(
norm_
),
ord
=
ord
)
else
:
norm_
=
norm_
[
0
]
scale
=
max_norm
/
(
norm_
+
1e-6
)
scale
=
minimum
(
scale
,
1
)
for
tensor
in
tensors
:
tensor
.
grad
.
_reset
(
tensor
.
grad
*
scale
)
pop_scope
(
"clip_grad_norm"
)
return
norm_
def
clip_grad_value
(
tensors
:
Union
[
Tensor
,
Iterable
[
Tensor
]],
lower
:
float
,
upper
:
float
):
r
"""Clips gradient of an iterable of parameters to a specified lower and
upper. Gradients are modified in-place.
The gradients are clipped in the range:
.. math:: \left[\text{lower}, \text{upper}\right]
:param tensors: an iterable of Tensors or a single Tensor.
:param lower: minimum allowed value of the gradients.
:param upper: maximum allowed value of the gradients.
"""
push_scope
(
"clip_grad_value"
)
if
isinstance
(
tensors
,
Tensor
):
tensors
=
[
tensors
]
for
tensor
in
tensors
:
if
tensor
.
grad
is
None
:
continue
tensor
.
grad
.
_reset
(
clip
(
tensor
.
grad
,
lower
,
upper
))
pop_scope
(
"clip_grad_value"
)
imperative/python/test/integration/test_converge_with_gradient_clip.py
0 → 100644
浏览文件 @
25989425
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
itertools
import
numpy
as
np
import
pytest
import
megengine
as
mge
import
megengine.autodiff
as
ad
import
megengine.functional
as
F
import
megengine.optimizer
as
optim
from
megengine
import
Tensor
from
megengine.jit
import
trace
from
megengine.module
import
Linear
,
Module
from
megengine.optimizer
import
SGD
batch_size
=
64
data_shape
=
(
batch_size
,
2
)
label_shape
=
(
batch_size
,)
def
minibatch_generator
():
while
True
:
inp_data
=
np
.
zeros
((
batch_size
,
2
))
label
=
np
.
zeros
(
batch_size
,
dtype
=
np
.
int32
)
for
i
in
range
(
batch_size
):
# [x0, x1], sampled from U[-1, 1]
inp_data
[
i
,
:]
=
np
.
random
.
rand
(
2
)
*
2
-
1
label
[
i
]
=
0
if
np
.
prod
(
inp_data
[
i
])
<
0
else
1
yield
inp_data
.
astype
(
np
.
float32
),
label
.
astype
(
np
.
int32
)
def
calculate_precision
(
data
:
np
.
ndarray
,
pred
:
np
.
ndarray
)
->
float
:
""" Calculate precision for given data and prediction.
:type data: [[x, y], ...]
:param data: Input data
:type pred: [[x_pred, y_pred], ...]
:param pred: Network output data
"""
correct
=
0
assert
len
(
data
)
==
len
(
pred
)
for
inp_data
,
pred_output
in
zip
(
data
,
pred
):
label
=
0
if
np
.
prod
(
inp_data
)
<
0
else
1
pred_label
=
np
.
argmax
(
pred_output
)
if
pred_label
==
label
:
correct
+=
1
return
float
(
correct
)
/
len
(
data
)
class
XORNet
(
Module
):
def
__init__
(
self
):
self
.
mid_layers
=
14
self
.
num_class
=
2
super
().
__init__
()
self
.
fc0
=
Linear
(
self
.
num_class
,
self
.
mid_layers
,
bias
=
True
)
self
.
fc1
=
Linear
(
self
.
mid_layers
,
self
.
mid_layers
,
bias
=
True
)
self
.
fc2
=
Linear
(
self
.
mid_layers
,
self
.
num_class
,
bias
=
True
)
def
forward
(
self
,
x
):
x
=
self
.
fc0
(
x
)
x
=
F
.
tanh
(
x
)
x
=
self
.
fc1
(
x
)
x
=
F
.
tanh
(
x
)
x
=
self
.
fc2
(
x
)
return
x
def
test_training_converge
():
net
=
XORNet
()
opt
=
SGD
(
net
.
parameters
(),
lr
=
0.01
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
gm
=
ad
.
GradManager
().
attach
(
net
.
parameters
())
@
trace
(
symbolic
=
False
)
def
train
(
data
,
label
):
with
gm
:
pred
=
net
(
data
)
loss
=
F
.
nn
.
cross_entropy
(
pred
,
label
)
gm
.
backward
(
loss
)
optim
.
clip_grad_norm
(
net
.
parameters
(),
max_norm
=
0.2
,
ord
=
2.0
)
return
loss
def
infer
(
data
):
return
net
(
data
)
train_dataset
=
minibatch_generator
()
losses
=
[]
for
data
,
label
in
itertools
.
islice
(
train_dataset
,
2000
):
data
=
Tensor
(
data
,
dtype
=
np
.
float32
)
label
=
Tensor
(
label
,
dtype
=
np
.
int32
)
opt
.
clear_grad
()
loss
=
train
(
data
,
label
)
optim
.
clip_grad_value
(
net
.
parameters
(),
lower
=-
0.1
,
upper
=
0.1
)
opt
.
step
()
losses
.
append
(
loss
.
numpy
())
print
(
np
.
mean
(
losses
[
-
100
:]))
assert
np
.
mean
(
losses
[
-
100
:])
<
0.1
,
"Final training Loss must be low enough"
ngrid
=
10
x
=
np
.
linspace
(
-
1.0
,
1.0
,
ngrid
)
xx
,
yy
=
np
.
meshgrid
(
x
,
x
)
xx
=
xx
.
reshape
((
ngrid
*
ngrid
,
1
))
yy
=
yy
.
reshape
((
ngrid
*
ngrid
,
1
))
data
=
np
.
concatenate
((
xx
,
yy
),
axis
=
1
).
astype
(
np
.
float32
)
pred
=
infer
(
data
).
numpy
()
precision
=
calculate_precision
(
data
,
pred
)
print
(
"precision="
,
precision
)
assert
precision
==
1.0
,
"Test precision must be high enough, get {}"
.
format
(
precision
)
imperative/python/test/unit/optimizer/test_clip_grad.py
0 → 100644
浏览文件 @
25989425
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import
platform
import
weakref
import
numpy
as
np
import
pytest
import
megengine
as
mge
import
megengine.autodiff
as
ad
import
megengine.functional
as
F
import
megengine.module
as
M
import
megengine.optimizer
as
optim
class
Net
(
M
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
M
.
Conv2d
(
3
,
64
,
kernel_size
=
7
,
stride
=
2
,
padding
=
3
,
bias
=
False
)
self
.
bn1
=
M
.
BatchNorm2d
(
64
)
self
.
avgpool
=
M
.
AvgPool2d
(
kernel_size
=
5
,
stride
=
5
,
padding
=
0
)
self
.
fc
=
M
.
Linear
(
64
,
10
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
bn1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
avgpool
(
x
)
x
=
F
.
avg_pool2d
(
x
,
22
)
x
=
F
.
flatten
(
x
,
1
)
x
=
self
.
fc
(
x
)
return
x
def
save_grad_value
(
net
):
for
param
in
net
.
parameters
():
param
.
grad_backup
=
param
.
grad
.
numpy
().
copy
()
def
test_clip_grad_norm
():
net
=
Net
()
x
=
mge
.
tensor
(
np
.
random
.
randn
(
10
,
3
,
224
,
224
))
gm
=
ad
.
GradManager
().
attach
(
net
.
parameters
())
opt
=
optim
.
SGD
(
net
.
parameters
(),
1e-3
,
momentum
=
0.9
)
with
gm
:
loss
=
net
(
x
).
sum
()
gm
.
backward
(
loss
)
save_grad_value
(
net
)
max_norm
=
1.0
original_norm
=
optim
.
clip_grad_norm
(
net
.
parameters
(),
max_norm
=
max_norm
,
ord
=
2
)
scale
=
max_norm
/
original_norm
for
param
in
net
.
parameters
():
np
.
testing
.
assert_almost_equal
(
param
.
grad
.
numpy
(),
param
.
grad_backup
*
scale
)
opt
.
step
().
clear_grad
()
def
test_clip_grad_value
():
net
=
Net
()
x
=
np
.
random
.
randn
(
10
,
3
,
224
,
224
).
astype
(
"float32"
)
gm
=
ad
.
GradManager
().
attach
(
net
.
parameters
())
opt
=
optim
.
SGD
(
net
.
parameters
(),
1e-3
,
momentum
=
0.9
)
with
gm
:
y
=
net
(
x
)
y
=
y
.
mean
()
gm
.
backward
(
y
)
save_grad_value
(
net
)
max_val
=
5
min_val
=
-
2
optim
.
clip_grad_value
(
net
.
parameters
(),
lower
=
min_val
,
upper
=
max_val
)
for
param
in
net
.
parameters
():
np
.
testing
.
assert_almost_equal
(
param
.
grad
.
numpy
(),
np
.
maximum
(
np
.
minimum
(
param
.
grad_backup
,
max_val
),
min_val
),
)
opt
.
step
().
clear_grad
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录