Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9e3e08f0
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9e3e08f0
编写于
8月 11, 2021
作者:
R
ronnywang
提交者:
GitHub
8月 11, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[NPU] add momentum_op_npu and test (#34082)
* add momentum_op_npu and test * update * fix hang
上级
f6fab559
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
432 addition
and
1 deletion
+432
-1
paddle/fluid/operators/optimizers/momentum_op_npu.cc
paddle/fluid/operators/optimizers/momentum_op_npu.cc
+96
-0
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+8
-1
python/paddle/fluid/tests/unittests/npu/test_momentum_op_npu.py
.../paddle/fluid/tests/unittests/npu/test_momentum_op_npu.py
+328
-0
未找到文件。
paddle/fluid/operators/optimizers/momentum_op_npu.cc
0 → 100644
浏览文件 @
9e3e08f0
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/optimizers/momentum_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/optimizers/sgd_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
NPUMomentumOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
NPUDeviceContext
>();
std
::
string
regularization_method
=
ctx
.
Attr
<
std
::
string
>
(
"regularization_method"
);
auto
regularization_coeff
=
ctx
.
Attr
<
float
>
(
"regularization_coeff"
);
RegularizationType
regularization_flag
{
RegularizationType
::
kNONE
};
// disable regularization
if
(
regularization_method
==
"l2_decay"
)
{
regularization_flag
=
RegularizationType
::
kL2DECAY
;
}
T
mu
=
static_cast
<
T
>
(
ctx
.
Attr
<
float
>
(
"mu"
));
bool
use_nesterov
=
ctx
.
Attr
<
bool
>
(
"use_nesterov"
);
auto
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
velocity
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Velocity"
);
auto
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
velocity_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"VelocityOut"
);
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
velocity_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
auto
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
Tensor
mu_tensor
;
mu_tensor
.
mutable_data
<
T
>
(
framework
::
make_ddim
({
1
}),
ctx
.
GetPlace
());
FillNpuTensorWithConstant
<
T
>
(
&
mu_tensor
,
mu
);
Tensor
regularized_grad
;
if
(
regularization_flag
==
RegularizationType
::
kL2DECAY
)
{
regularized_grad
.
mutable_data
<
T
>
(
grad
->
dims
(),
ctx
.
GetPlace
());
const
auto
&
runner1
=
NpuOpRunner
(
"Muls"
,
{
*
param
},
{
regularized_grad
},
{{
"value"
,
regularization_coeff
}});
runner1
.
Run
(
dev_ctx
.
stream
());
const
auto
&
runner2
=
NpuOpRunner
(
"Add"
,
{
regularized_grad
,
*
grad
},
{
regularized_grad
},
{});
runner2
.
Run
(
dev_ctx
.
stream
());
}
else
{
regularized_grad
.
ShareDataWith
(
*
grad
);
}
framework
::
TensorCopy
(
*
param
,
ctx
.
GetPlace
(),
dev_ctx
,
param_out
);
framework
::
TensorCopy
(
*
velocity
,
ctx
.
GetPlace
(),
dev_ctx
,
velocity_out
);
// NOTE: ApplyMomentum will change the input
const
auto
&
runner
=
NpuOpRunner
(
"ApplyMomentum"
,
{
*
param_out
,
*
velocity_out
,
*
learning_rate
,
regularized_grad
,
mu_tensor
},
{
*
param_out
},
{{
"use_nesterov"
,
use_nesterov
}});
runner
.
Run
(
dev_ctx
.
stream
());
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
PADDLE_ENFORCE_EQ
(
false
,
true
,
platform
::
errors
::
PermissionDenied
(
"Unsupport SparseMomentum"
));
}
else
{
PADDLE_ENFORCE_EQ
(
false
,
true
,
platform
::
errors
::
PermissionDenied
(
"Unsupported Variable Type of Grad "
"in MomentumOp. Excepted LodTensor "
"or SelectedRows, But received [%s]"
,
paddle
::
framework
::
ToTypeName
(
grad_var
->
Type
())));
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_NPU_KERNEL
(
momentum
,
ops
::
NPUMomentumOpKernel
<
float
>
,
ops
::
NPUMomentumOpKernel
<
plat
::
float16
>
);
paddle/fluid/pybind/pybind.cc
浏览文件 @
9e3e08f0
...
@@ -2217,7 +2217,14 @@ All parameter, weight, gradient are variables in Paddle.
...
@@ -2217,7 +2217,14 @@ All parameter, weight, gradient are variables in Paddle.
#ifdef PADDLE_WITH_ASCEND_CL
#ifdef PADDLE_WITH_ASCEND_CL
m
.
def
(
"get_npu_device_count"
,
platform
::
GetNPUDeviceCount
);
m
.
def
(
"get_npu_device_count"
,
platform
::
GetNPUDeviceCount
);
m
.
def
(
"npu_finalize"
,
[]()
{
platform
::
AclInstance
::
Instance
().
Finalize
();
});
m
.
def
(
"npu_finalize"
,
[]()
{
auto
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
devices
=
platform
::
GetSelectedNPUDevices
();
for
(
size_t
i
=
0
;
i
<
devices
.
size
();
++
i
)
{
pool
.
Get
(
platform
::
NPUPlace
(
devices
[
i
]))
->
Wait
();
}
platform
::
AclInstance
::
Instance
().
Finalize
();
});
py
::
class_
<
platform
::
NPUProfConfigWrapper
>
(
m
,
"NPUProfConfigWrapper"
);
py
::
class_
<
platform
::
NPUProfConfigWrapper
>
(
m
,
"NPUProfConfigWrapper"
);
...
...
python/paddle/fluid/tests/unittests/npu/test_momentum_op_npu.py
0 → 100644
浏览文件 @
9e3e08f0
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
numpy
as
np
import
unittest
import
sys
sys
.
path
.
append
(
".."
)
from
op_test
import
OpTest
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.core
as
core
from
paddle.fluid.op
import
Operator
from
test_momentum_op
import
calculate_momentum_by_numpy
paddle
.
enable_static
()
class
TestMomentumOp1
(
OpTest
):
def
set_npu
(
self
):
self
.
__class__
.
use_npu
=
True
def
setUp
(
self
):
self
.
set_npu
()
self
.
op_type
=
"momentum"
self
.
init_dtype
()
self
.
init_case
()
param
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
grad
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
velocity
=
np
.
zeros
(
self
.
shape
).
astype
(
self
.
dtype
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
np
.
float32
)
mu
=
0.0001
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Velocity'
:
velocity
,
'LearningRate'
:
learning_rate
}
self
.
attrs
=
{
'mu'
:
mu
,
'use_nesterov'
:
self
.
use_nesterov
}
param_out
,
velocity_out
=
calculate_momentum_by_numpy
(
param
=
param
,
grad
=
grad
,
mu
=
mu
,
velocity
=
velocity
,
use_nesterov
=
self
.
use_nesterov
,
learning_rate
=
learning_rate
)
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
def
init_case
(
self
):
self
.
shape
=
(
123
,
321
)
self
.
use_nesterov
=
False
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float32
def
test_check_output
(
self
):
self
.
check_output_with_place
(
core
.
NPUPlace
(
0
))
class
TestMomentumOpFp16
(
TestMomentumOp1
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
self
.
check_output
(
atol
=
1e-3
)
class
TestMomentumOp2
(
TestMomentumOp1
):
def
init_case
(
self
):
self
.
shape
=
(
123
,
321
)
self
.
use_nesterov
=
True
class
TestMomentumV2
(
unittest
.
TestCase
):
def
test_momentum_dygraph
(
self
):
paddle
.
disable_static
(
place
=
fluid
.
NPUPlace
(
0
))
value
=
np
.
arange
(
26
).
reshape
(
2
,
13
).
astype
(
"float32"
)
a
=
paddle
.
to_tensor
(
value
)
linear
=
paddle
.
nn
.
Linear
(
13
,
5
)
# This can be any optimizer supported by dygraph.
adam
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
,
parameters
=
linear
.
parameters
())
out
=
linear
(
a
)
out
.
backward
()
adam
.
step
()
adam
.
clear_gradients
()
def
test_momentum
(
self
):
paddle
.
enable_static
()
place
=
fluid
.
NPUPlace
(
0
)
main
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
rms_optimizer
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
0.1
,
momentum
=
0.9
)
rms_optimizer
.
minimize
(
avg_cost
)
fetch_list
=
[
avg_cost
]
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
uci_housing
.
train
(),
batch_size
=
1
)
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
x
,
y
])
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
data
in
train_reader
():
exe
.
run
(
main
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
fetch_list
)
def
test_raise_error
(
self
):
self
.
assertRaises
(
ValueError
,
paddle
.
optimizer
.
Momentum
,
learning_rate
=
None
)
self
.
assertRaises
(
ValueError
,
paddle
.
optimizer
.
Momentum
,
momentum
=
None
)
class
TestMomentumOpWithDecay
(
OpTest
):
def
set_npu
(
self
):
self
.
__class__
.
use_npu
=
True
def
setUp
(
self
):
self
.
set_npu
()
self
.
op_type
=
"momentum"
self
.
dtype
=
np
.
float32
self
.
use_nesterov
=
True
self
.
regularization_method
=
'l2_decay'
self
.
regularization_coeff
=
0.9
self
.
init_config
()
param
=
np
.
random
.
random
((
123
,
321
)).
astype
(
self
.
dtype
)
grad
=
np
.
random
.
random
((
123
,
321
)).
astype
(
self
.
dtype
)
velocity
=
np
.
zeros
((
123
,
321
)).
astype
(
self
.
dtype
)
learning_rate
=
np
.
array
([
0.001
]).
astype
(
np
.
float32
)
mu
=
0.0001
use_nesterov
=
self
.
use_nesterov
regularization_method
=
self
.
regularization_method
regularization_coeff
=
self
.
regularization_coeff
self
.
inputs
=
{
'Param'
:
param
,
'Grad'
:
grad
,
'Velocity'
:
velocity
,
'LearningRate'
:
learning_rate
}
self
.
attrs
=
{
'mu'
:
mu
,
'use_nesterov'
:
use_nesterov
,
'regularization_method'
:
regularization_method
,
'regularization_coeff'
:
regularization_coeff
}
grad
=
grad
+
regularization_coeff
*
param
param_out
,
velocity_out
=
calculate_momentum_by_numpy
(
param
=
param
,
grad
=
grad
,
mu
=
mu
,
velocity
=
velocity
,
use_nesterov
=
use_nesterov
,
learning_rate
=
learning_rate
)
self
.
outputs
=
{
'ParamOut'
:
param_out
,
'VelocityOut'
:
velocity_out
}
def
init_config
(
self
):
pass
def
test_check_output
(
self
):
paddle
.
enable_static
()
self
.
check_output_with_place
(
core
.
NPUPlace
(
0
),
atol
=
3e-3
)
class
TestMomentumOpWithDecayFP16
(
TestMomentumOpWithDecay
):
def
init_config
(
self
):
self
.
dtype
=
np
.
float16
def
test_check_output
(
self
):
paddle
.
enable_static
()
self
.
check_output
(
atol
=
1e-3
)
class
TestMomentumOpWithDecay2
(
TestMomentumOpWithDecay
):
def
init_config
(
self
):
self
.
use_nesterov
=
False
class
TestMomentumOpWithDecayAPI
(
unittest
.
TestCase
):
def
_test_momentum_dygraph_common
(
self
,
regularization
):
paddle
.
disable_static
(
fluid
.
NPUPlace
(
0
))
inp
=
np
.
random
.
uniform
(
-
0.1
,
0.1
,
[
10
,
10
]).
astype
(
"float32"
)
linear
=
paddle
.
nn
.
Linear
(
10
,
10
)
inp
=
paddle
.
to_tensor
(
inp
)
out
=
linear
(
inp
)
loss
=
paddle
.
mean
(
out
)
# This can be any optimizer supported by dygraph.
momentum
=
paddle
.
fluid
.
contrib
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
,
parameter_list
=
linear
.
parameters
(),
regularization
=
regularization
)
momentum
.
minimize
(
loss
)
def
test_momentum_dygraph_1
(
self
):
self
.
_test_momentum_dygraph_common
(
regularization
=
paddle
.
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
0.1
))
def
test_momentum_static
(
self
):
paddle
.
enable_static
()
place
=
fluid
.
NPUPlace
(
0
)
main
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
momentum_optimizer
=
paddle
.
fluid
.
contrib
.
optimizer
.
Momentum
(
learning_rate
=
0.1
,
momentum
=
0.9
)
momentum_optimizer
.
minimize
(
avg_cost
)
fetch_list
=
[
avg_cost
]
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
uci_housing
.
train
(),
batch_size
=
1
)
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
x
,
y
])
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
for
data
in
train_reader
():
exe
.
run
(
main
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
fetch_list
)
class
TestMomentumOpVsMomentumOpWithDecayAPI
(
unittest
.
TestCase
):
def
__update_params
(
self
,
momentum
,
linear
):
for
i
in
range
(
10
):
inp
=
paddle
.
full
(
shape
=
[
2
,
2
],
fill_value
=
i
,
dtype
=
'float32'
).
astype
(
"float32"
)
inp
=
paddle
.
to_tensor
(
inp
)
out
=
linear
(
inp
)
loss
=
paddle
.
mean
(
out
)
loss
.
backward
()
momentum
.
minimize
(
loss
)
linear
.
clear_gradients
()
def
__test_vs
(
self
,
place
=
fluid
.
NPUPlace
(
0
)):
paddle
.
disable_static
(
place
=
place
)
linear_old
=
paddle
.
nn
.
Linear
(
2
,
2
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
2.0
),
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
2.0
))
momentum_old
=
paddle
.
fluid
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
,
parameter_list
=
linear_old
.
parameters
(),
regularization
=
paddle
.
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
0.1
))
self
.
__update_params
(
momentum
=
momentum_old
,
linear
=
linear_old
)
linear_new
=
paddle
.
nn
.
Linear
(
2
,
2
,
weight_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
2.0
),
bias_attr
=
paddle
.
nn
.
initializer
.
Constant
(
value
=
2.0
))
momentum_new
=
paddle
.
fluid
.
contrib
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
momentum
=
0.9
,
parameter_list
=
linear_new
.
parameters
(),
regularization
=
paddle
.
fluid
.
regularizer
.
L2Decay
(
regularization_coeff
=
0.1
))
self
.
__update_params
(
momentum
=
momentum_new
,
linear
=
linear_new
)
self
.
assertEqual
(
(
linear_old
.
weight
.
numpy
()
==
linear_new
.
weight
.
numpy
()).
all
(),
True
,
'the param weight updated by two Momentum optimizers should equal'
)
def
test_vs
(
self
,
place
=
fluid
.
NPUPlace
(
0
)):
self
.
__test_vs
(
place
=
place
)
class
TestMomentumV2Group
(
TestMomentumV2
):
def
test_momentum_dygraph
(
self
):
paddle
.
disable_static
(
place
=
fluid
.
NPUPlace
(
0
))
value
=
np
.
arange
(
26
).
reshape
(
2
,
13
).
astype
(
"float32"
)
a
=
paddle
.
to_tensor
(
value
)
linear_1
=
paddle
.
nn
.
Linear
(
13
,
5
)
linear_2
=
paddle
.
nn
.
Linear
(
5
,
3
)
# This can be any optimizer supported by dygraph.
adam
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
0.01
,
parameters
=
[{
'params'
:
linear_1
.
parameters
()
},
{
'params'
:
linear_2
.
parameters
(),
'weight_decay'
:
0.001
,
'learning_rate'
:
0.1
,
'momentum'
:
0.99
}],
weight_decay
=
0.1
,
momentum
=
0.9
)
out
=
linear_1
(
a
)
out
=
linear_2
(
out
)
out
.
backward
()
adam
.
step
()
adam
.
clear_gradients
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录