Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0a144ca1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0a144ca1
编写于
9月 22, 2022
作者:
L
Leo Chen
提交者:
GitHub
9月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
convert grad_merge_all_reduce in graph to program (#46353)
上级
173b39bb
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
132 addition
and
39 deletion
+132
-39
paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h
...fluid/framework/details/grad_merge_all_reduce_op_handle.h
+2
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-1
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
+23
-33
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+15
-4
paddle/fluid/framework/program_utils.cc
paddle/fluid/framework/program_utils.cc
+3
-1
paddle/fluid/operators/collective/c_allreduce_op.h
paddle/fluid/operators/collective/c_allreduce_op.h
+83
-0
paddle/fluid/operators/collective/c_allreduce_sum_op.cc
paddle/fluid/operators/collective/c_allreduce_sum_op.cc
+4
-0
未找到文件。
paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h
浏览文件 @
0a144ca1
...
@@ -99,6 +99,8 @@ class FusedGradMergeAllReduceOpHandle : public FusedAllReduceOpHandle {
...
@@ -99,6 +99,8 @@ class FusedGradMergeAllReduceOpHandle : public FusedAllReduceOpHandle {
std
::
string
Name
()
const
override
;
std
::
string
Name
()
const
override
;
std
::
string
GradMergeCondName
()
{
return
grad_merge_cond_name_
;
}
protected:
protected:
void
RunImpl
()
override
;
void
RunImpl
()
override
;
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
0a144ca1
...
@@ -19,7 +19,8 @@ cc_library(
...
@@ -19,7 +19,8 @@ cc_library(
cc_library
(
cc_library
(
graph_helper
graph_helper
SRCS graph_helper.cc
SRCS graph_helper.cc
DEPS graph program_utils scale_loss_grad_op_handle
)
DEPS graph program_utils scale_loss_grad_op_handle
grad_merge_all_reduce_op_handle
)
cc_library
(
cc_library
(
pass
pass
SRCS pass.cc
SRCS pass.cc
...
...
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
浏览文件 @
0a144ca1
...
@@ -559,38 +559,26 @@ class CoalesceGradTensorPass : public ir::Pass {
...
@@ -559,38 +559,26 @@ class CoalesceGradTensorPass : public ir::Pass {
all_persistable
=
false
;
all_persistable
=
false
;
}
}
}
}
VLOG
(
4
)
<<
"all_persistable:"
<<
all_persistable
;
if
(
all_persistable
)
{
VLOG
(
4
)
<<
"any_persistable:"
<<
all_persistable
;
// All grads are persistable, only need to be executed once at the
// NOTE. In scope_buffered_ssa_graph_executor, after each execution of
// beginning.
// DropScope(), non persistable vars will be Erase or Clear. So
result
->
Get
<
details
::
ProgramDescs
>
(
details
::
kStartupProgramDescs
)
// coalesce_tensor op needs to be executed again after the execution
.
emplace_back
();
// of DropScope().
ProgramDesc
&
program_desc
=
result
->
Get
<
details
::
ProgramDescs
>
(
details
::
kStartupProgramDescs
)
// we can make fused_output persistable, so the memeory is not cleared
.
back
();
// and coalesce_tensor op do nothing if the inputs are already continue.
auto
*
global_block
=
program_desc
.
MutableBlock
(
0
);
AppendAllocSpaceForVarsOp
(
params_name
,
result
->
Get
<
details
::
ProgramDescs
>
(
details
::
kProgramDescs
).
emplace_back
();
grads_name
,
ProgramDesc
&
program_desc
=
fused_var_name
,
result
->
Get
<
details
::
ProgramDescs
>
(
details
::
kProgramDescs
).
back
();
dtype
,
auto
*
global_block
=
program_desc
.
MutableBlock
(
0
);
all_persistable
,
AppendAllocSpaceForVarsOp
(
params_name
,
global_block
);
grads_name
,
}
else
{
fused_var_name
,
// NOTE. In scope_buffered_ssa_graph_executor, after each execution of
dtype
,
// DropScope(), non persistable vars will be Erase or Clear. So
any_persistable
,
// coalesce_tensor op needs to be executed again after the execution
global_block
);
// of DropScope().
result
->
Get
<
details
::
ProgramDescs
>
(
details
::
kProgramDescs
).
emplace_back
();
ProgramDesc
&
program_desc
=
result
->
Get
<
details
::
ProgramDescs
>
(
details
::
kProgramDescs
).
back
();
auto
*
global_block
=
program_desc
.
MutableBlock
(
0
);
AppendAllocSpaceForVarsOp
(
params_name
,
grads_name
,
fused_var_name
,
dtype
,
any_persistable
,
global_block
);
}
}
}
void
AppendAllocSpaceForVarsOp
(
const
std
::
vector
<
std
::
string
>
&
params_name
,
void
AppendAllocSpaceForVarsOp
(
const
std
::
vector
<
std
::
string
>
&
params_name
,
...
@@ -599,13 +587,15 @@ class CoalesceGradTensorPass : public ir::Pass {
...
@@ -599,13 +587,15 @@ class CoalesceGradTensorPass : public ir::Pass {
const
proto
::
VarType
::
Type
&
dtype
,
const
proto
::
VarType
::
Type
&
dtype
,
bool
persistable
,
bool
persistable
,
BlockDesc
*
global_block
)
const
{
BlockDesc
*
global_block
)
const
{
auto
fused_out_var
=
global_block
->
Var
(
fused_var_name
);
fused_out_var
->
SetPersistable
(
persistable
);
auto
op_desc
=
global_block
->
AppendOp
();
auto
op_desc
=
global_block
->
AppendOp
();
op_desc
->
SetType
(
"coalesce_tensor"
);
op_desc
->
SetType
(
"coalesce_tensor"
);
op_desc
->
SetInput
(
"Input"
,
params_name
);
op_desc
->
SetInput
(
"Input"
,
params_name
);
op_desc
->
SetOutput
(
"Output"
,
grads_name
);
op_desc
->
SetOutput
(
"Output"
,
grads_name
);
op_desc
->
SetOutput
(
"FusedOutput"
,
{
fused_var_name
});
op_desc
->
SetOutput
(
"FusedOutput"
,
{
fused_var_name
});
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
dtype
));
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
dtype
));
op_desc
->
SetAttr
(
"persist_output"
,
persistable
);
op_desc
->
SetAttr
(
"persist_output"
,
persistable
);
}
}
};
};
...
...
paddle/fluid/framework/ir/graph_helper.cc
浏览文件 @
0a144ca1
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <queue>
#include <queue>
#include <stack>
#include <stack>
#include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
...
@@ -519,11 +520,11 @@ static void ReplaceAllReduceOp(const Node &node,
...
@@ -519,11 +520,11 @@ static void ReplaceAllReduceOp(const Node &node,
desc2
.
SetType
(
"c_allreduce_sum"
);
desc2
.
SetType
(
"c_allreduce_sum"
);
if
(
node
.
IsWrappedBy
<
details
::
OpHandleBase
>
())
{
if
(
node
.
IsWrappedBy
<
details
::
OpHandleBase
>
())
{
details
::
OpHandleBase
&
op_hander
=
details
::
OpHandleBase
&
op_hand
l
er
=
const_cast
<
Node
*>
(
&
node
)
->
Wrapper
<
details
::
OpHandleBase
>
();
const_cast
<
Node
*>
(
&
node
)
->
Wrapper
<
details
::
OpHandleBase
>
();
// set inputs
// set inputs
auto
in_var_handles
=
op_hander
.
Inputs
();
auto
in_var_handles
=
op_hand
l
er
.
Inputs
();
std
::
vector
<
std
::
string
>
in_names
;
std
::
vector
<
std
::
string
>
in_names
;
for
(
const
auto
&
in
:
in_var_handles
)
{
for
(
const
auto
&
in
:
in_var_handles
)
{
if
(
dynamic_cast
<
details
::
DummyVarHandle
*>
(
in
)
!=
nullptr
)
{
if
(
dynamic_cast
<
details
::
DummyVarHandle
*>
(
in
)
!=
nullptr
)
{
...
@@ -543,7 +544,7 @@ static void ReplaceAllReduceOp(const Node &node,
...
@@ -543,7 +544,7 @@ static void ReplaceAllReduceOp(const Node &node,
desc2
.
SetInput
(
"X"
,
{
name
});
desc2
.
SetInput
(
"X"
,
{
name
});
// set outputs
// set outputs
auto
out_var_handles
=
op_hander
.
Outputs
();
auto
out_var_handles
=
op_hand
l
er
.
Outputs
();
std
::
vector
<
std
::
string
>
out_names
;
std
::
vector
<
std
::
string
>
out_names
;
for
(
const
auto
&
out
:
out_var_handles
)
{
for
(
const
auto
&
out
:
out_var_handles
)
{
if
(
dynamic_cast
<
details
::
DummyVarHandle
*>
(
out
)
!=
nullptr
)
{
if
(
dynamic_cast
<
details
::
DummyVarHandle
*>
(
out
)
!=
nullptr
)
{
...
@@ -554,9 +555,18 @@ static void ReplaceAllReduceOp(const Node &node,
...
@@ -554,9 +555,18 @@ static void ReplaceAllReduceOp(const Node &node,
desc2
.
SetOutput
(
"Out"
,
{
name
});
desc2
.
SetOutput
(
"Out"
,
{
name
});
int
ring_id
=
platform
::
NCCLCommContext
::
Instance
().
GetRingId
(
int
ring_id
=
platform
::
NCCLCommContext
::
Instance
().
GetRingId
(
dynamic_cast
<
details
::
NCCLOpHandleBase
*>
(
&
op_hander
)
->
GetComm
());
dynamic_cast
<
details
::
NCCLOpHandleBase
*>
(
&
op_hand
l
er
)
->
GetComm
());
desc2
.
SetAttr
(
"ring_id"
,
ring_id
);
desc2
.
SetAttr
(
"ring_id"
,
ring_id
);
desc2
.
SetAttr
(
"use_calc_stream"
,
true
);
desc2
.
SetAttr
(
"use_calc_stream"
,
true
);
// handle grad merge
if
(
dynamic_cast
<
details
::
FusedGradMergeAllReduceOpHandle
*>
(
&
op_handler
))
{
VLOG
(
4
)
<<
"FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum"
;
auto
cond_name
=
dynamic_cast
<
details
::
FusedGradMergeAllReduceOpHandle
*>
(
&
op_handler
)
->
GradMergeCondName
();
desc2
.
SetInput
(
"Cond"
,
{
cond_name
});
}
}
}
desc1
.
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
desc1
.
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
...
@@ -780,6 +790,7 @@ void GraphToProgram(const Graph &graph,
...
@@ -780,6 +790,7 @@ void GraphToProgram(const Graph &graph,
VLOG
(
8
)
<<
"Merge main programs"
;
VLOG
(
8
)
<<
"Merge main programs"
;
MergePrograms
(
program
,
program_descs
,
/*append=*/
false
);
MergePrograms
(
program
,
program_descs
,
/*append=*/
false
);
}
}
// handle startup program
}
}
static
std
::
vector
<
std
::
vector
<
ir
::
Node
::
Dep
>>
GetOpDependencies
(
static
std
::
vector
<
std
::
vector
<
ir
::
Node
::
Dep
>>
GetOpDependencies
(
...
...
paddle/fluid/framework/program_utils.cc
浏览文件 @
0a144ca1
...
@@ -49,9 +49,11 @@ void MergePrograms(ProgramDesc *dst,
...
@@ -49,9 +49,11 @@ void MergePrograms(ProgramDesc *dst,
if
(
dst_block
->
FindVar
(
src_new_var
->
Name
()))
continue
;
if
(
dst_block
->
FindVar
(
src_new_var
->
Name
()))
continue
;
auto
*
dst_new_var
=
dst_block
->
Var
(
src_new_var
->
Name
());
auto
*
dst_new_var
=
dst_block
->
Var
(
src_new_var
->
Name
());
*
dst_new_var
=
*
src_new_var
;
*
dst_new_var
=
*
src_new_var
;
VLOG
(
10
)
<<
"Create new variable "
<<
dst_new_var
->
Name
();
VLOG
(
10
)
<<
"Create new variable "
<<
dst_new_var
->
Name
()
<<
", persistable:"
<<
dst_new_var
->
Persistable
();
}
}
};
};
VisitAllElements
(
srcs
,
create_var_visitor
,
reverse
);
VisitAllElements
(
srcs
,
create_var_visitor
,
reverse
);
auto
create_op_visitor
=
[
dst
,
reverse
](
const
ProgramDesc
&
src
)
{
auto
create_op_visitor
=
[
dst
,
reverse
](
const
ProgramDesc
&
src
)
{
...
...
paddle/fluid/operators/collective/c_allreduce_op.h
浏览文件 @
0a144ca1
...
@@ -76,6 +76,18 @@ class CAllReduceOp : public framework::OperatorWithKernel {
...
@@ -76,6 +76,18 @@ class CAllReduceOp : public framework::OperatorWithKernel {
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
}
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
framework
::
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
{
if
(
var_name
==
"Cond"
)
{
return
expected_kernel_type
;
}
else
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
}
};
};
template
<
ReduceType
red_type
,
typename
T
>
template
<
ReduceType
red_type
,
typename
T
>
...
@@ -83,6 +95,7 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
...
@@ -83,6 +95,7 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_GLOO)
#if defined(PADDLE_WITH_GLOO)
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
...
@@ -180,6 +193,23 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
...
@@ -180,6 +193,23 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_ASCEND_CL)
#if defined(PADDLE_WITH_ASCEND_CL)
if
(
ctx
.
HasInput
(
"Cond"
))
{
auto
cond
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Cond"
);
auto
place
=
cond
->
place
();
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` tensor should be on cpu place"
));
PADDLE_ENFORCE_EQ
(
cond
->
numel
(),
1
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` should be shape [1]"
));
if
(
!
cond
->
data
<
bool
>
()[
0
])
{
VLOG
(
4
)
<<
"Skip all reduce Op since cond is 0"
;
return
;
}
}
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
...
@@ -296,6 +326,23 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
...
@@ -296,6 +326,23 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_XPU_BKCL)
#if defined(PADDLE_WITH_XPU_BKCL)
if
(
ctx
.
HasInput
(
"Cond"
))
{
auto
cond
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Cond"
);
auto
place
=
cond
->
place
();
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` tensor should be on cpu place"
));
PADDLE_ENFORCE_EQ
(
cond
->
numel
(),
1
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` should be shape [1]"
));
if
(
!
cond
->
data
<
bool
>
()[
0
])
{
VLOG
(
4
)
<<
"Skip all reduce Op since cond is 0"
;
return
;
}
}
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
...
@@ -364,6 +411,23 @@ template <ReduceType red_type, typename T>
...
@@ -364,6 +411,23 @@ template <ReduceType red_type, typename T>
class
CAllReduceOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
CAllReduceOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
if
(
ctx
.
HasInput
(
"Cond"
))
{
auto
cond
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Cond"
);
auto
place
=
cond
->
place
();
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` tensor should be on cpu place"
));
PADDLE_ENFORCE_EQ
(
cond
->
numel
(),
1
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` should be shape [1]"
));
if
(
!
cond
->
data
<
bool
>
()[
0
])
{
VLOG
(
4
)
<<
"Skip all reduce Op since cond is 0"
;
return
;
}
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
...
@@ -468,6 +532,23 @@ class CAllReduceOpMLUKernel : public framework::OpKernel<T> {
...
@@ -468,6 +532,23 @@ class CAllReduceOpMLUKernel : public framework::OpKernel<T> {
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
if
(
ctx
.
HasInput
(
"Cond"
))
{
auto
cond
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Cond"
);
auto
place
=
cond
->
place
();
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` tensor should be on cpu place"
));
PADDLE_ENFORCE_EQ
(
cond
->
numel
(),
1
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` should be shape [1]"
));
if
(
!
cond
->
data
<
bool
>
()[
0
])
{
VLOG
(
4
)
<<
"Skip all reduce Op since cond is 0"
;
return
;
}
}
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
cnclDataType_t
dtype
=
cnclDataType_t
dtype
=
platform
::
ToCNCLDataType
(
framework
::
TransToProtoVarType
(
in
->
dtype
()));
platform
::
ToCNCLDataType
(
framework
::
TransToProtoVarType
(
in
->
dtype
()));
...
@@ -549,10 +630,12 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
...
@@ -549,10 +630,12 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
)DOC"
,
)DOC"
,
GetName
(),
GetName
(),
GetName
()));
GetName
()));
ExtraMake
();
}
}
protected:
protected:
virtual
std
::
string
GetName
()
const
=
0
;
virtual
std
::
string
GetName
()
const
=
0
;
virtual
void
ExtraMake
()
{}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/collective/c_allreduce_sum_op.cc
浏览文件 @
0a144ca1
...
@@ -47,6 +47,10 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
...
@@ -47,6 +47,10 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
class
CAllReduceSumOpMaker
:
public
CAllReduceOpMaker
{
class
CAllReduceSumOpMaker
:
public
CAllReduceOpMaker
{
protected:
protected:
void
ExtraMake
()
override
{
AddInput
(
"Cond"
,
"(Tensor), whether to do all reduce or not."
)
.
AsDispensable
();
}
std
::
string
GetName
()
const
override
{
return
"Sum"
;
}
std
::
string
GetName
()
const
override
{
return
"Sum"
;
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录