Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
02175555
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
02175555
编写于
5月 17, 2019
作者:
Y
Yan Xu
提交者:
chengduo
5月 17, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
polish parallel dygraph code (#17164)
* add var grad hook test=develop
上级
d7df4e5e
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
411 addition
and
62 deletion
+411
-62
paddle/fluid/imperative/layer.cc
paddle/fluid/imperative/layer.cc
+5
-9
paddle/fluid/imperative/layer.h
paddle/fluid/imperative/layer.h
+1
-1
paddle/fluid/operators/distributed_ops/allreduce_op.h
paddle/fluid/operators/distributed_ops/allreduce_op.h
+1
-2
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+3
-5
python/paddle/fluid/dygraph/parallel.py
python/paddle/fluid/dygraph/parallel.py
+27
-23
python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py
...on/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py
+8
-7
python/paddle/fluid/tests/unittests/parallel_dygraph_se_resnext.py
...ddle/fluid/tests/unittests/parallel_dygraph_se_resnext.py
+314
-0
python/paddle/fluid/tests/unittests/test_dist_base.py
python/paddle/fluid/tests/unittests/test_dist_base.py
+14
-13
python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py
...ddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py
+3
-2
python/paddle/fluid/tests/unittests/test_parallel_dygraph_se_resnext.py
...fluid/tests/unittests/test_parallel_dygraph_se_resnext.py
+35
-0
未找到文件。
paddle/fluid/imperative/layer.cc
浏览文件 @
02175555
...
...
@@ -150,9 +150,9 @@ class Autograd {
const
std
::
vector
<
VarBase
*>&
ingrads
=
it
->
second
;
for
(
size_t
i
=
0
;
i
<
ingrads
.
size
();
++
i
)
{
if
(
!
ingrads
[
i
])
continue
;
if
(
ready_op
->
input_vars_
[
it
->
first
][
i
]
->
IsStopGradient
())
{
continue
;
}
auto
p
=
ready_op
->
input_vars_
[
it
->
first
][
i
];
if
(
p
->
IsStopGradient
())
continue
;
OpBase
*
pre_op
=
ready_op
->
pre_ops_
[
it
->
first
][
i
];
if
(
!
pre_op
)
continue
;
...
...
@@ -415,15 +415,11 @@ void OpBase::InvokeBackwardHooks() {
}
}
void
OpBase
::
RegisterBackwardHooks
(
const
py
::
object
&
callable
,
bool
front
)
{
void
OpBase
::
RegisterBackwardHooks
(
const
py
::
object
&
callable
)
{
VLOG
(
3
)
<<
"Register backward hooks "
<<
trace_id_
;
// TODO(minqiyang): check the callable format
if
(
front
)
{
backward_hooks_
.
insert
(
backward_hooks_
.
begin
(),
callable
);
}
else
{
backward_hooks_
.
push_back
(
callable
);
}
backward_hooks_
.
push_back
(
callable
);
}
void
VarBase
::
RunBackward
(
const
detail
::
BackwardStrategy
&
bck_stratedy
)
{
...
...
paddle/fluid/imperative/layer.h
浏览文件 @
02175555
...
...
@@ -310,7 +310,7 @@ class PYBIND11_HIDDEN OpBase {
return
grad_op_descs_
[
index
]
->
Type
();
}
void
RegisterBackwardHooks
(
const
py
::
object
&
callable
,
bool
front
=
false
);
void
RegisterBackwardHooks
(
const
py
::
object
&
callable
);
void
InvokeBackwardHooks
();
...
...
paddle/fluid/operators/distributed_ops/allreduce_op.h
浏览文件 @
02175555
...
...
@@ -39,6 +39,7 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>();
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
int
dtype
=
platform
::
ToNCCLDataType
(
in
->
type
());
int64_t
numel
=
in
->
numel
();
auto
*
sendbuff
=
in
->
data
<
void
>
();
...
...
@@ -66,12 +67,10 @@ class AllReduceOpKernel : public framework::OpKernel<T> {
red_type
=
ncclMin
;
break
;
}
VLOG
(
0
)
<<
"call allreduce with type: "
<<
reduce_type
;
PADDLE_ENFORCE
(
platform
::
dynload
::
ncclAllReduce
(
sendbuff
,
recvbuff
,
numel
,
static_cast
<
ncclDataType_t
>
(
dtype
),
red_type
,
comm
,
stream
));
if
(
ctx
.
Attr
<
bool
>
(
"sync_mode"
))
{
VLOG
(
0
)
<<
"sync allreduce..."
;
cudaError_t
e_sync
=
cudaStreamSynchronize
(
stream
);
if
(
e_sync
!=
0
)
{
LOG
(
FATAL
)
<<
"cudaStreamSynchronize "
<<
cudaGetErrorString
(
e_sync
);
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
02175555
...
...
@@ -252,11 +252,9 @@ PYBIND11_MODULE(core, m) {
py
::
class_
<
imperative
::
OpBase
,
PyOpBase
>
(
m
,
"OpBase"
,
R"DOC()DOC"
)
.
def
(
py
::
init
<
const
std
::
string
&>
())
.
def
(
"register_backward_hooks"
,
[](
imperative
::
OpBase
&
self
,
const
py
::
object
&
callable
,
bool
front
=
false
)
{
self
.
RegisterBackwardHooks
(
callable
,
front
);
},
py
::
arg
(
"callable"
),
py
::
arg
(
"front"
)
=
false
)
[](
imperative
::
OpBase
&
self
,
const
py
::
object
&
callable
)
{
self
.
RegisterBackwardHooks
(
callable
);
})
.
def_property
(
"_trace_id"
,
[](
const
imperative
::
OpBase
&
self
)
{
pybind11
::
gil_scoped_release
release
;
...
...
python/paddle/fluid/dygraph/parallel.py
浏览文件 @
02175555
...
...
@@ -13,12 +13,14 @@
# limitations under the License.
import
os
import
six
import
numpy
as
np
from
..
import
core
from
.
import
layers
from
..
import
framework
from
..layers
import
collective
from
.
import
to_variable
__all__
=
[
"prepare_context"
]
...
...
@@ -75,31 +77,33 @@ class Env(object):
class
DataParallel
(
layers
.
Layer
):
def
__init__
(
self
,
layers
):
def
__init__
(
self
,
layers
,
strategy
):
super
(
DataParallel
,
self
).
__init__
(
layers
.
full_name
()
+
"_data_parallel"
)
self
.
_layers
=
layers
def
build_once
(
self
,
*
inputs
,
**
kwargs
):
#TODO(Yancey1989): broadcast all the paramters
pass
self
.
_strategy
=
strategy
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
_collective_hook
(
iop
):
op
=
framework
.
_dygraph_tracer
().
_ops
[
iop
.
_trace_id
]
for
k
,
v
in
six
.
iteritems
(
op
.
inputs
):
for
ivar
in
v
:
g
=
ivar
.
_grad_ivar
()
if
g
:
g_var
=
framework
.
Variable
(
block
=
self
.
_helper
.
main_program
.
current_block
(),
name
=
ivar
.
_grad_name
(),
stop_gradient
=
True
,
ivar
=
g
)
collective
.
_allreduce
(
g_var
,
g_var
,
sync_mode
=
True
)
outs
=
self
.
_layers
(
*
inputs
,
**
kwargs
)
for
_
,
op
in
six
.
iteritems
(
framework
.
_dygraph_tracer
().
_ops
):
# hook collective ops
op
.
iop
.
register_backward_hooks
(
_collective_hook
,
front
=
True
)
return
outs
return
self
.
_layers
(
*
inputs
,
**
kwargs
)
def
scale_loss
(
self
,
loss
):
if
self
.
_strategy
.
nranks
<
2
:
return
loss
loss_scale
=
to_variable
(
np
.
array
([
self
.
_strategy
.
nranks
]).
astype
(
"float32"
))
loss_scale
.
stop_gradient
=
True
loss
=
loss
/
loss_scale
return
loss
def
apply_collective_grads
(
self
):
if
self
.
_strategy
.
nranks
<
2
:
return
for
param
in
self
.
_layers
.
parameters
():
if
param
.
trainable
and
param
.
_ivar
.
_grad_ivar
():
g_var
=
framework
.
Variable
(
block
=
self
.
_helper
.
main_program
.
current_block
(),
name
=
param
.
_ivar
.
_grad_name
(),
stop_gradient
=
True
,
ivar
=
param
.
_ivar
.
_grad_ivar
())
collective
.
_allreduce
(
g_var
,
g_var
,
sync_mode
=
True
)
python/paddle/fluid/tests/unittests/parallel_dygraph_mnist.py
浏览文件 @
02175555
...
...
@@ -101,11 +101,13 @@ class MNIST(fluid.dygraph.Layer):
loc
=
0.0
,
scale
=
scale
)),
act
=
"softmax"
)
def
forward
(
self
,
inputs
):
def
forward
(
self
,
inputs
,
label
):
x
=
self
.
_simple_img_conv_pool_1
(
inputs
)
x
=
self
.
_simple_img_conv_pool_2
(
x
)
x
=
self
.
_fc
(
x
)
return
x
cost
=
self
.
_fc
(
x
)
loss
=
fluid
.
layers
.
cross_entropy
(
cost
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
return
avg_loss
class
TestMnist
(
TestParallelDyGraphRunnerBase
):
...
...
@@ -113,7 +115,7 @@ class TestMnist(TestParallelDyGraphRunnerBase):
model
=
MNIST
(
"mnist"
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
mnist
.
train
(),
batch_size
=
2
,
drop_last
=
True
)
opt
=
SGDOptimizer
(
learning_rate
=
1e-3
)
opt
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
1e-3
)
return
model
,
train_reader
,
opt
def
run_one_loop
(
self
,
model
,
opt
,
data
):
...
...
@@ -126,9 +128,8 @@ class TestMnist(TestParallelDyGraphRunnerBase):
label
=
to_variable
(
y_data
)
label
.
stop_gradient
=
True
cost
=
model
(
img
)
loss
=
fluid
.
layers
.
cross_entropy
(
cost
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
avg_loss
=
model
(
img
,
label
)
return
avg_loss
...
...
python/paddle/fluid/tests/unittests/parallel_dygraph_se_resnext.py
0 → 100644
浏览文件 @
02175555
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
os
import
contextlib
import
unittest
import
numpy
as
np
import
six
import
pickle
import
sys
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.dygraph
as
dygraph
from
paddle.fluid
import
core
from
paddle.fluid.optimizer
import
SGDOptimizer
from
paddle.fluid.dygraph.nn
import
Conv2D
,
Pool2D
,
FC
,
BatchNorm
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.layer_helper
import
LayerHelper
from
test_dist_base
import
runtime_main
,
TestParallelDyGraphRunnerBase
class
ConvBNLayer
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
,
num_channels
,
num_filters
,
filter_size
,
stride
=
1
,
groups
=
1
,
act
=
None
):
super
(
ConvBNLayer
,
self
).
__init__
(
name_scope
)
self
.
_conv
=
Conv2D
(
self
.
full_name
(),
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
filter_size
,
stride
=
stride
,
padding
=
(
filter_size
-
1
)
//
2
,
groups
=
groups
,
act
=
None
,
bias_attr
=
None
)
self
.
_batch_norm
=
BatchNorm
(
self
.
full_name
(),
num_filters
,
act
=
act
,
momentum
=
0.1
)
def
forward
(
self
,
inputs
):
y
=
self
.
_conv
(
inputs
)
y
=
self
.
_batch_norm
(
y
)
return
y
class
SqueezeExcitation
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
,
num_channels
,
reduction_ratio
):
super
(
SqueezeExcitation
,
self
).
__init__
(
name_scope
)
self
.
_pool
=
Pool2D
(
self
.
full_name
(),
pool_size
=
0
,
pool_type
=
'avg'
,
global_pooling
=
True
)
self
.
_squeeze
=
FC
(
self
.
full_name
(),
size
=
num_channels
//
reduction_ratio
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.05
)),
act
=
'relu'
)
self
.
_excitation
=
FC
(
self
.
full_name
(),
size
=
num_channels
,
param_attr
=
fluid
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Constant
(
value
=
0.05
)),
act
=
'sigmoid'
)
def
forward
(
self
,
input
):
y
=
self
.
_pool
(
input
)
y
=
self
.
_squeeze
(
y
)
y
=
self
.
_excitation
(
y
)
y
=
fluid
.
layers
.
elementwise_mul
(
x
=
input
,
y
=
y
,
axis
=
0
)
return
y
class
BottleneckBlock
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
,
num_channels
,
num_filters
,
stride
,
cardinality
,
reduction_ratio
,
shortcut
=
True
):
super
(
BottleneckBlock
,
self
).
__init__
(
name_scope
)
self
.
conv0
=
ConvBNLayer
(
self
.
full_name
(),
num_channels
=
num_channels
,
num_filters
=
num_filters
,
filter_size
=
1
)
self
.
conv1
=
ConvBNLayer
(
self
.
full_name
(),
num_channels
=
num_filters
,
num_filters
=
num_filters
,
filter_size
=
3
,
stride
=
stride
,
groups
=
cardinality
)
self
.
conv2
=
ConvBNLayer
(
self
.
full_name
(),
num_channels
=
num_filters
,
num_filters
=
num_filters
*
4
,
filter_size
=
1
,
act
=
'relu'
)
self
.
scale
=
SqueezeExcitation
(
self
.
full_name
(),
num_channels
=
num_filters
*
4
,
reduction_ratio
=
reduction_ratio
)
if
not
shortcut
:
self
.
short
=
ConvBNLayer
(
self
.
full_name
(),
num_channels
=
num_channels
,
num_filters
=
num_filters
*
4
,
filter_size
=
1
,
stride
=
stride
)
self
.
shortcut
=
shortcut
self
.
_num_channels_out
=
num_filters
*
4
def
forward
(
self
,
inputs
):
y
=
self
.
conv0
(
inputs
)
conv1
=
self
.
conv1
(
y
)
conv2
=
self
.
conv2
(
conv1
)
scale
=
self
.
scale
(
conv2
)
if
self
.
shortcut
:
short
=
inputs
else
:
short
=
self
.
short
(
inputs
)
y
=
fluid
.
layers
.
elementwise_add
(
x
=
short
,
y
=
scale
)
layer_helper
=
LayerHelper
(
self
.
full_name
(),
act
=
'relu'
)
y
=
layer_helper
.
append_activation
(
y
)
return
y
class
SeResNeXt
(
fluid
.
dygraph
.
Layer
):
def
__init__
(
self
,
name_scope
,
layers
=
50
,
class_dim
=
102
):
super
(
SeResNeXt
,
self
).
__init__
(
name_scope
)
self
.
layers
=
layers
supported_layers
=
[
50
,
101
,
152
]
assert
layers
in
supported_layers
,
\
"supported layers are {} but input layer is {}"
.
format
(
supported_layers
,
layers
)
if
layers
==
50
:
cardinality
=
32
reduction_ratio
=
16
depth
=
[
3
,
4
,
6
,
3
]
num_filters
=
[
128
,
256
,
512
,
1024
]
self
.
conv0
=
ConvBNLayer
(
self
.
full_name
(),
num_channels
=
3
,
num_filters
=
64
,
filter_size
=
7
,
stride
=
2
,
act
=
'relu'
)
self
.
pool
=
Pool2D
(
self
.
full_name
(),
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
1
,
pool_type
=
'max'
)
elif
layers
==
101
:
cardinality
=
32
reduction_ratio
=
16
depth
=
[
3
,
4
,
23
,
3
]
num_filters
=
[
128
,
256
,
512
,
1024
]
self
.
conv0
=
ConvBNLayer
(
self
.
full_name
(),
num_channels
=
3
,
num_filters
=
3
,
filter_size
=
7
,
stride
=
2
,
act
=
'relu'
)
self
.
pool
=
Pool2D
(
self
.
full_name
(),
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
1
,
pool_type
=
'max'
)
elif
layers
==
152
:
cardinality
=
64
reduction_ratio
=
16
depth
=
[
3
,
8
,
36
,
3
]
num_filters
=
[
128
,
256
,
512
,
1024
]
self
.
conv0
=
ConvBNLayer
(
self
.
full_name
(),
num_channels
=
3
,
num_filters
=
3
,
filter_size
=
7
,
stride
=
2
,
act
=
'relu'
)
self
.
conv1
=
ConvBNLayer
(
self
.
full_name
(),
num_channels
=
64
,
num_filters
=
3
,
filter_size
=
7
,
stride
=
2
,
act
=
'relu'
)
self
.
conv2
=
ConvBNLayer
(
self
.
full_name
(),
num_channels
=
64
,
num_filters
=
3
,
filter_size
=
7
,
stride
=
2
,
act
=
'relu'
)
self
.
pool
=
Pool2D
(
self
.
full_name
(),
pool_size
=
3
,
pool_stride
=
2
,
pool_padding
=
1
,
pool_type
=
'max'
)
self
.
bottleneck_block_list
=
[]
num_channels
=
64
for
block
in
range
(
len
(
depth
)):
shortcut
=
False
for
i
in
range
(
depth
[
block
]):
bottleneck_block
=
self
.
add_sublayer
(
'bb_%d_%d'
%
(
block
,
i
),
BottleneckBlock
(
self
.
full_name
(),
num_channels
=
num_channels
,
num_filters
=
num_filters
[
block
],
stride
=
2
if
i
==
0
and
block
!=
0
else
1
,
cardinality
=
cardinality
,
reduction_ratio
=
reduction_ratio
,
shortcut
=
shortcut
))
num_channels
=
bottleneck_block
.
_num_channels_out
self
.
bottleneck_block_list
.
append
(
bottleneck_block
)
shortcut
=
True
self
.
pool2d_avg
=
Pool2D
(
self
.
full_name
(),
pool_size
=
7
,
pool_type
=
'avg'
,
global_pooling
=
True
)
import
math
stdv
=
1.0
/
math
.
sqrt
(
2048
*
1.0
)
self
.
fc
=
FC
(
self
.
full_name
(),
size
=
class_dim
,
act
=
'softmax'
,
param_attr
=
fluid
.
param_attr
.
ParamAttr
(
initializer
=
fluid
.
initializer
.
Uniform
(
-
stdv
,
stdv
)))
def
forward
(
self
,
inputs
,
label
):
if
self
.
layers
==
50
or
self
.
layers
==
101
:
y
=
self
.
conv0
(
inputs
)
y
=
self
.
pool
(
y
)
elif
self
.
layers
==
152
:
y
=
self
.
conv0
(
inputs
)
y
=
self
.
conv1
(
inputs
)
y
=
self
.
conv2
(
inputs
)
y
=
self
.
pool
(
y
)
for
bottleneck_block
in
self
.
bottleneck_block_list
:
y
=
bottleneck_block
(
y
)
y
=
self
.
pool2d_avg
(
y
)
y
=
fluid
.
layers
.
dropout
(
y
,
dropout_prob
=
0.2
,
seed
=
1
)
cost
=
self
.
fc
(
y
)
loss
=
fluid
.
layers
.
cross_entropy
(
cost
,
label
)
avg_loss
=
fluid
.
layers
.
mean
(
loss
)
return
avg_loss
class
TestSeResNeXt
(
TestParallelDyGraphRunnerBase
):
def
get_model
(
self
):
model
=
SeResNeXt
(
"se-resnext"
)
train_reader
=
paddle
.
batch
(
paddle
.
dataset
.
flowers
.
test
(
use_xmap
=
False
),
batch_size
=
2
,
drop_last
=
True
)
opt
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
1e-3
)
return
model
,
train_reader
,
opt
def
run_one_loop
(
self
,
model
,
opt
,
data
):
bs
=
len
(
data
)
dy_x_data
=
np
.
array
([
x
[
0
].
reshape
(
3
,
224
,
224
)
for
x
in
data
]).
astype
(
'float32'
)
y_data
=
np
.
array
([
x
[
1
]
for
x
in
data
]).
astype
(
'int64'
).
reshape
(
bs
,
1
)
img
=
to_variable
(
dy_x_data
)
label
=
to_variable
(
y_data
)
label
.
stop_gradient
=
True
loss
=
model
(
img
,
label
)
return
loss
if
__name__
==
"__main__"
:
runtime_main
(
TestSeResNeXt
)
python/paddle/fluid/tests/unittests/test_dist_base.py
浏览文件 @
02175555
...
...
@@ -31,7 +31,7 @@ import paddle.fluid.dygraph as dygraph
from
paddle.fluid.dygraph.base
import
to_variable
from
paddle.fluid.dygraph.parallel
import
DataParallel
RUN_STEP
=
10
RUN_STEP
=
5
DEFAULT_BATCH_SIZE
=
2
...
...
@@ -200,6 +200,7 @@ class TestParallelDyGraphRunnerBase(object):
"train_one_loop should be implemented by the child classes."
)
def
run_trainer
(
self
,
args
):
seed
=
90
device_id
=
int
(
os
.
getenv
(
"FLAGS_selected_gpus"
,
"0"
))
place
=
fluid
.
CUDAPlace
(
device_id
)
...
...
@@ -217,32 +218,35 @@ class TestParallelDyGraphRunnerBase(object):
with
fluid
.
dygraph
.
guard
(
place
):
fluid
.
default_startup_program
().
random_seed
=
seed
fluid
.
default_main_program
().
random_seed
=
seed
np
.
random
.
seed
(
seed
)
import
random
random
.
seed
=
seed
model
,
train_reader
,
opt
=
self
.
get_model
()
nranks
=
len
(
args
.
endpoints
.
split
(
","
))
if
args
.
endpoints
else
1
if
args
.
update_method
==
"nccl2"
:
sys
.
stderr
.
write
(
""
)
model
=
dygraph
.
parallel
.
DataParallel
(
model
)
strategy
=
dygraph
.
parallel
.
ParallelStrategy
()
strategy
.
nranks
=
nranks
strategy
.
local_rank
=
args
.
trainer_id
strategy
.
trainer_endpoints
=
args
.
endpoints
.
split
(
","
)
strategy
.
current_endpoint
=
args
.
current_endpoint
dygraph
.
parallel
.
prepare_context
(
strategy
)
model
=
dygraph
.
parallel
.
DataParallel
(
model
,
strategy
)
out_losses
=
[]
for
step_id
,
data
in
enumerate
(
train_reader
()):
data
=
_get_data
(
data
)
if
step_id
==
RUN_STEP
:
break
loss
=
self
.
run_one_loop
(
model
,
opt
,
data
)
out_losses
.
append
(
loss
.
numpy
())
# FIXME(Yancey1989): scale the loss inplace
loss
.
stop_gradient
=
True
loss_scale
=
to_variable
(
np
.
array
([
nranks
]).
astype
(
"float32"
))
loss
=
loss
/
loss_scale
# FIXME(Yancey1989): scale the loss inplace
if
args
.
update_method
==
"nccl2"
:
loss
=
model
.
scale_loss
(
loss
)
out_losses
.
append
(
loss
.
numpy
())
loss
.
backward
()
if
args
.
update_method
==
"nccl2"
:
model
.
apply_collective_grads
()
opt
.
minimize
(
loss
)
model
.
clear_gradients
()
...
...
@@ -663,9 +667,6 @@ class TestDistBase(unittest.TestCase):
local_loss
=
local_losses
[
step_id
]
tr0_loss
=
tr0_losses
[
step_id
]
tr1_loss
=
tr1_losses
[
step_id
]
dist_loss
=
(
np
.
array
([
tr0_loss
])
+
np
.
array
([
tr1_loss
]))
if
not
self
.
_dygraph
:
# Parallel DyGraph already scaled the loss in training
dist_loss
=
dist_loss
/
2
dist_loss
=
(
np
.
array
([
tr0_loss
])
+
np
.
array
([
tr1_loss
]))
/
2
print
(
"======="
,
local_loss
,
":"
,
dist_loss
[
0
],
"======="
)
self
.
assertAlmostEqual
(
local_loss
,
dist_loss
[
0
],
delta
=
delta
)
python/paddle/fluid/tests/unittests/test_parallel_dygraph_mnist.py
浏览文件 @
02175555
...
...
@@ -15,6 +15,7 @@
from
__future__
import
print_function
import
unittest
from
test_dist_base
import
TestDistBase
import
paddle.fluid
as
fluid
class
TestParallelDygraphMnist
(
TestDistBase
):
...
...
@@ -24,8 +25,8 @@ class TestParallelDygraphMnist(TestDistBase):
self
.
_dygraph
=
True
def
test_mnist
(
self
):
self
.
check_with_place
(
"parallel_dygraph_mnist.py"
,
delta
=
1e-5
,
check_error_log
=
True
)
if
fluid
.
core
.
is_compiled_with_cuda
():
self
.
check_with_place
(
"parallel_dygraph_mnist.py"
,
delta
=
1e-5
)
if
__name__
==
"__main__"
:
...
...
python/paddle/fluid/tests/unittests/test_parallel_dygraph_se_resnext.py
0 → 100644
浏览文件 @
02175555
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
from
test_dist_base
import
TestDistBase
import
paddle.fluid
as
fluid
class
TestParallelDygraphSeResNeXt
(
TestDistBase
):
def
_setup_config
(
self
):
self
.
_sync_mode
=
False
self
.
_nccl2_mode
=
True
self
.
_dygraph
=
True
def
test_se_resnext
(
self
):
# TODO(Yancey1989): BN and Dropout is related with batchsize, so the delta is the 1,
# try to remove the BN and Dropout in the network and using delta = 1e-5
if
fluid
.
core
.
is_compiled_with_cuda
():
self
.
check_with_place
(
"parallel_dygraph_se_resnext.py"
,
delta
=
1
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录