Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a4d07bb9
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a4d07bb9
编写于
12月 24, 2021
作者:
Z
zhangbo9674
提交者:
GitHub
12月 24, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[AMP] Add multi_precision for sgd (#38231)
上级
08941eda
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
408 addition
and
36 deletion
+408
-36
paddle/fluid/operators/optimizers/sgd_op.cc
paddle/fluid/operators/optimizers/sgd_op.cc
+11
-0
paddle/fluid/operators/optimizers/sgd_op.cu
paddle/fluid/operators/optimizers/sgd_op.cu
+42
-18
paddle/fluid/pybind/op_function_generator.h
paddle/fluid/pybind/op_function_generator.h
+3
-1
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+70
-9
python/paddle/fluid/tests/unittests/test_sgd_op.py
python/paddle/fluid/tests/unittests/test_sgd_op.py
+208
-0
python/paddle/optimizer/sgd.py
python/paddle/optimizer/sgd.py
+74
-8
未找到文件。
paddle/fluid/operators/optimizers/sgd_op.cc
浏览文件 @
a4d07bb9
...
...
@@ -126,13 +126,24 @@ class SGDOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Param"
,
"(Tensor or SelectedRows) Input parameter"
);
AddInput
(
"LearningRate"
,
"(Tensor) Learning rate of SGD"
);
AddInput
(
"Grad"
,
"(Tensor or SelectedRows) Input gradient"
);
AddInput
(
"MasterParam"
,
"FP32 master weight for AMP."
).
AsDispensable
();
AddOutput
(
"ParamOut"
,
"(Tensor or SelectedRows, same with Param) "
"Output parameter, should share the same memory with Param"
);
AddOutput
(
"MasterParamOut"
,
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam)."
)
.
AsDispensable
();
AddAttr
<
bool
>
(
"use_mkldnn"
,
"(bool, default false) Indicates if MKL-DNN kernel will be used"
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"multi_precision"
,
"(bool, default false) "
"Whether to use multi-precision during weight updating."
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
SGD operator
...
...
paddle/fluid/operators/optimizers/sgd_op.cu
浏览文件 @
a4d07bb9
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/sgd_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
...
...
@@ -21,14 +22,19 @@ namespace operators {
namespace
{
template
<
typename
T
>
__global__
void
SGDKernel
(
const
T
*
g
,
const
T
*
p
,
const
T
*
learning_rate
,
const
int
num
,
T
*
p_out
)
{
T
lr
=
learning_rate
[
0
];
template
<
typename
T
,
typename
MT
>
__global__
void
SGDKernelMT
(
const
T
*
param
,
const
T
*
grad
,
const
T
*
learning_rate
,
const
int
num
,
T
*
param_out
,
const
MT
*
master_param
,
MT
*
master_param_out
)
{
MT
lr
=
static_cast
<
MT
>
(
learning_rate
[
0
]);
CUDA_KERNEL_LOOP
(
i
,
num
)
{
T
g_data
=
g
[
i
];
T
p_data
=
p
[
i
];
p_out
[
i
]
=
p_data
-
lr
*
g_data
;
MT
p_data
=
master_param
?
master_param
[
i
]
:
static_cast
<
MT
>
(
param
[
i
]);
MT
g_data
=
static_cast
<
MT
>
(
grad
[
i
]);
p_data
=
p_data
-
lr
*
g_data
;
param_out
[
i
]
=
static_cast
<
T
>
(
p_data
);
if
(
master_param_out
)
{
master_param_out
[
i
]
=
p_data
;
}
}
}
...
...
@@ -63,30 +69,48 @@ class SGDOpKernel<platform::CUDADeviceContext, T>
"but the received is %s"
,
ctx
.
InputNames
(
"Param"
).
front
(),
paddle
::
framework
::
ToTypeName
(
param_var
->
Type
())));
using
paddle
::
framework
::
Tensor
;
using
MPDType
=
typename
details
::
MPTypeTrait
<
T
>::
Type
;
auto
*
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
*
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
auto
*
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
const
bool
multi_precision
=
ctx
.
Attr
<
bool
>
(
"multi_precision"
);
const
Tensor
*
master_param
=
nullptr
;
Tensor
*
master_param_out
=
nullptr
;
if
(
multi_precision
)
{
bool
has_master
=
ctx
.
HasInput
(
"MasterParam"
)
&&
ctx
.
HasOutput
(
"MasterParamOut"
);
PADDLE_ENFORCE_EQ
(
has_master
,
true
,
platform
::
errors
::
InvalidArgument
(
"The Input(MasterParam) and Output(MasterParamOut) "
"should not be null when "
"the attr `multi_precision` is true"
));
master_param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"MasterParam"
);
master_param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"MasterParamOut"
);
}
const
MPDType
*
master_in_data
=
multi_precision
?
master_param
->
data
<
MPDType
>
()
:
nullptr
;
MPDType
*
master_out_data
=
multi_precision
?
master_param_out
->
mutable_data
<
MPDType
>
(
ctx
.
GetPlace
())
:
nullptr
;
// Actually, all tensors are LoDTensor except SelectedRows.
if
(
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
*
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
// LOG(ERROR) << "grad";
// LOG(ERROR) << ctx.op().Input("Grad");
auto
*
grad_data
=
grad
->
data
<
T
>
();
// LOG(ERROR) << "param";
auto
*
param_data
=
param
->
data
<
T
>
();
// LOG(ERROR) << "fin";
auto
*
param_out_data
=
param_out
->
data
<
T
>
();
int
block
=
512
;
int
grid
=
(
param
->
numel
()
+
block
-
1
)
/
block
;
SGDKernel
<
T
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
grad_data
,
param_data
,
learning_rate
->
data
<
T
>
(),
param
->
numel
(),
param_out_data
);
SGDKernelMT
<
T
,
MPDType
><<<
grid
,
block
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
param
->
data
<
T
>
(),
grad
->
data
<
T
>
(),
learning_rate
->
data
<
T
>
(),
param
->
numel
(),
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
master_in_data
,
master_out_data
);
}
else
if
(
grad_var
->
IsType
<
framework
::
SelectedRows
>
())
{
// TODO(qijun): In Sparse SGD operator, in-place update is enforced.
...
...
paddle/fluid/pybind/op_function_generator.h
浏览文件 @
a4d07bb9
...
...
@@ -79,6 +79,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
"Beta2Pow"
,
"MasterParam"
}},
{
"sparse_attention"
,
{
"Q"
,
"K"
,
"V"
,
"Offset"
,
"Columns"
,
"KeyPaddingMask"
,
"AttnMask"
}},
{
"sgd"
,
{
"Param"
,
"LearningRate"
,
"Grad"
,
"MasterParam"
}},
};
// NOTE(zhiqiu): Like op_ins_map.
...
...
@@ -125,6 +126,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
{
"adamw"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
,
"MasterParamOut"
}},
{
"sgd"
,
{
"ParamOut"
,
"MasterParamOut"
}},
{
"lamb"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
,
"MasterParamOut"
}},
...
...
@@ -142,7 +144,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
// especially in declarative mode.
// For those OPs, we need to manually specify the outs need to pass in this map.
std
::
map
<
std
::
string
,
std
::
set
<
std
::
string
>>
op_passing_outs_map
=
{
{
"sgd"
,
{
"ParamOut"
}},
{
"sgd"
,
{
"ParamOut"
,
"MasterParamOut"
}},
{
"adam"
,
{
"ParamOut"
,
"Moment1Out"
,
"Moment2Out"
,
"Beta1PowOut"
,
"Beta2PowOut"
,
"MasterParamOut"
}},
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
a4d07bb9
...
...
@@ -1296,6 +1296,7 @@ class SGDOptimizer(Optimizer):
parameter_list
=
None
,
regularization
=
None
,
grad_clip
=
None
,
multi_precision
=
False
,
name
=
None
):
assert
learning_rate
is
not
None
super
(
SGDOptimizer
,
self
).
__init__
(
...
...
@@ -1306,26 +1307,86 @@ class SGDOptimizer(Optimizer):
name
=
name
)
self
.
type
=
"sgd"
self
.
_use_mkldnn
=
False
self
.
_multi_precision
=
multi_precision
self
.
_master_weights
=
{}
def
_create_master_weight
(
self
,
param
):
if
param
.
name
in
self
.
_master_weights
:
var
=
self
.
_master_weights
[
param
.
name
]
else
:
assert
isinstance
(
self
.
helper
,
LayerHelper
)
var_name
=
param
.
name
+
"_fp32_master"
var_name
=
unique_name
.
generate
(
var_name
)
var
=
layers
.
create_global_var
(
name
=
var_name
,
shape
=
param
.
shape
,
value
=
0
,
dtype
=
'float32'
,
persistable
=
True
)
block
=
self
.
helper
.
startup_program
.
global_block
()
block
.
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
[
param
]},
outputs
=
{
"Out"
:
[
var
]},
attrs
=
{
"in_dtype"
:
param
.
dtype
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
})
self
.
_master_weights
[
param
.
name
]
=
var
return
var
def
_create_accumulators
(
self
,
block
,
parameters
):
assert
isinstance
(
block
,
framework
.
Block
)
if
isinstance
(
parameters
,
dict
):
parameters
=
self
.
_update_param_group
(
parameters
)
# Create accumulator tensors for first and second moments
for
p
in
parameters
:
if
self
.
_multi_precision
and
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
master_p
=
self
.
_create_master_weight
(
p
)
continue
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
and
not
self
.
_multi_precision
:
warnings
.
warn
(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer."
)
@
no_grad
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
find_master
=
self
.
_multi_precision
and
param_and_grad
[
0
].
dtype
==
core
.
VarDesc
.
VarType
.
FP16
master_weight
=
(
self
.
_master_weights
[
param_and_grad
[
0
].
name
]
if
find_master
else
None
)
lr
=
self
.
_create_param_lr
(
param_and_grad
)
if
framework
.
in_dygraph_mode
():
_C_ops
.
sgd
(
param_and_grad
[
0
],
lr
,
param_and_grad
[
1
],
param_and_grad
[
0
])
_C_ops
.
sgd
(
param_and_grad
[
0
],
lr
,
param_and_grad
[
1
],
master_weight
,
param_and_grad
[
0
]
,
master_weight
)
return
None
assert
isinstance
(
block
,
framework
.
Block
)
# create the optimize op
inputs
=
{
"Param"
:
param_and_grad
[
0
],
"Grad"
:
param_and_grad
[
1
],
"LearningRate"
:
lr
}
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
]}
attrs
=
{
"multi_precision"
:
find_master
}
if
find_master
:
inputs
[
"MasterParam"
]
=
master_weight
outputs
[
"MasterParamOut"
]
=
master_weight
sgd_op
=
block
.
append_op
(
type
=
self
.
type
,
inputs
=
{
"Param"
:
param_and_grad
[
0
],
"Grad"
:
param_and_grad
[
1
],
"LearningRate"
:
lr
},
attrs
=
{
"use_mkldnn"
:
self
.
_use_mkldnn
},
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
]},
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
,
stop_gradient
=
True
)
return
sgd_op
...
...
python/paddle/fluid/tests/unittests/test_sgd_op.py
浏览文件 @
a4d07bb9
...
...
@@ -192,6 +192,7 @@ class TestSGDOpOptimizeSelectedRows(unittest.TestCase):
class
TestSGDOpWithLargeInput
(
unittest
.
TestCase
):
def
runTest
(
self
):
paddle
.
enable_static
()
data
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
value
=
128
,
dtype
=
'int64'
)
label
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
,
150
],
value
=
0.5
,
dtype
=
'float32'
)
...
...
@@ -291,5 +292,212 @@ class TestSGDV2(unittest.TestCase):
adam
.
clear_gradients
()
class
TestSGDMultiPrecision2_0
(
unittest
.
TestCase
):
def
dygraph_sgd_mp
(
self
,
mp
):
paddle
.
disable_static
()
paddle
.
seed
(
10
)
paddle
.
set_device
(
'gpu'
)
input
=
paddle
.
randn
((
2
,
2
))
model
=
paddle
.
nn
.
Linear
(
2
,
2
)
optimizer
=
paddle
.
optimizer
.
SGD
(
parameters
=
model
.
parameters
(),
multi_precision
=
mp
)
if
mp
==
True
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
for
idx
in
range
(
5
):
if
mp
==
True
:
with
paddle
.
amp
.
auto_cast
(
level
=
'O2'
):
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
scaled
=
scaler
.
scale
(
loss
)
scaled
.
backward
()
scaler
.
minimize
(
optimizer
,
scaled
)
optimizer
.
clear_grad
()
else
:
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
optimizer
.
step
()
optimizer
.
clear_grad
()
return
output
,
model
.
parameters
()
def
static_sgd_mp
(
self
,
mp
):
paddle
.
enable_static
()
paddle
.
seed
(
10
)
np
.
random
.
seed
(
10
)
exe
=
paddle
.
static
.
Executor
(
'gpu'
)
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
optimizer
=
paddle
.
optimizer
.
SGD
(
multi_precision
=
mp
)
if
mp
:
optimizer
=
paddle
.
static
.
amp
.
decorate
(
optimizer
,
init_loss_scaling
=
128.0
,
use_dynamic_loss_scaling
=
True
,
use_pure_fp16
=
True
,
use_fp16_guard
=
False
)
with
paddle
.
static
.
program_guard
(
train_program
,
startup_program
):
if
mp
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float16'
)
else
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float32'
)
hidden
=
paddle
.
static
.
nn
.
fc
(
x
=
data
,
size
=
10
)
loss
=
paddle
.
fluid
.
layers
.
mean
(
hidden
)
optimizer
.
minimize
(
loss
)
exe
.
run
(
startup_program
)
if
mp
:
optimizer
.
amp_init
(
place
=
'gpu'
,
scope
=
paddle
.
static
.
global_scope
())
x
=
np
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float16'
)
else
:
x
=
np
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float32'
)
out
=
[]
for
idx
in
range
(
5
):
loss_data
,
=
exe
.
run
(
train_program
,
feed
=
{
"X"
:
x
},
fetch_list
=
[
loss
.
name
])
out
.
append
(
loss_data
)
return
out
def
test_main
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
"Test dygraph mode"
output1_dy
,
params1_dy
=
self
.
dygraph_sgd_mp
(
mp
=
True
)
output2_dy
,
params2_dy
=
self
.
dygraph_sgd_mp
(
mp
=
False
)
self
.
assertEqual
(
np
.
allclose
(
output1_dy
.
astype
(
'float32'
).
numpy
(),
output2_dy
.
astype
(
'float32'
).
numpy
(),
atol
=
1e-01
),
True
)
for
idx
in
range
(
len
(
params1_dy
)):
self
.
assertEqual
(
np
.
allclose
(
params1_dy
[
idx
].
astype
(
'float32'
).
numpy
(),
params2_dy
[
idx
].
astype
(
'float32'
).
numpy
(),
atol
=
1e-01
),
True
)
"Test static mode"
output1_st
=
self
.
static_sgd_mp
(
mp
=
True
)
output2_st
=
self
.
static_sgd_mp
(
mp
=
False
)
for
idx
in
range
(
len
(
output1_st
)):
self
.
assertEqual
(
np
.
allclose
(
output1_st
[
idx
].
astype
(
'float32'
),
output2_st
[
idx
].
astype
(
'float32'
),
atol
=
1e-01
),
True
)
class
TestSGDMultiPrecision1_0
(
unittest
.
TestCase
):
def
dygraph_sgd_mp
(
self
,
mp
):
paddle
.
disable_static
()
paddle
.
seed
(
10
)
paddle
.
set_device
(
'gpu'
)
input
=
paddle
.
randn
((
2
,
2
))
model
=
paddle
.
nn
.
Linear
(
2
,
2
)
optimizer
=
paddle
.
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
,
parameter_list
=
model
.
parameters
(),
multi_precision
=
mp
)
if
mp
==
True
:
model
=
paddle
.
amp
.
decorate
(
models
=
model
,
level
=
'O2'
)
scaler
=
paddle
.
amp
.
GradScaler
(
init_loss_scaling
=
1024
)
for
idx
in
range
(
5
):
if
mp
==
True
:
with
paddle
.
amp
.
auto_cast
(
level
=
'O2'
):
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
scaled
=
scaler
.
scale
(
loss
)
scaled
.
backward
()
scaler
.
minimize
(
optimizer
,
scaled
)
optimizer
.
clear_gradients
()
else
:
output
=
model
(
input
)
loss
=
paddle
.
mean
(
output
)
optimizer
.
minimize
(
loss
)
optimizer
.
clear_gradients
()
return
output
,
model
.
parameters
()
def
static_sgd_mp
(
self
,
mp
):
paddle
.
enable_static
()
paddle
.
seed
(
10
)
np
.
random
.
seed
(
10
)
exe
=
paddle
.
static
.
Executor
(
'gpu'
)
train_program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
optimizer
=
paddle
.
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
,
multi_precision
=
mp
)
if
mp
:
optimizer
=
paddle
.
static
.
amp
.
decorate
(
optimizer
,
init_loss_scaling
=
128.0
,
use_dynamic_loss_scaling
=
True
,
use_pure_fp16
=
True
,
use_fp16_guard
=
False
)
with
paddle
.
static
.
program_guard
(
train_program
,
startup_program
):
if
mp
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float16'
)
else
:
data
=
paddle
.
static
.
data
(
shape
=
[
2
,
2
],
name
=
'X'
,
dtype
=
'float32'
)
hidden
=
paddle
.
static
.
nn
.
fc
(
x
=
data
,
size
=
10
)
loss
=
paddle
.
fluid
.
layers
.
mean
(
hidden
)
optimizer
.
minimize
(
loss
)
exe
.
run
(
startup_program
)
if
mp
:
optimizer
.
amp_init
(
place
=
'gpu'
,
scope
=
paddle
.
static
.
global_scope
())
x
=
np
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float16'
)
else
:
x
=
np
.
random
.
random
(
size
=
(
2
,
2
)).
astype
(
'float32'
)
out
=
[]
for
idx
in
range
(
5
):
loss_data
,
=
exe
.
run
(
train_program
,
feed
=
{
"X"
:
x
},
fetch_list
=
[
loss
.
name
])
out
.
append
(
loss_data
)
return
out
def
test_main
(
self
):
if
not
paddle
.
is_compiled_with_cuda
():
return
"Test dygraph mode"
output1_dy
,
params1_dy
=
self
.
dygraph_sgd_mp
(
mp
=
True
)
output2_dy
,
params2_dy
=
self
.
dygraph_sgd_mp
(
mp
=
False
)
self
.
assertEqual
(
np
.
allclose
(
output1_dy
.
astype
(
'float32'
).
numpy
(),
output2_dy
.
astype
(
'float32'
).
numpy
(),
atol
=
1e-01
),
True
)
for
idx
in
range
(
len
(
params1_dy
)):
self
.
assertEqual
(
np
.
allclose
(
params1_dy
[
idx
].
astype
(
'float32'
).
numpy
(),
params2_dy
[
idx
].
astype
(
'float32'
).
numpy
(),
atol
=
1e-01
),
True
)
"Test static mode"
output1_st
=
self
.
static_sgd_mp
(
mp
=
True
)
output2_st
=
self
.
static_sgd_mp
(
mp
=
False
)
for
idx
in
range
(
len
(
output1_st
)):
self
.
assertEqual
(
np
.
allclose
(
output1_st
[
idx
].
astype
(
'float32'
),
output2_st
[
idx
].
astype
(
'float32'
),
atol
=
1e-01
),
True
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/optimizer/sgd.py
浏览文件 @
a4d07bb9
...
...
@@ -18,6 +18,10 @@ from ..fluid import framework
from
..fluid.framework
import
Variable
,
name_scope
from
..fluid.dygraph
import
no_grad
from
paddle
import
_C_ops
import
warnings
from
..fluid.layer_helper
import
LayerHelper
from
..fluid
import
unique_name
from
..fluid
import
layers
__all__
=
[]
...
...
@@ -75,6 +79,7 @@ class SGD(Optimizer):
parameters
=
None
,
weight_decay
=
None
,
grad_clip
=
None
,
multi_precision
=
False
,
name
=
None
):
if
learning_rate
is
None
:
raise
ValueError
(
"learning_rate is not set"
)
...
...
@@ -85,27 +90,88 @@ class SGD(Optimizer):
grad_clip
=
grad_clip
,
name
=
name
)
self
.
type
=
"sgd"
self
.
_multi_precision
=
multi_precision
self
.
_master_weights
=
{}
def
_create_master_weight
(
self
,
param
):
if
param
.
name
in
self
.
_master_weights
:
var
=
self
.
_master_weights
[
param
.
name
]
else
:
assert
isinstance
(
self
.
helper
,
LayerHelper
)
var_name
=
param
.
name
+
"_fp32_master"
var_name
=
unique_name
.
generate
(
var_name
)
var
=
layers
.
create_global_var
(
name
=
var_name
,
shape
=
param
.
shape
,
value
=
0
,
dtype
=
'float32'
,
persistable
=
True
)
block
=
self
.
helper
.
startup_program
.
global_block
()
block
.
append_op
(
type
=
"cast"
,
inputs
=
{
"X"
:
[
param
]},
outputs
=
{
"Out"
:
[
var
]},
attrs
=
{
"in_dtype"
:
param
.
dtype
,
"out_dtype"
:
core
.
VarDesc
.
VarType
.
FP32
})
self
.
_master_weights
[
param
.
name
]
=
var
return
var
def
_create_accumulators
(
self
,
block
,
parameters
):
assert
isinstance
(
block
,
framework
.
Block
)
if
isinstance
(
parameters
,
dict
):
parameters
=
self
.
_update_param_group
(
parameters
)
# Create accumulator tensors for first and second moments
for
p
in
parameters
:
if
self
.
_multi_precision
and
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
master_p
=
self
.
_create_master_weight
(
p
)
continue
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
and
not
self
.
_multi_precision
:
warnings
.
warn
(
"Accumulating with FP16 in optimizer can lead to poor accuracy or slow convergence."
"Consider using multi_precision=True option of the Adam optimizer."
)
@
no_grad
def
_append_optimize_op
(
self
,
block
,
param_and_grad
):
if
isinstance
(
param_and_grad
,
dict
):
param_and_grad
=
self
.
_update_param_group
(
param_and_grad
)
find_master
=
self
.
_multi_precision
and
param_and_grad
[
0
].
dtype
==
core
.
VarDesc
.
VarType
.
FP16
master_weight
=
(
self
.
_master_weights
[
param_and_grad
[
0
].
name
]
if
find_master
else
None
)
lr
=
self
.
_create_param_lr
(
param_and_grad
)
if
framework
.
in_dygraph_mode
():
_C_ops
.
sgd
(
param_and_grad
[
0
],
lr
,
param_and_grad
[
1
],
param_and_grad
[
0
])
_C_ops
.
sgd
(
param_and_grad
[
0
],
lr
,
param_and_grad
[
1
],
master_weight
,
param_and_grad
[
0
]
,
master_weight
)
return
None
assert
isinstance
(
block
,
framework
.
Block
)
# create the optimize op
inputs
=
{
"Param"
:
param_and_grad
[
0
],
"Grad"
:
param_and_grad
[
1
],
"LearningRate"
:
lr
}
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
]}
attrs
=
{
"multi_precision"
:
find_master
}
if
find_master
:
inputs
[
"MasterParam"
]
=
master_weight
outputs
[
"MasterParamOut"
]
=
master_weight
sgd_op
=
block
.
append_op
(
type
=
self
.
type
,
inputs
=
{
"Param"
:
param_and_grad
[
0
],
"Grad"
:
param_and_grad
[
1
],
"LearningRate"
:
lr
},
outputs
=
{
"ParamOut"
:
param_and_grad
[
0
]},
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
,
stop_gradient
=
True
)
return
sgd_op
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录