Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f8ba12e5
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f8ba12e5
编写于
2月 09, 2022
作者:
F
fwenguang
提交者:
GitHub
2月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[mlu] add mlu kernel for momentum op (#39331)
上级
d47a511a
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
717 addition
and
4 deletion
+717
-4
paddle/fluid/operators/mlu/mlu_baseop.cc
paddle/fluid/operators/mlu/mlu_baseop.cc
+5
-3
paddle/fluid/operators/mlu/mlu_baseop.h
paddle/fluid/operators/mlu/mlu_baseop.h
+4
-1
paddle/fluid/operators/optimizers/momentum_op_mlu.cc
paddle/fluid/operators/optimizers/momentum_op_mlu.cc
+96
-0
python/paddle/fluid/tests/unittests/mlu/test_momentum_op_mlu.py
.../paddle/fluid/tests/unittests/mlu/test_momentum_op_mlu.py
+612
-0
未找到文件。
paddle/fluid/operators/mlu/mlu_baseop.cc
浏览文件 @
f8ba12e5
...
@@ -873,9 +873,11 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
...
@@ -873,9 +873,11 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
const
cnnlTensorDescriptor_t
a_desc
,
const
void
*
a
,
const
cnnlTensorDescriptor_t
a_desc
,
const
void
*
a
,
const
cnnlTensorDescriptor_t
b_desc
,
const
void
*
b
,
const
cnnlTensorDescriptor_t
b_desc
,
const
void
*
b
,
const
cnnlTensorDescriptor_t
output_desc
,
void
*
output
,
const
cnnlTensorDescriptor_t
output_desc
,
void
*
output
,
const
cnnlDataType_t
dtype
)
{
const
cnnlDataType_t
dtype
,
const
float
alpha1_float
,
static
const
int
alpha1_int
=
1
,
alpha2_int
=
1
,
beta_int
=
0
;
const
float
alpha2_float
,
const
float
beta_float
)
{
static
const
float
alpha1_float
=
1.
f
,
alpha2_float
=
1.
f
,
beta_float
=
0.
f
;
const
int
alpha1_int
=
static_cast
<
const
int
>
(
alpha1_float
);
const
int
alpha2_int
=
static_cast
<
const
int
>
(
alpha2_float
);
const
int
beta_int
=
static_cast
<
const
int
>
(
beta_float
);
const
void
*
alpha1_ptr
=
static_cast
<
const
void
*>
(
&
alpha1_float
);
const
void
*
alpha1_ptr
=
static_cast
<
const
void
*>
(
&
alpha1_float
);
const
void
*
alpha2_ptr
=
static_cast
<
const
void
*>
(
&
alpha2_float
);
const
void
*
alpha2_ptr
=
static_cast
<
const
void
*>
(
&
alpha2_float
);
...
...
paddle/fluid/operators/mlu/mlu_baseop.h
浏览文件 @
f8ba12e5
...
@@ -678,7 +678,10 @@ class MLUCnnl {
...
@@ -678,7 +678,10 @@ class MLUCnnl {
const
cnnlTensorDescriptor_t
a_desc
,
const
void
*
a
,
const
cnnlTensorDescriptor_t
a_desc
,
const
void
*
a
,
const
cnnlTensorDescriptor_t
b_desc
,
const
void
*
b
,
const
cnnlTensorDescriptor_t
b_desc
,
const
void
*
b
,
const
cnnlTensorDescriptor_t
output_desc
,
void
*
output
,
const
cnnlTensorDescriptor_t
output_desc
,
void
*
output
,
const
cnnlDataType_t
dtype
);
const
cnnlDataType_t
dtype
,
const
float
alpha1_float
=
1.
f
,
const
float
alpha2_float
=
1.
f
,
const
float
beta_float
=
0.
f
);
static
void
BiasAddGrad
(
const
ExecutionContext
&
ctx
,
const
int
axis
,
static
void
BiasAddGrad
(
const
ExecutionContext
&
ctx
,
const
int
axis
,
const
cnnlTensorDescriptor_t
out_backprop_desc
,
const
cnnlTensorDescriptor_t
out_backprop_desc
,
...
...
paddle/fluid/operators/optimizers/momentum_op_mlu.cc
0 → 100644
浏览文件 @
f8ba12e5
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
MLUMomentumOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
MLUDeviceContext
>();
std
::
string
regularization_method
=
ctx
.
Attr
<
std
::
string
>
(
"regularization_method"
);
auto
regularization_coeff
=
ctx
.
Attr
<
float
>
(
"regularization_coeff"
);
RegularizationType
regularization_flag
{
RegularizationType
::
kNONE
};
// disable regularization
if
(
regularization_method
==
"l2_decay"
)
{
regularization_flag
=
RegularizationType
::
kL2DECAY
;
}
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
Tensor
mu_tensor
=
ctx
.
AllocateTmpTensor
<
T
,
MLUDeviceContext
>
({
1
},
dev_ctx
);
MLUCnnlTensorDesc
mu_tensor_desc
(
mu_tensor
);
MLUCnnl
::
Fill
(
ctx
,
mu
,
mu_tensor_desc
.
get
(),
GetBasePtr
(
&
mu_tensor
));
Tensor
regularized_grad
;
MLUCnnlTensorDesc
param_desc
(
*
param
);
if
(
regularization_flag
==
RegularizationType
::
kL2DECAY
)
{
regularized_grad
=
ctx
.
AllocateTmpTensor
<
T
,
MLUDeviceContext
>
(
param
->
dims
(),
dev_ctx
);
MLUCnnlOpTensorDesc
op_tensor_desc
(
CNNL_OP_TENSOR_ADD
,
ToCnnlDataType
<
T
>
(),
CNNL_NOT_PROPAGATE_NAN
);
MLUCnnl
::
OpTensor
(
ctx
,
op_tensor_desc
.
get
(),
param_desc
.
get
(),
GetBasePtr
(
param
),
param_desc
.
get
(),
GetBasePtr
(
grad
),
param_desc
.
get
(),
GetBasePtr
(
&
regularized_grad
),
ToCnnlDataType
<
T
>
(),
regularization_coeff
);
}
else
{
regularized_grad
=
*
grad
;
}
framework
::
TensorCopy
(
*
param
,
ctx
.
GetPlace
(),
dev_ctx
,
param_out
);
framework
::
TensorCopy
(
*
velocity
,
ctx
.
GetPlace
(),
dev_ctx
,
velocity_out
);
MLUCnnl
::
ApplyMomentum
(
ctx
,
param_desc
.
get
(),
GetBasePtr
(
&
regularized_grad
),
use_nesterov
,
GetBasePtr
(
learning_rate
),
GetBasePtr
(
&
mu_tensor
),
GetBasePtr
(
param_out
),
GetBasePtr
(
velocity_out
));
}
else
if
(
grad_var
->
IsType
<
pten
::
SelectedRows
>
())
{
PADDLE_ENFORCE_EQ
(
false
,
true
,
platform
::
errors
::
PermissionDenied
(
"Unsupport SparseMomentum"
));
}
else
{
PADDLE_ENFORCE_EQ
(
false
,
true
,
platform
::
errors
::
PermissionDenied
(
"Unsupported Variable Type of Grad "
"in MomentumOp. Excepted LodTensor "
"or SelectedRows, But received [%s]"
,
paddle
::
framework
::
ToTypeName
(
grad_var
->
Type
())));
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_MLU_KERNEL
(
momentum
,
ops
::
MLUMomentumOpKernel
<
float
>
,
ops
::
MLUMomentumOpKernel
<
plat
::
float16
>
);
python/paddle/fluid/tests/unittests/mlu/test_momentum_op_mlu.py
0 → 100644
浏览文件 @
f8ba12e5
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
import
sys
sys
.
path
.
append
(
'..'
)
from
op_test
import
OpTest
import
paddle
import
paddle.fluid
as
fluid
import
numpy
from
test_momentum_op
import
calculate_momentum_by_numpy
class
TestMomentumOp1
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"momentum"
self
.
dtype
=
np
.
float32
self
.
init_dtype
()
self
.
set_mlu
()
param
=
np
.
random
.
random
((
123
,
321
)).
astype
(
self
.
dtype
)
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
self
.
dtype
)
velocity
=
np
.
zeros
((
123
,
321
)).
astype
(
self
.
dtype
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
np
.
float32
)
mu
=
0.0001
use_nesterov
=
False
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Velocity'
:
velocity
,
'LearningRate'
:
learning_rate
}
self
.
attrs
=
{
'mu'
:
mu
}
param_out
,
velocity_out
=
calculate_momentum_by_numpy
(
param
=
param
,
grad
=
grad
,
mu
=
mu
,
velocity
=
velocity
,
use_nesterov
=
use_nesterov
,
learning_rate
=
learning_rate
)
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
def
init_dtype
(
self
):
pass
def
set_mlu
(
self
):
self
.
place
=
paddle
.
device
.
MLUPlace
(
0
)
self
.
__class__
.
use_mlu
=
True
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
class
TestMomentumOpFp16
(
TestMomentumOp1
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
,
atol
=
1e-3
)
class
TestMomentumOp2
(
OpTest
):
'''Test Momentum with default values for attributes
'''
def
setUp
(
self
):
self
.
op_type
=
"momentum"
self
.
place
=
paddle
.
device
.
MLUPlace
(
0
)
self
.
__class__
.
use_mlu
=
True
param
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
"float32"
)
velocity
=
np
.
zeros
((
123
,
321
)).
astype
(
"float32"
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
"float32"
)
mu
=
0.0001
use_nesterov
=
True
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Velocity'
:
velocity
,
'LearningRate'
:
learning_rate
}
self
.
attrs
=
{
'mu'
:
mu
,
'use_nesterov'
:
use_nesterov
}
param_out
,
velocity_out
=
calculate_momentum_by_numpy
(
param
=
param
,
grad
=
grad
,
mu
=
mu
,
velocity
=
velocity
,
use_nesterov
=
use_nesterov
,
learning_rate
=
learning_rate
)
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
class
TestMomentumV2
(
unittest
.
TestCase
):
def
test_momentum_dygraph
(
self
):
paddle
.
disable_static
()
value
=
np
.
arange
(
26
).
reshape
(
2
,
13
).
astype
(
"float32"
)
a
=
paddle
.
to_tensor
(
value
)
linear
=
paddle
.
nn
.
Linear
(
13
,
5
)
# This can be any optimizer supported by dygraph.
adam
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
,
parameters
=
linear
.
parameters
())
out
=
linear
(
a
)
out
.
backward
()
adam
.
step
()
adam
.
clear_gradients
()
def
test_momentum
(
self
):
paddle
.
enable_static
()
place
=
fluid
.
MLUPlace
(
0
)
main
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
rms_optimizer
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
0.1
,
momentum
=
0.9
)
rms_optimizer
.
minimize
(
avg_cost
)
fetch_list
=
[
avg_cost
]
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
uci_housing
.
train
(),
batch_size
=
1
)
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
x
,
y
])
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
data
in
train_reader
():
exe
.
run
(
main
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
fetch_list
)
def
test_raise_error
(
self
):
self
.
assertRaises
(
ValueError
,
paddle
.
optimizer
.
Momentum
,
learning_rate
=
None
)
self
.
assertRaises
(
ValueError
,
paddle
.
optimizer
.
Momentum
,
momentum
=
None
)
class
TestMomentumOpWithDecay
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"momentum"
self
.
place
=
paddle
.
device
.
MLUPlace
(
0
)
self
.
__class__
.
use_mlu
=
True
self
.
dtype
=
np
.
float32
self
.
use_nesterov
=
True
self
.
regularization_method
=
'l2_decay'
self
.
regularization_coeff
=
0.9
self
.
init_config
()
param
=
np
.
random
.
random
((
123
,
321
)).
astype
(
self
.
dtype
)
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
self
.
dtype
)
velocity
=
np
.
zeros
((
123
,
321
)).
astype
(
self
.
dtype
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
np
.
float32
)
mu
=
0.0001
use_nesterov
=
self
.
use_nesterov
regularization_method
=
self
.
regularization_method
regularization_coeff
=
self
.
regularization_coeff
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Velocity'
:
velocity
,
'LearningRate'
:
learning_rate
}
self
.
attrs
=
{
'mu'
:
mu
,
'use_nesterov'
:
use_nesterov
,
'regularization_method'
:
regularization_method
,
'regularization_coeff'
:
regularization_coeff
}
grad
=
grad
+
regularization_coeff
*
param
param_out
,
velocity_out
=
calculate_momentum_by_numpy
(
param
=
param
,
grad
=
grad
,
mu
=
mu
,
velocity
=
velocity
,
use_nesterov
=
use_nesterov
,
learning_rate
=
learning_rate
)
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
def
init_config
(
self
):
pass
def
test_check_output
(
self
):
paddle
.
enable_static
()
self
.
check_output_with_place
(
self
.
place
)
class
TestMomentumOpWithDecayFP16
(
TestMomentumOpWithDecay
):
def
init_config
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
,
atol
=
1e-3
)
class
TestMomentumOpWithDecay2
(
TestMomentumOpWithDecay
):
def
init_config
(
self
):
self
.
use_nesterov
=
False
class
TestMomentumOpWithDecayAPI
(
unittest
.
TestCase
):
def
_test_momentum_dygraph_common
(
self
,
regularization
):
paddle
.
disable_static
()
inp
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
[
10
,
10
]).
astype
(
"float32"
)
linear
=
paddle
.
nn
.
Linear
(
10
,
10
)
inp
=
paddle
.
to_tensor
(
inp
)
out
=
linear
(
inp
)
loss
=
paddle
.
mean
(
out
)
# This can be any optimizer supported by dygraph.
momentum
=
paddle
.
fluid
.
contrib
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
,
parameter_list
=
linear
.
parameters
(),
regularization
=
regularization
)
momentum
.
minimize
(
loss
)
def
test_momentum_dygraph_1
(
self
):
self
.
_test_momentum_dygraph_common
(
regularization
=
paddle
.
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
0.1
))
def
test_momentum_static
(
self
):
paddle
.
enable_static
()
place
=
fluid
.
MLUPlace
(
0
)
main
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
momentum_optimizer
=
paddle
.
fluid
.
contrib
.
optimizer
.
Momentum
(
learning_rate
=
0.1
,
momentum
=
0.9
)
momentum_optimizer
.
minimize
(
avg_cost
)
fetch_list
=
[
avg_cost
]
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
uci_housing
.
train
(),
batch_size
=
1
)
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
x
,
y
])
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
data
in
train_reader
():
exe
.
run
(
main
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
fetch_list
)
class
TestFusedMomentumWithDecayAPI
(
unittest
.
TestCase
):
def
get_program
(
self
,
weight_attr
,
bias_attr
=
False
):
main_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main_program
=
main_program
,
startup_program
=
startup_program
):
x
=
paddle
.
static
.
data
(
name
=
'x'
,
shape
=
[
10
,
10
])
linear
=
paddle
.
nn
.
Linear
(
10
,
10
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
)
out
=
linear
(
x
)
loss
=
paddle
.
mean
(
out
)
optimizer
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
,
weight_decay
=
paddle
.
regularizer
.
L2Decay
(
0.5
))
optimizer
.
minimize
(
loss
)
return
main_program
def
test_param_has_l2decay
(
self
):
paddle
.
enable_static
()
weight_attr
=
paddle
.
ParamAttr
(
name
=
"weight"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.5
),
regularizer
=
paddle
.
regularizer
.
L2Decay
(
0.1
))
program
=
self
.
get_program
(
weight_attr
,
bias_attr
=
False
)
ops
=
program
.
global_block
().
ops
self
.
assertEqual
(
ops
[
-
1
].
attr
(
'regularization_method'
),
'l2_decay'
)
self
.
assertEqual
(
ops
[
-
1
].
attr
(
'regularization_coeff'
),
np
.
float32
(
0.1
))
for
i
in
range
(
len
(
ops
)):
self
.
assertTrue
(
'sum'
not
in
ops
[
i
].
type
)
self
.
assertTrue
(
'scale'
not
in
ops
[
i
].
type
)
def
test_param_has_l1decay
(
self
):
paddle
.
enable_static
()
weight_attr
=
paddle
.
ParamAttr
(
name
=
"weight"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.5
),
regularizer
=
paddle
.
regularizer
.
L1Decay
(
0.1
))
bias_attr
=
paddle
.
ParamAttr
(
name
=
"bias"
,
initializer
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
0.
),
regularizer
=
None
)
program
=
self
.
get_program
(
weight_attr
,
bias_attr
)
ops
=
program
.
global_block
().
ops
self
.
assertEqual
(
ops
[
-
1
].
type
,
'momentum'
)
self
.
assertEqual
(
ops
[
-
2
].
type
,
'momentum'
)
self
.
assertEqual
(
ops
[
-
3
].
type
,
'sum'
)
self
.
assertEqual
(
ops
[
-
4
].
type
,
'scale'
)
self
.
assertEqual
(
ops
[
-
5
].
type
,
'sign'
)
self
.
assertEqual
(
ops
[
-
6
].
type
,
'matmul_v2_grad'
)
if
'weight'
in
ops
[
-
1
].
input
(
'Param'
):
self
.
assertEqual
(
ops
[
-
1
].
attr
(
'regularization_method'
),
''
)
self
.
assertEqual
(
ops
[
-
1
].
attr
(
'regularization_coeff'
),
0
)
if
'bias'
in
ops
[
-
2
].
input
(
'Param'
):
self
.
assertEqual
(
ops
[
-
2
].
attr
(
'regularization_method'
),
'l2_decay'
)
self
.
assertEqual
(
ops
[
-
2
].
attr
(
'regularization_coeff'
),
np
.
float32
(
0.5
))
def
test_param_has_no_regularizer
(
self
):
paddle
.
enable_static
()
program
=
self
.
get_program
(
weight_attr
=
None
)
ops
=
program
.
global_block
().
ops
self
.
assertEqual
(
ops
[
-
1
].
attr
(
'regularization_method'
),
'l2_decay'
)
self
.
assertEqual
(
ops
[
-
1
].
attr
(
'regularization_coeff'
),
np
.
float32
(
0.5
))
for
i
in
range
(
len
(
ops
)):
self
.
assertTrue
(
'sum'
not
in
ops
[
i
].
type
)
self
.
assertTrue
(
'scale'
not
in
ops
[
i
].
type
)
class
TestMomentumOpVsMomentumOpWithDecayAPI
(
unittest
.
TestCase
):
def
__update_params
(
self
,
momentum
,
linear
):
for
i
in
range
(
10
):
inp
=
paddle
.
full
(
shape
=
[
2
,
2
],
fill_value
=
i
,
dtype
=
'float32'
).
astype
(
"float32"
)
inp
=
paddle
.
to_tensor
(
inp
)
out
=
linear
(
inp
)
loss
=
paddle
.
mean
(
out
)
loss
.
backward
()
momentum
.
minimize
(
loss
)
linear
.
clear_gradients
()
def
__test_vs
(
self
,
place
=
fluid
.
MLUPlace
(
0
)):
paddle
.
disable_static
(
place
=
place
)
linear_old
=
paddle
.
nn
.
Linear
(
2
,
2
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
2.0
),
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
2.0
))
momentum_old
=
paddle
.
fluid
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
,
parameter_list
=
linear_old
.
parameters
(),
regularization
=
paddle
.
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
0.1
))
self
.
__update_params
(
momentum
=
momentum_old
,
linear
=
linear_old
)
linear_new
=
paddle
.
nn
.
Linear
(
2
,
2
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
2.0
),
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
2.0
))
momentum_new
=
paddle
.
fluid
.
contrib
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
,
parameter_list
=
linear_new
.
parameters
(),
regularization
=
paddle
.
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
0.1
))
self
.
__update_params
(
momentum
=
momentum_new
,
linear
=
linear_new
)
self
.
assertEqual
(
(
linear_old
.
weight
.
numpy
()
==
linear_new
.
weight
.
numpy
()).
all
(),
True
,
'the param weight updated by two Momentum optimizers should equal'
)
def
test_vs
(
self
,
place
=
fluid
.
MLUPlace
(
0
)):
places
=
[
fluid
.
MLUPlace
(
0
)]
for
place
in
places
:
self
.
__test_vs
(
place
=
place
)
class
TestMomentumV2Group
(
TestMomentumV2
):
def
test_momentum_dygraph
(
self
):
paddle
.
disable_static
()
value
=
np
.
arange
(
26
).
reshape
(
2
,
13
).
astype
(
"float32"
)
a
=
paddle
.
to_tensor
(
value
)
linear_1
=
paddle
.
nn
.
Linear
(
13
,
5
)
linear_2
=
paddle
.
nn
.
Linear
(
5
,
3
)
# This can be any optimizer supported by dygraph.
adam
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
parameters
=
[{
'params'
:
linear_1
.
parameters
()
},
{
'params'
:
linear_2
.
parameters
(),
'weight_decay'
:
0.001
,
'learning_rate'
:
0.1
,
'momentum'
:
0.99
}],
weight_decay
=
0.1
,
momentum
=
0.9
)
out
=
linear_1
(
a
)
out
=
linear_2
(
out
)
out
.
backward
()
adam
.
step
()
adam
.
clear_gradients
()
class
TestMultiTensorMomentumDygraph
(
unittest
.
TestCase
):
def
_momentum_optimize_dygraph
(
self
,
place
,
use_param_attr
=
False
,
use_param_group
=
False
,
use_amp
=
False
,
use_multi_tensor
=
False
):
paddle
.
disable_static
()
paddle
.
seed
(
10
)
paddle
.
set_device
(
place
)
input
=
paddle
.
randn
((
5
,
5
))
weight_attr
=
paddle
.
ParamAttr
(
learning_rate
=
0.5
,
regularizer
=
paddle
.
regularizer
.
L2Decay
(
1.0
),
trainable
=
True
)
if
use_param_attr
:
model
=
paddle
.
nn
.
Linear
(
5
,
5
,
weight_attr
)
else
:
model
=
paddle
.
nn
.
Linear
(
5
,
5
)
if
not
use_param_group
:
optimizer
=
paddle
.
optimizer
.
Momentum
(
parameters
=
model
.
parameters
(),
use_multi_tensor
=
use_multi_tensor
,
multi_precision
=
use_amp
)
else
:
optimizer
=
paddle
.
optimizer
.
Momentum
(
parameters
=
[{
'params'
:
model
.
parameters
(),
'weight_decay'
:
0.001
,
'learning_rate'
:
0.1
,
'momentum'
:
0.99
}],
use_multi_tensor
=
use_multi_tensor
,
multi_precision
=
use_amp
)
for
idx
in
range
(
5
):
if
place
==
'mlu'
and
use_amp
==
True
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
if
place
==
'mlu'
and
use_amp
==
True
:
with
paddle
.
amp
.
auto_cast
(
level
=
'O2'
):
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
scaled
=
scaler
.
scale
(
loss
)
scaled
.
backward
()
scaler
.
step
(
optimizer
)
optimizer
.
clear_grad
(
set_to_zero
=
False
)
else
:
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
# This can be any optimizer supported by dygraph.
loss
.
backward
()
optimizer
.
step
()
optimizer
.
clear_grad
(
set_to_zero
=
False
)
return
output
,
model
.
parameters
()
def
_get_places
(
self
):
places
=
[
'mlu'
]
return
places
def
_check_with_place_amp
(
self
,
place
,
use_amp
):
output1
,
params1
=
self
.
_momentum_optimize_dygraph
(
place
=
place
,
use_amp
=
use_amp
,
use_multi_tensor
=
True
)
output2
,
params2
=
self
.
_momentum_optimize_dygraph
(
place
=
place
,
use_amp
=
use_amp
,
use_multi_tensor
=
False
)
self
.
assertEqual
(
np
.
allclose
(
output1
,
output2
,
rtol
=
1e-05
),
True
)
for
idx
in
range
(
len
(
params1
)):
self
.
assertEqual
(
np
.
allclose
(
params1
[
idx
],
params2
[
idx
],
rtol
=
1e-05
),
True
)
def
_check_with_param_arrt
(
self
,
place
,
use_amp
):
output1
,
params1
=
self
.
_momentum_optimize_dygraph
(
place
=
place
,
use_amp
=
use_amp
,
use_param_attr
=
True
,
use_multi_tensor
=
True
)
output2
,
params2
=
self
.
_momentum_optimize_dygraph
(
place
=
place
,
use_amp
=
use_amp
,
use_param_attr
=
True
,
use_multi_tensor
=
False
)
self
.
assertEqual
(
np
.
allclose
(
output1
,
output2
,
rtol
=
1e-05
),
True
)
for
idx
in
range
(
len
(
params1
)):
self
.
assertEqual
(
np
.
allclose
(
params1
[
idx
],
params2
[
idx
],
rtol
=
1e-05
),
True
)
def
_check_with_param_group
(
self
,
place
,
use_amp
):
output1
,
params1
=
self
.
_momentum_optimize_dygraph
(
place
=
place
,
use_amp
=
use_amp
,
use_param_group
=
True
,
use_multi_tensor
=
True
)
output2
,
params2
=
self
.
_momentum_optimize_dygraph
(
place
=
place
,
use_amp
=
use_amp
,
use_param_group
=
True
,
use_multi_tensor
=
False
)
self
.
assertEqual
(
np
.
allclose
(
output1
,
output2
,
rtol
=
1e-05
),
True
)
for
idx
in
range
(
len
(
params1
)):
self
.
assertEqual
(
np
.
allclose
(
params1
[
idx
],
params2
[
idx
],
rtol
=
1e-05
),
True
)
def
test_main
(
self
):
for
place
in
self
.
_get_places
():
# use_amp_list = [True, False]
use_amp_list
=
[
False
]
for
use_amp
in
use_amp_list
:
self
.
_check_with_place_amp
(
place
,
use_amp
)
self
.
_check_with_param_arrt
(
place
,
use_amp
)
self
.
_check_with_param_group
(
place
,
use_amp
)
class
TestMultiTensorMomentumStatic
(
unittest
.
TestCase
):
def
_momentum_optimize_static
(
self
,
place
,
use_amp
=
False
,
use_multi_tensor
=
False
):
paddle
.
enable_static
()
paddle
.
seed
(
10
)
np
.
random
.
seed
(
10
)
if
place
==
'cpu'
:
use_amp
=
False
exe
=
paddle
.
static
.
Executor
(
place
=
paddle
.
device
.
MLUPlace
(
0
))
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
optimizer
=
paddle
.
optimizer
.
Momentum
(
multi_precision
=
use_amp
,
use_multi_tensor
=
use_multi_tensor
)
if
use_amp
:
optimizer
=
paddle
.
static
.
amp
.
decorate
(
optimizer
,
init_loss_scaling
=
128.0
,
use_dynamic_loss_scaling
=
True
,
use_pure_fp16
=
True
,
use_fp16_guard
=
False
)
with
paddle
.
static
.
program_guard
(
train_program
,
startup_program
):
if
use_amp
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float16'
)
else
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float32'
)
hidden
=
paddle
.
static
.
nn
.
fc
(
x
=
data
,
size
=
10
)
loss
=
paddle
.
fluid
.
layers
.
mean
(
hidden
)
optimizer
.
minimize
(
loss
)
exe
.
run
(
startup_program
)
if
use_amp
:
optimizer
.
amp_init
(
place
=
place
,
scope
=
paddle
.
static
.
global_scope
())
x
=
numpy
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float16'
)
else
:
x
=
numpy
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float32'
)
out
=
[]
for
idx
in
range
(
5
):
loss_data
,
=
exe
.
run
(
train_program
,
feed
=
{
"X"
:
x
},
fetch_list
=
[
loss
.
name
])
out
.
append
(
loss_data
)
return
out
def
_get_places
(
self
):
places
=
[
'mlu'
]
return
places
def
_check_with_place_amp
(
self
,
place
,
use_amp
):
output1
=
self
.
_momentum_optimize_static
(
place
=
place
,
use_amp
=
use_amp
,
use_multi_tensor
=
True
)
output2
=
self
.
_momentum_optimize_static
(
place
=
place
,
use_amp
=
use_amp
,
use_multi_tensor
=
False
)
for
idx
in
range
(
len
(
output1
)):
self
.
assertEqual
(
np
.
allclose
(
output1
[
idx
],
output2
[
idx
],
rtol
=
1e-05
),
True
)
def
test_main
(
self
):
for
place
in
self
.
_get_places
():
# use_amp_list = [True, False]
use_amp_list
=
[
False
]
for
use_amp
in
use_amp_list
:
self
.
_check_with_place_amp
(
place
,
use_amp
)
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录