Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
633756ad
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
633756ad
编写于
2月 20, 2018
作者:
H
helinwang
提交者:
GitHub
2月 20, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #8361 from tonyyang-svail/backward_on_parallel_do
Backward on parallel do using nccl
上级
a040239d
4b957af2
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
191 addition
and
41 deletion
+191
-41
paddle/fluid/framework/executor.cc
paddle/fluid/framework/executor.cc
+5
-4
paddle/fluid/framework/framework.proto
paddle/fluid/framework/framework.proto
+1
-0
paddle/fluid/operators/nccl_op.cc
paddle/fluid/operators/nccl_op.cc
+37
-9
paddle/fluid/operators/parallel_do_op.cc
paddle/fluid/operators/parallel_do_op.cc
+24
-2
paddle/fluid/pybind/protobuf.cc
paddle/fluid/pybind/protobuf.cc
+2
-1
python/paddle/v2/fluid/backward.py
python/paddle/v2/fluid/backward.py
+94
-14
python/paddle/v2/fluid/framework.py
python/paddle/v2/fluid/framework.py
+1
-1
python/paddle/v2/fluid/layers/control_flow.py
python/paddle/v2/fluid/layers/control_flow.py
+4
-2
python/paddle/v2/fluid/optimizer.py
python/paddle/v2/fluid/optimizer.py
+1
-1
python/paddle/v2/fluid/tests/test_error_clip.py
python/paddle/v2/fluid/tests/test_error_clip.py
+1
-1
python/paddle/v2/fluid/tests/unittests/test_parallel_op.py
python/paddle/v2/fluid/tests/unittests/test_parallel_op.py
+21
-6
未找到文件。
paddle/fluid/framework/executor.cc
浏览文件 @
633756ad
...
@@ -55,11 +55,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
...
@@ -55,11 +55,13 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
var
->
GetMutable
<
platform
::
PlaceList
>
();
var
->
GetMutable
<
platform
::
PlaceList
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
}
else
if
(
var_type
==
proto
::
VarType
::
READER
)
{
var
->
GetMutable
<
ReaderHolder
>
();
var
->
GetMutable
<
ReaderHolder
>
();
}
else
if
(
var_type
==
proto
::
VarType
::
NCCL_COM
)
{
// GetMutable will be called in ncclInit
}
else
{
}
else
{
PADDLE_THROW
(
PADDLE_THROW
(
"Variable type %d is not in "
"Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER]"
,
"LOD_RANK_TABLE, PLACE_LIST, READER
, NCCL_COM
]"
,
var_type
);
var_type
);
}
}
}
}
...
@@ -120,14 +122,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
...
@@ -120,14 +122,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
for
(
auto
&
op_desc
:
block
.
AllOps
())
{
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
*
op_desc
);
auto
op
=
paddle
::
framework
::
OpRegistry
::
CreateOp
(
*
op_desc
);
VLOG
(
4
)
<<
op
->
DebugStringEx
(
local_scope
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
RecordEvent
record_event
(
op
->
Type
(),
pool
.
Get
(
place_
));
platform
::
RecordEvent
record_event
(
op
->
Type
(),
pool
.
Get
(
place_
));
VLOG
(
3
)
<<
place_
<<
" "
<<
op
->
DebugStringEx
(
local_scope
);
op
->
Run
(
*
local_scope
,
place_
);
op
->
Run
(
*
local_scope
,
place_
);
// Wait current device context.
VLOG
(
3
)
<<
op
->
DebugStringEx
(
local_scope
);
if
(
FLAGS_benchmark
)
{
if
(
FLAGS_benchmark
)
{
VLOG
(
2
)
<<
"Memory used after operator "
+
op
->
Type
()
+
" running: "
VLOG
(
2
)
<<
"Memory used after operator "
+
op
->
Type
()
+
" running: "
<<
memory
::
memory_usage
(
place_
);
<<
memory
::
memory_usage
(
place_
);
...
...
paddle/fluid/framework/framework.proto
浏览文件 @
633756ad
...
@@ -113,6 +113,7 @@ message VarType {
...
@@ -113,6 +113,7 @@ message VarType {
PLACE_LIST
=
14
;
PLACE_LIST
=
14
;
READER
=
15
;
READER
=
15
;
CHANNEL
=
16
;
CHANNEL
=
16
;
NCCL_COM
=
17
;
}
}
required
Type
type
=
1
;
required
Type
type
=
1
;
...
...
paddle/fluid/operators/nccl_op.cc
浏览文件 @
633756ad
...
@@ -14,10 +14,13 @@ limitations under the License. */
...
@@ -14,10 +14,13 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
static
constexpr
char
kParallelScopes
[]
=
"parallel_scopes"
;
// NCCLinitOp
// NCCLinitOp
class
NCCLInitOp
:
public
framework
::
OperatorBase
{
class
NCCLInitOp
:
public
framework
::
OperatorBase
{
public:
public:
...
@@ -29,11 +32,22 @@ class NCCLInitOp : public framework::OperatorBase {
...
@@ -29,11 +32,22 @@ class NCCLInitOp : public framework::OperatorBase {
private:
private:
void
RunImpl
(
const
framework
::
Scope
&
scope
,
void
RunImpl
(
const
framework
::
Scope
&
scope
,
const
platform
::
Place
&
place
)
const
override
{
const
platform
::
Place
&
place
)
const
override
{
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
Input
(
kParallelScopes
)),
"Can not find variable '%s' in the scope."
,
kParallelScopes
);
const
auto
&
name
=
Output
(
"Communicator"
);
const
auto
&
name
=
Output
(
"Communicator"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
name
),
PADDLE_ENFORCE_NOT_NULL
(
scope
.
FindVar
(
name
),
"Can not find variable '%s' in the scope."
,
name
);
"Can not find variable '%s' in the scope."
,
name
);
std
::
vector
<
int
>
gpus
=
Attr
<
std
::
vector
<
int
>>
(
"gpus"
);
// A parallel do may not use all the gpus. For example, the batch size is 7
PADDLE_ENFORCE
(
!
gpus
.
empty
(),
"Attr(gpus) should not be empty."
);
// in the last batch while we have 8 gpu. In this case, parallel_do will
// create 7 parallel scopes, so should ncclInitOp create 7 gpu peers
auto
&
parallel_scopes
=
scope
.
FindVar
(
Input
(
kParallelScopes
))
->
Get
<
std
::
vector
<
framework
::
Scope
*>>
();
std
::
vector
<
int
>
gpus
(
parallel_scopes
.
size
());
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
parallel_scopes
.
size
());
++
i
)
{
gpus
[
i
]
=
i
;
}
PADDLE_ENFORCE
(
!
gpus
.
empty
(),
"NCCL init with 0 gpus."
);
if
(
scope
.
FindVar
(
name
)
==
nullptr
)
{
if
(
scope
.
FindVar
(
name
)
==
nullptr
)
{
PADDLE_THROW
(
"Output(Communicator) is needed for ncclInit operator."
);
PADDLE_THROW
(
"Output(Communicator) is needed for ncclInit operator."
);
...
@@ -45,17 +59,29 @@ class NCCLInitOp : public framework::OperatorBase {
...
@@ -45,17 +59,29 @@ class NCCLInitOp : public framework::OperatorBase {
}
}
};
};
class
NCCLInitOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"Communicator"
).
front
();
auto
&
out_var
=
block
->
FindRecursiveOrCreateVar
(
out_var_name
);
auto
var_type
=
framework
::
proto
::
VarType
::
NCCL_COM
;
out_var
.
SetType
(
var_type
);
}
};
class
NCCLInitOpShapeInference
:
public
framework
::
InferShapeBase
{
public:
void
operator
()(
framework
::
InferShapeContext
*
ctx
)
const
override
{}
};
class
NCCLInitOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
class
NCCLInitOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
public:
NCCLInitOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
NCCLInitOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
kParallelScopes
,
"The working place of parallel do."
);
AddOutput
(
"Communicator"
,
AddOutput
(
"Communicator"
,
"Create Communicator for communicating between gpus"
);
"Create Communicator for communicating between gpus"
);
AddAttr
<
std
::
vector
<
int
>>
(
"gpus"
,
"(vector<int>) GPU id lists"
);
AddAttr
<
int
>
(
"dtype"
,
"(int, default 5 (FP32)) "
"Output data type"
)
.
SetDefault
(
framework
::
proto
::
VarType
::
FP32
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
NCCLInit Operator.
NCCLInit Operator.
...
@@ -78,7 +104,7 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
...
@@ -78,7 +104,7 @@ class NCCLAllReduceOp : public framework::OperatorWithKernel {
ctx
->
HasInput
(
"Communicator"
),
ctx
->
HasInput
(
"Communicator"
),
" Input(Communicator) of AllReduce op input should not be NULL"
);
" Input(Communicator) of AllReduce op input should not be NULL"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"
Input(X) of AllReduce op in
put should not be NULL"
);
"
Output(Out) of AllReduce op out
put should not be NULL"
);
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
auto
x_dims
=
ctx
->
GetInputsDim
(
"X"
);
...
@@ -215,7 +241,9 @@ Bcast the tensors.
...
@@ -215,7 +241,9 @@ Bcast the tensors.
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
ncclInit
,
ops
::
NCCLInitOp
,
REGISTER_OPERATOR
(
ncclInit
,
ops
::
NCCLInitOp
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
NCCLInitOpMaker
);
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
NCCLInitOpMaker
,
ops
::
NCCLInitOpVarTypeInference
,
ops
::
NCCLInitOpShapeInference
);
REGISTER_OP_WITHOUT_GRADIENT
(
ncclAllReduce
,
ops
::
NCCLAllReduceOp
,
REGISTER_OP_WITHOUT_GRADIENT
(
ncclAllReduce
,
ops
::
NCCLAllReduceOp
,
ops
::
NCCLAllReduceOpMaker
);
ops
::
NCCLAllReduceOpMaker
);
...
...
paddle/fluid/operators/parallel_do_op.cc
浏览文件 @
633756ad
...
@@ -30,6 +30,7 @@ static constexpr char kOutputs[] = "outputs";
...
@@ -30,6 +30,7 @@ static constexpr char kOutputs[] = "outputs";
static
constexpr
char
kParallelScopes
[]
=
"parallel_scopes"
;
static
constexpr
char
kParallelScopes
[]
=
"parallel_scopes"
;
static
constexpr
char
kParallelBlock
[]
=
"sub_block"
;
static
constexpr
char
kParallelBlock
[]
=
"sub_block"
;
static
constexpr
char
kUseNCCL
[]
=
"use_nccl"
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
SelectedRows
=
framework
::
SelectedRows
;
using
SelectedRows
=
framework
::
SelectedRows
;
...
@@ -194,6 +195,8 @@ class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -194,6 +195,8 @@ class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
kOutputs
,
""
).
AsDuplicable
();
AddOutput
(
kOutputs
,
""
).
AsDuplicable
();
AddOutput
(
kParallelScopes
,
""
);
AddOutput
(
kParallelScopes
,
""
);
AddAttr
<
framework
::
BlockDesc
*>
(
kParallelBlock
,
""
);
AddAttr
<
framework
::
BlockDesc
*>
(
kParallelBlock
,
""
);
AddAttr
<
bool
>
(
kUseNCCL
,
"true if we use nccl on backward"
)
.
SetDefault
(
false
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
ParallelDo Operator.
ParallelDo Operator.
)DOC"
);
)DOC"
);
...
@@ -216,7 +219,6 @@ class ParallelDoGradOp : public framework::OperatorBase {
...
@@ -216,7 +219,6 @@ class ParallelDoGradOp : public framework::OperatorBase {
auto
&
sub_scopes
=
scope
.
FindVar
(
Input
(
kParallelScopes
))
auto
&
sub_scopes
=
scope
.
FindVar
(
Input
(
kParallelScopes
))
->
Get
<
std
::
vector
<
framework
::
Scope
*>>
();
->
Get
<
std
::
vector
<
framework
::
Scope
*>>
();
auto
&
places
=
scope
.
FindVar
(
Input
(
kPlaces
))
->
Get
<
platform
::
PlaceList
>
();
auto
&
places
=
scope
.
FindVar
(
Input
(
kPlaces
))
->
Get
<
platform
::
PlaceList
>
();
// feed output@grad
// feed output@grad
...
@@ -243,7 +245,24 @@ class ParallelDoGradOp : public framework::OperatorBase {
...
@@ -243,7 +245,24 @@ class ParallelDoGradOp : public framework::OperatorBase {
}
}
WaitOnPlaces
(
places
);
WaitOnPlaces
(
places
);
AccumulateGrad
(
scope
,
place
,
sub_scopes
,
places
);
// NCCL allreduce op will be added by backward,
// so no need to explicitly accumulate grad
if
(
!
(
Attr
<
bool
>
(
kUseNCCL
)))
{
AccumulateGrad
(
scope
,
place
,
sub_scopes
,
places
);
}
else
{
for
(
auto
&
place
:
places
)
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
place
),
"NCCL only supports cuda place"
);
}
}
for
(
auto
&
s
:
Outputs
(
framework
::
GradVarName
(
kParameters
)))
{
if
(
s
==
"@EMPTY@"
)
{
continue
;
}
VLOG
(
3
)
<<
"Moving "
<<
s
;
CopyOrShare
(
*
sub_scopes
[
0
]
->
FindVar
(
s
),
place
,
scope
.
FindVar
(
s
));
}
WaitOnPlaces
(
places
);
}
}
void
AccumulateGrad
(
const
framework
::
Scope
&
scope
,
void
AccumulateGrad
(
const
framework
::
Scope
&
scope
,
...
@@ -251,6 +270,9 @@ class ParallelDoGradOp : public framework::OperatorBase {
...
@@ -251,6 +270,9 @@ class ParallelDoGradOp : public framework::OperatorBase {
const
std
::
vector
<
framework
::
Scope
*>
&
sub_scopes
,
const
std
::
vector
<
framework
::
Scope
*>
&
sub_scopes
,
const
platform
::
PlaceList
&
places
)
const
{
const
platform
::
PlaceList
&
places
)
const
{
for
(
auto
&
s
:
Outputs
(
framework
::
GradVarName
(
kParameters
)))
{
for
(
auto
&
s
:
Outputs
(
framework
::
GradVarName
(
kParameters
)))
{
if
(
s
==
"@EMPTY@"
)
{
continue
;
}
VLOG
(
3
)
<<
"Accumulating "
<<
s
;
VLOG
(
3
)
<<
"Accumulating "
<<
s
;
if
(
s
==
framework
::
kEmptyVarName
)
continue
;
if
(
s
==
framework
::
kEmptyVarName
)
continue
;
std
::
string
tmp_name
;
std
::
string
tmp_name
;
...
...
paddle/fluid/pybind/protobuf.cc
浏览文件 @
633756ad
...
@@ -239,7 +239,8 @@ void BindVarDsec(py::module &m) {
...
@@ -239,7 +239,8 @@ void BindVarDsec(py::module &m) {
.
value
(
"LOD_RANK_TABLE"
,
proto
::
VarType
::
LOD_RANK_TABLE
)
.
value
(
"LOD_RANK_TABLE"
,
proto
::
VarType
::
LOD_RANK_TABLE
)
.
value
(
"LOD_TENSOR_ARRAY"
,
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
.
value
(
"LOD_TENSOR_ARRAY"
,
proto
::
VarType
::
LOD_TENSOR_ARRAY
)
.
value
(
"PLACE_LIST"
,
proto
::
VarType
::
PLACE_LIST
)
.
value
(
"PLACE_LIST"
,
proto
::
VarType
::
PLACE_LIST
)
.
value
(
"READER"
,
proto
::
VarType
::
READER
);
.
value
(
"READER"
,
proto
::
VarType
::
READER
)
.
value
(
"NCCL_COM"
,
proto
::
VarType
::
NCCL_COM
);
}
}
void
BindOpDesc
(
py
::
module
&
m
)
{
void
BindOpDesc
(
py
::
module
&
m
)
{
...
...
python/paddle/v2/fluid/backward.py
浏览文件 @
633756ad
...
@@ -199,12 +199,76 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
...
@@ -199,12 +199,76 @@ def _remove_no_grad_branch_(op_descs, no_grad_set):
return
op_descs
return
op_descs
import
proto.framework_pb2
as
framework_pb2
def
serialize_op_decs
(
op_desc
):
protostr
=
op_desc
.
serialize_to_string
()
proto
=
framework_pb2
.
OpDesc
.
FromString
(
str
(
protostr
))
return
proto
.
__str__
()
def
_callback_lookup_
(
op
):
"""
Only used in _append_backward_ops_
Build and returns a callback function for certain op. For example
parallel_do: AllReduce
:param op:
:return: callback function
"""
if
op
.
type
==
'parallel_do'
and
op
.
attr
(
'use_nccl'
):
param_names
=
set
(
op
.
input
(
'parameters'
))
param_grad_names
=
[
n
+
"@GRAD"
for
n
in
param_names
]
class
ParallelDoCallBack
(
object
):
def
__init__
(
self
,
param_grad_names
,
parallel_scopes_name
):
self
.
has_inserted_nccl_init
=
False
self
.
param_grad_names
=
param_grad_names
self
.
parallel_scopes_name
=
parallel_scopes_name
def
__call__
(
self
,
block
,
context
):
if
not
self
.
has_inserted_nccl_init
:
op_desc
=
_create_op_desc_
(
"ncclInit"
,
{
"parallel_scopes"
:
self
.
parallel_scopes_name
},
{
"Communicator"
:
[
'nccl_com__do_not_change_'
]},
{})
block
.
program
.
global_block
().
desc
.
append_op
().
copy_from
(
op_desc
)
self
.
has_inserted_nccl_init
=
True
current_op_desc
=
context
[
"__current_op_desc__"
]
for
o_param
in
current_op_desc
.
output_names
():
for
o_argu
in
current_op_desc
.
output
(
o_param
):
if
o_argu
in
self
.
param_grad_names
:
allreduce_out_name
=
o_argu
+
"__nccl_all_reduce__"
op_desc
=
_create_op_desc_
(
"ncclAllReduce"
,
{
"X"
:
[
o_argu
],
"Communicator"
:
[
'nccl_com__do_not_change_'
]
},
{
"Out"
:
[
allreduce_out_name
]},
{
"reduction"
:
"ncclSum"
})
block
.
desc
.
append_op
().
copy_from
(
op_desc
)
op_desc
=
_create_op_desc_
(
"assign"
,
{
"X"
:
[
allreduce_out_name
]},
{
"Out"
:
[
o_argu
]},
{})
block
.
desc
.
append_op
().
copy_from
(
op_desc
)
return
ParallelDoCallBack
(
param_grad_names
,
op
.
output
(
"parallel_scopes"
))
else
:
return
None
def
_append_backward_ops_
(
block
,
def
_append_backward_ops_
(
block
,
ops
,
ops
,
target_block
,
target_block
,
no_grad_dict
,
no_grad_dict
,
grad_to_var
,
grad_to_var
,
callback
=
None
):
callback
s
=
None
):
"""
"""
Create all grad ops, and insert them into given block
Create all grad ops, and insert them into given block
...
@@ -220,14 +284,11 @@ def _append_backward_ops_(block,
...
@@ -220,14 +284,11 @@ def _append_backward_ops_(block,
val(str): corresponding forward variable name
val(str): corresponding forward variable name
callback(callable object): a callable object used to decorate new generated grad ops
callback(callable object): a callable object used to decorate new generated grad ops
"""
"""
if
callback
is
None
:
if
callbacks
is
not
None
:
assert
(
isinstance
(
callbacks
,
list
))
def
empty_callback
(
block
,
context
):
for
cb
in
callbacks
:
pass
if
not
hasattr
(
cb
,
'__call__'
):
raise
ValueError
(
"'callback' must be a callable object."
)
callback
=
empty_callback
elif
not
hasattr
(
callback
,
'__call__'
):
raise
ValueError
(
"'callback' must be a callable object."
)
# grad_op_descs holds created grad_op, and will be appended to target_block
# grad_op_descs holds created grad_op, and will be appended to target_block
grad_op_descs
=
[]
grad_op_descs
=
[]
...
@@ -238,8 +299,17 @@ def _append_backward_ops_(block,
...
@@ -238,8 +299,17 @@ def _append_backward_ops_(block,
if
op
.
has_attr
(
"sub_block"
):
if
op
.
has_attr
(
"sub_block"
):
sub_block
=
program
.
block
(
op
.
block_attr
(
"sub_block"
))
sub_block
=
program
.
block
(
op
.
block_attr
(
"sub_block"
))
grad_sub_block
=
program
.
create_block
(
parent_idx
=
sub_block
.
idx
)
grad_sub_block
=
program
.
create_block
(
parent_idx
=
sub_block
.
idx
)
_append_backward_ops_
(
sub_block
,
sub_block
.
ops
,
grad_sub_block
,
cb
=
_callback_lookup_
(
op
)
no_grad_dict
,
grad_to_var
)
if
cb
is
not
None
:
if
callbacks
is
None
:
new_callbacks
=
[
cb
]
else
:
new_callbacks
=
callbacks
+
[
_callback_lookup_
(
op
)]
_append_backward_ops_
(
sub_block
,
sub_block
.
ops
,
grad_sub_block
,
no_grad_dict
,
grad_to_var
,
new_callbacks
)
else
:
_append_backward_ops_
(
sub_block
,
sub_block
.
ops
,
grad_sub_block
,
no_grad_dict
,
grad_to_var
,
callbacks
)
grad_sub_block_list
.
append
(
grad_sub_block
.
desc
)
grad_sub_block_list
.
append
(
grad_sub_block
.
desc
)
# Getting op's corresponding grad_op
# Getting op's corresponding grad_op
...
@@ -258,7 +328,11 @@ def _append_backward_ops_(block,
...
@@ -258,7 +328,11 @@ def _append_backward_ops_(block,
for
op_desc
in
grad_op_descs
:
for
op_desc
in
grad_op_descs
:
new_op_desc
=
target_block
.
desc
.
append_op
()
new_op_desc
=
target_block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op_desc
)
new_op_desc
.
copy_from
(
op_desc
)
callback
(
block
=
target_block
,
context
=
grad_to_var
)
grad_to_var
[
"__current_op_desc__"
]
=
new_op_desc
if
callbacks
is
not
None
:
assert
(
isinstance
(
callbacks
,
list
))
for
cb
in
callbacks
:
cb
(
block
=
target_block
,
context
=
grad_to_var
)
def
_append_backward_vars_
(
block
,
start_op_idx
,
grad_to_var
,
grad_info_map
):
def
_append_backward_vars_
(
block
,
start_op_idx
,
grad_to_var
,
grad_info_map
):
...
@@ -296,6 +370,9 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
...
@@ -296,6 +370,9 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
# infer_shape and infer_type
# infer_shape and infer_type
op_desc
.
infer_var_type
(
block
.
desc
)
op_desc
.
infer_var_type
(
block
.
desc
)
op_desc
.
infer_shape
(
block
.
desc
)
op_desc
.
infer_shape
(
block
.
desc
)
# ncclInit dones't need to set data_type
if
op_desc
.
type
()
==
'ncclInit'
:
continue
for
arg
in
op_desc
.
output_arg_names
():
for
arg
in
op_desc
.
output_arg_names
():
if
arg
in
new_vars
:
if
arg
in
new_vars
:
_infer_var_data_type_
(
arg
,
block
)
_infer_var_data_type_
(
arg
,
block
)
...
@@ -335,7 +412,8 @@ def _get_stop_gradients_(program):
...
@@ -335,7 +412,8 @@ def _get_stop_gradients_(program):
return
no_grad_dict
return
no_grad_dict
def
append_backward
(
loss
,
parameter_list
=
None
,
no_grad_set
=
None
,
callback
=
None
):
def
append_backward
(
loss
,
parameter_list
=
None
,
no_grad_set
=
None
,
callbacks
=
None
):
"""
"""
Append backward part to main_program
Append backward part to main_program
...
@@ -351,6 +429,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
...
@@ -351,6 +429,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
(list[(Variable,Variable)]): list of (parameter, gradient) pair.
(list[(Variable,Variable)]): list of (parameter, gradient) pair.
"""
"""
assert
isinstance
(
loss
,
framework
.
Variable
)
assert
isinstance
(
loss
,
framework
.
Variable
)
if
callbacks
is
not
None
:
isinstance
(
callbacks
,
list
)
program
=
loss
.
block
.
program
program
=
loss
.
block
.
program
if
no_grad_set
is
None
:
if
no_grad_set
is
None
:
...
@@ -378,7 +458,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
...
@@ -378,7 +458,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None, callback=None):
no_grad_dict
[
0
].
update
(
map
(
_append_grad_suffix_
,
block_no_grad_set
))
no_grad_dict
[
0
].
update
(
map
(
_append_grad_suffix_
,
block_no_grad_set
))
_append_backward_ops_
(
root_block
,
op_path
,
root_block
,
no_grad_dict
,
_append_backward_ops_
(
root_block
,
op_path
,
root_block
,
no_grad_dict
,
grad_to_var
,
callback
)
grad_to_var
,
callback
s
)
# Because calc_gradient may be called multiple times,
# Because calc_gradient may be called multiple times,
# we need rename the internal gradient variables so that they have
# we need rename the internal gradient variables so that they have
...
...
python/paddle/v2/fluid/framework.py
浏览文件 @
633756ad
...
@@ -490,7 +490,7 @@ class Operator(object):
...
@@ -490,7 +490,7 @@ class Operator(object):
'feed'
,
'fetch'
,
'save'
,
'load'
,
'recurrent'
,
'feed'
,
'fetch'
,
'save'
,
'load'
,
'recurrent'
,
'rnn_memory_helper_grad'
,
'conditional_block'
,
'while'
,
'send'
,
'rnn_memory_helper_grad'
,
'conditional_block'
,
'while'
,
'send'
,
'recv'
,
'listen_and_serv'
,
'parallel_do'
,
'save_combine'
,
'recv'
,
'listen_and_serv'
,
'parallel_do'
,
'save_combine'
,
'load_combine'
'load_combine'
,
'ncclInit'
}
}
if
type
not
in
no_kernel_op_set
:
if
type
not
in
no_kernel_op_set
:
self
.
desc
.
infer_var_type
(
self
.
block
.
desc
)
self
.
desc
.
infer_var_type
(
self
.
block
.
desc
)
...
...
python/paddle/v2/fluid/layers/control_flow.py
浏览文件 @
633756ad
...
@@ -237,12 +237,13 @@ class ParallelDo(object):
...
@@ -237,12 +237,13 @@ class ParallelDo(object):
ParallelDo class is used to create a ParallelDo.
ParallelDo class is used to create a ParallelDo.
"""
"""
def
__init__
(
self
,
places
,
name
=
None
):
def
__init__
(
self
,
places
,
use_nccl
=
False
,
name
=
None
):
self
.
helper
=
LayerHelper
(
"parallel_do"
,
name
=
name
)
self
.
helper
=
LayerHelper
(
"parallel_do"
,
name
=
name
)
self
.
inputs
=
[]
self
.
inputs
=
[]
self
.
places
=
places
self
.
places
=
places
self
.
outputs
=
[]
self
.
outputs
=
[]
self
.
status
=
StaticRNN
.
BEFORE_RNN_BLOCK
self
.
status
=
StaticRNN
.
BEFORE_RNN_BLOCK
self
.
use_nccl
=
use_nccl
def
do
(
self
):
def
do
(
self
):
return
BlockGuardWithCompletion
(
self
)
return
BlockGuardWithCompletion
(
self
)
...
@@ -325,7 +326,8 @@ class ParallelDo(object):
...
@@ -325,7 +326,8 @@ class ParallelDo(object):
},
},
outputs
=
{
'outputs'
:
outputs
,
outputs
=
{
'outputs'
:
outputs
,
'parallel_scopes'
:
[
step_scope
]},
'parallel_scopes'
:
[
step_scope
]},
attrs
=
{
'sub_block'
:
current_block
})
attrs
=
{
'sub_block'
:
current_block
,
'use_nccl'
:
self
.
use_nccl
})
class
BlockGuardWithCompletion
(
BlockGuard
):
class
BlockGuardWithCompletion
(
BlockGuard
):
...
...
python/paddle/v2/fluid/optimizer.py
浏览文件 @
633756ad
...
@@ -225,7 +225,7 @@ class Optimizer(object):
...
@@ -225,7 +225,7 @@ class Optimizer(object):
`create_optimization_pass()` into one.
`create_optimization_pass()` into one.
"""
"""
params_grads
=
append_backward
(
loss
,
parameter_list
,
no_grad_set
,
params_grads
=
append_backward
(
loss
,
parameter_list
,
no_grad_set
,
error_clip_callback
)
[
error_clip_callback
]
)
params_grads
=
append_gradient_clip_ops
(
params_grads
)
params_grads
=
append_gradient_clip_ops
(
params_grads
)
...
...
python/paddle/v2/fluid/tests/test_error_clip.py
浏览文件 @
633756ad
...
@@ -43,7 +43,7 @@ prog_clip.block(0).var(hidden1.name).set_error_clip(
...
@@ -43,7 +43,7 @@ prog_clip.block(0).var(hidden1.name).set_error_clip(
avg_cost_clip
=
prog_clip
.
block
(
0
).
var
(
avg_cost
.
name
)
avg_cost_clip
=
prog_clip
.
block
(
0
).
var
(
avg_cost
.
name
)
fluid
.
backward
.
append_backward
(
loss
=
avg_cost
)
fluid
.
backward
.
append_backward
(
loss
=
avg_cost
)
fluid
.
backward
.
append_backward
(
fluid
.
backward
.
append_backward
(
loss
=
avg_cost_clip
,
callback
=
fluid
.
clip
.
error_clip_callback
)
loss
=
avg_cost_clip
,
callback
s
=
[
fluid
.
clip
.
error_clip_callback
]
)
hidden1_grad
=
prog
.
block
(
0
).
var
(
hidden1
.
name
+
"@GRAD"
)
hidden1_grad
=
prog
.
block
(
0
).
var
(
hidden1
.
name
+
"@GRAD"
)
hidden1_grad_clip
=
prog_clip
.
block
(
0
).
var
(
hidden1
.
name
+
"@GRAD"
)
hidden1_grad_clip
=
prog_clip
.
block
(
0
).
var
(
hidden1
.
name
+
"@GRAD"
)
...
...
python/paddle/v2/fluid/tests/unittests/test_parallel_op.py
浏览文件 @
633756ad
...
@@ -67,12 +67,25 @@ class BaseParallelForTest(unittest.TestCase):
...
@@ -67,12 +67,25 @@ class BaseParallelForTest(unittest.TestCase):
fetch
=
fetch
,
fetch
=
fetch
,
place
=
gpu
,
place
=
gpu
,
use_parallel
=
True
)
use_parallel
=
True
)
result_gpu_nccl
=
self
.
_run_test_impl_
(
callback
=
callback
,
feed
=
feed
,
fetch
=
fetch
,
place
=
gpu
,
use_parallel
=
True
,
use_nccl
=
True
)
self
.
_assert_same_
(
fetch
,
result_cpu
,
result_cpu_parallel
,
self
.
_assert_same_
(
fetch
,
result_cpu
,
result_cpu_parallel
,
result_gpu
,
result_gpu_parallel
)
result_gpu
,
result_gpu_parallel
,
result_gpu_nccl
)
else
:
else
:
self
.
_assert_same_
(
fetch
,
result_cpu
,
result_cpu_parallel
)
self
.
_assert_same_
(
fetch
,
result_cpu
,
result_cpu_parallel
)
def
_run_test_impl_
(
self
,
callback
,
feed
,
fetch
,
place
,
use_parallel
=
False
):
def
_run_test_impl_
(
self
,
callback
,
feed
,
fetch
,
place
,
use_parallel
=
False
,
use_nccl
=
False
):
"""
"""
Run a single test, returns the fetch values
Run a single test, returns the fetch values
Args:
Args:
...
@@ -96,7 +109,7 @@ class BaseParallelForTest(unittest.TestCase):
...
@@ -96,7 +109,7 @@ class BaseParallelForTest(unittest.TestCase):
# Automatically insert parallel do if use_parallel = True
# Automatically insert parallel do if use_parallel = True
if
use_parallel
:
if
use_parallel
:
places
=
fluid
.
layers
.
get_places
()
places
=
fluid
.
layers
.
get_places
()
pd
=
fluid
.
layers
.
ParallelDo
(
places
)
pd
=
fluid
.
layers
.
ParallelDo
(
places
,
use_nccl
=
use_nccl
)
data
=
next
(
generator
)
data
=
next
(
generator
)
if
isinstance
(
data
,
fluid
.
Variable
):
if
isinstance
(
data
,
fluid
.
Variable
):
...
@@ -137,7 +150,9 @@ class BaseParallelForTest(unittest.TestCase):
...
@@ -137,7 +150,9 @@ class BaseParallelForTest(unittest.TestCase):
"""
"""
def
_impl_
(
a
,
b
,
fetch_id
,
item_id
):
def
_impl_
(
a
,
b
,
fetch_id
,
item_id
):
item_str
=
[
'CPU'
,
'ParallelCPU'
,
'GPU'
,
'ParallelGPU'
]
item_str
=
[
'CPU'
,
'ParallelCPU'
,
'GPU'
,
'ParallelGPU'
,
'ParallelGPUNCCL'
]
flag
=
numpy
.
allclose
(
a
,
b
,
rtol
=
0.1
,
atol
=
1e-3
)
flag
=
numpy
.
allclose
(
a
,
b
,
rtol
=
0.1
,
atol
=
1e-3
)
self
.
assertTrue
(
flag
,
self
.
assertTrue
(
flag
,
"The {0} are different in {1}, {2} vs {3}"
.
format
(
"The {0} are different in {1}, {2} vs {3}"
.
format
(
...
@@ -198,5 +213,5 @@ class ParallelOpTestMultipleInput(BaseParallelForTest):
...
@@ -198,5 +213,5 @@ class ParallelOpTestMultipleInput(BaseParallelForTest):
fetch
=
[
'fc1.w@GRAD'
,
'fc2.w@GRAD'
,
'fc3.w@GRAD'
])
fetch
=
[
'fc1.w@GRAD'
,
'fc2.w@GRAD'
,
'fc3.w@GRAD'
])
#
if __name__ == '__main__':
if
__name__
==
'__main__'
:
#
unittest.main()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录