Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0a144ca1
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0a144ca1
编写于
9月 22, 2022
作者:
L
Leo Chen
提交者:
GitHub
9月 22, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
convert grad_merge_all_reduce in graph to program (#46353)
上级
173b39bb
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
132 addition
and
39 deletion
+132
-39
paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h
...fluid/framework/details/grad_merge_all_reduce_op_handle.h
+2
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-1
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
+23
-33
paddle/fluid/framework/ir/graph_helper.cc
paddle/fluid/framework/ir/graph_helper.cc
+15
-4
paddle/fluid/framework/program_utils.cc
paddle/fluid/framework/program_utils.cc
+3
-1
paddle/fluid/operators/collective/c_allreduce_op.h
paddle/fluid/operators/collective/c_allreduce_op.h
+83
-0
paddle/fluid/operators/collective/c_allreduce_sum_op.cc
paddle/fluid/operators/collective/c_allreduce_sum_op.cc
+4
-0
未找到文件。
paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h
浏览文件 @
0a144ca1
...
@@ -99,6 +99,8 @@ class FusedGradMergeAllReduceOpHandle : public FusedAllReduceOpHandle {
...
@@ -99,6 +99,8 @@ class FusedGradMergeAllReduceOpHandle : public FusedAllReduceOpHandle {
std
::
string
Name
()
const
override
;
std
::
string
Name
()
const
override
;
std
::
string
GradMergeCondName
()
{
return
grad_merge_cond_name_
;
}
protected:
protected:
void
RunImpl
()
override
;
void
RunImpl
()
override
;
...
...
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
0a144ca1
...
@@ -19,7 +19,8 @@ cc_library(
...
@@ -19,7 +19,8 @@ cc_library(
cc_library
(
cc_library
(
graph_helper
graph_helper
SRCS graph_helper.cc
SRCS graph_helper.cc
DEPS graph program_utils scale_loss_grad_op_handle
)
DEPS graph program_utils scale_loss_grad_op_handle
grad_merge_all_reduce_op_handle
)
cc_library
(
cc_library
(
pass
pass
SRCS pass.cc
SRCS pass.cc
...
...
paddle/fluid/framework/ir/coalesce_grad_tensor_pass.cc
浏览文件 @
0a144ca1
...
@@ -559,38 +559,26 @@ class CoalesceGradTensorPass : public ir::Pass {
...
@@ -559,38 +559,26 @@ class CoalesceGradTensorPass : public ir::Pass {
all_persistable
=
false
;
all_persistable
=
false
;
}
}
}
}
VLOG
(
4
)
<<
"all_persistable:"
<<
all_persistable
;
if
(
all_persistable
)
{
VLOG
(
4
)
<<
"any_persistable:"
<<
all_persistable
;
// All grads are persistable, only need to be executed once at the
// NOTE. In scope_buffered_ssa_graph_executor, after each execution of
// beginning.
// DropScope(), non persistable vars will be Erase or Clear. So
result
->
Get
<
details
::
ProgramDescs
>
(
details
::
kStartupProgramDescs
)
// coalesce_tensor op needs to be executed again after the execution
.
emplace_back
();
// of DropScope().
ProgramDesc
&
program_desc
=
result
->
Get
<
details
::
ProgramDescs
>
(
details
::
kStartupProgramDescs
)
// we can make fused_output persistable, so the memeory is not cleared
.
back
();
// and coalesce_tensor op do nothing if the inputs are already continue.
auto
*
global_block
=
program_desc
.
MutableBlock
(
0
);
AppendAllocSpaceForVarsOp
(
params_name
,
result
->
Get
<
details
::
ProgramDescs
>
(
details
::
kProgramDescs
).
emplace_back
();
grads_name
,
ProgramDesc
&
program_desc
=
fused_var_name
,
result
->
Get
<
details
::
ProgramDescs
>
(
details
::
kProgramDescs
).
back
();
dtype
,
auto
*
global_block
=
program_desc
.
MutableBlock
(
0
);
all_persistable
,
AppendAllocSpaceForVarsOp
(
params_name
,
global_block
);
grads_name
,
}
else
{
fused_var_name
,
// NOTE. In scope_buffered_ssa_graph_executor, after each execution of
dtype
,
// DropScope(), non persistable vars will be Erase or Clear. So
any_persistable
,
// coalesce_tensor op needs to be executed again after the execution
global_block
);
// of DropScope().
result
->
Get
<
details
::
ProgramDescs
>
(
details
::
kProgramDescs
).
emplace_back
();
ProgramDesc
&
program_desc
=
result
->
Get
<
details
::
ProgramDescs
>
(
details
::
kProgramDescs
).
back
();
auto
*
global_block
=
program_desc
.
MutableBlock
(
0
);
AppendAllocSpaceForVarsOp
(
params_name
,
grads_name
,
fused_var_name
,
dtype
,
any_persistable
,
global_block
);
}
}
}
void
AppendAllocSpaceForVarsOp
(
const
std
::
vector
<
std
::
string
>
&
params_name
,
void
AppendAllocSpaceForVarsOp
(
const
std
::
vector
<
std
::
string
>
&
params_name
,
...
@@ -599,13 +587,15 @@ class CoalesceGradTensorPass : public ir::Pass {
...
@@ -599,13 +587,15 @@ class CoalesceGradTensorPass : public ir::Pass {
const
proto
::
VarType
::
Type
&
dtype
,
const
proto
::
VarType
::
Type
&
dtype
,
bool
persistable
,
bool
persistable
,
BlockDesc
*
global_block
)
const
{
BlockDesc
*
global_block
)
const
{
auto
fused_out_var
=
global_block
->
Var
(
fused_var_name
);
fused_out_var
->
SetPersistable
(
persistable
);
auto
op_desc
=
global_block
->
AppendOp
();
auto
op_desc
=
global_block
->
AppendOp
();
op_desc
->
SetType
(
"coalesce_tensor"
);
op_desc
->
SetType
(
"coalesce_tensor"
);
op_desc
->
SetInput
(
"Input"
,
params_name
);
op_desc
->
SetInput
(
"Input"
,
params_name
);
op_desc
->
SetOutput
(
"Output"
,
grads_name
);
op_desc
->
SetOutput
(
"Output"
,
grads_name
);
op_desc
->
SetOutput
(
"FusedOutput"
,
{
fused_var_name
});
op_desc
->
SetOutput
(
"FusedOutput"
,
{
fused_var_name
});
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
dtype
));
op_desc
->
SetAttr
(
"dtype"
,
static_cast
<
int
>
(
dtype
));
op_desc
->
SetAttr
(
"persist_output"
,
persistable
);
op_desc
->
SetAttr
(
"persist_output"
,
persistable
);
}
}
};
};
...
...
paddle/fluid/framework/ir/graph_helper.cc
浏览文件 @
0a144ca1
...
@@ -17,6 +17,7 @@ limitations under the License. */
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include <queue>
#include <queue>
#include <stack>
#include <stack>
#include "paddle/fluid/framework/details/grad_merge_all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/multi_devices_helper.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass.h"
...
@@ -519,11 +520,11 @@ static void ReplaceAllReduceOp(const Node &node,
...
@@ -519,11 +520,11 @@ static void ReplaceAllReduceOp(const Node &node,
desc2
.
SetType
(
"c_allreduce_sum"
);
desc2
.
SetType
(
"c_allreduce_sum"
);
if
(
node
.
IsWrappedBy
<
details
::
OpHandleBase
>
())
{
if
(
node
.
IsWrappedBy
<
details
::
OpHandleBase
>
())
{
details
::
OpHandleBase
&
op_hander
=
details
::
OpHandleBase
&
op_hand
l
er
=
const_cast
<
Node
*>
(
&
node
)
->
Wrapper
<
details
::
OpHandleBase
>
();
const_cast
<
Node
*>
(
&
node
)
->
Wrapper
<
details
::
OpHandleBase
>
();
// set inputs
// set inputs
auto
in_var_handles
=
op_hander
.
Inputs
();
auto
in_var_handles
=
op_hand
l
er
.
Inputs
();
std
::
vector
<
std
::
string
>
in_names
;
std
::
vector
<
std
::
string
>
in_names
;
for
(
const
auto
&
in
:
in_var_handles
)
{
for
(
const
auto
&
in
:
in_var_handles
)
{
if
(
dynamic_cast
<
details
::
DummyVarHandle
*>
(
in
)
!=
nullptr
)
{
if
(
dynamic_cast
<
details
::
DummyVarHandle
*>
(
in
)
!=
nullptr
)
{
...
@@ -543,7 +544,7 @@ static void ReplaceAllReduceOp(const Node &node,
...
@@ -543,7 +544,7 @@ static void ReplaceAllReduceOp(const Node &node,
desc2
.
SetInput
(
"X"
,
{
name
});
desc2
.
SetInput
(
"X"
,
{
name
});
// set outputs
// set outputs
auto
out_var_handles
=
op_hander
.
Outputs
();
auto
out_var_handles
=
op_hand
l
er
.
Outputs
();
std
::
vector
<
std
::
string
>
out_names
;
std
::
vector
<
std
::
string
>
out_names
;
for
(
const
auto
&
out
:
out_var_handles
)
{
for
(
const
auto
&
out
:
out_var_handles
)
{
if
(
dynamic_cast
<
details
::
DummyVarHandle
*>
(
out
)
!=
nullptr
)
{
if
(
dynamic_cast
<
details
::
DummyVarHandle
*>
(
out
)
!=
nullptr
)
{
...
@@ -554,9 +555,18 @@ static void ReplaceAllReduceOp(const Node &node,
...
@@ -554,9 +555,18 @@ static void ReplaceAllReduceOp(const Node &node,
desc2
.
SetOutput
(
"Out"
,
{
name
});
desc2
.
SetOutput
(
"Out"
,
{
name
});
int
ring_id
=
platform
::
NCCLCommContext
::
Instance
().
GetRingId
(
int
ring_id
=
platform
::
NCCLCommContext
::
Instance
().
GetRingId
(
dynamic_cast
<
details
::
NCCLOpHandleBase
*>
(
&
op_hander
)
->
GetComm
());
dynamic_cast
<
details
::
NCCLOpHandleBase
*>
(
&
op_hand
l
er
)
->
GetComm
());
desc2
.
SetAttr
(
"ring_id"
,
ring_id
);
desc2
.
SetAttr
(
"ring_id"
,
ring_id
);
desc2
.
SetAttr
(
"use_calc_stream"
,
true
);
desc2
.
SetAttr
(
"use_calc_stream"
,
true
);
// handle grad merge
if
(
dynamic_cast
<
details
::
FusedGradMergeAllReduceOpHandle
*>
(
&
op_handler
))
{
VLOG
(
4
)
<<
"FusedGradMergeAllReduceOpHandle: add cond to c_allreduce_sum"
;
auto
cond_name
=
dynamic_cast
<
details
::
FusedGradMergeAllReduceOpHandle
*>
(
&
op_handler
)
->
GradMergeCondName
();
desc2
.
SetInput
(
"Cond"
,
{
cond_name
});
}
}
}
desc1
.
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
desc1
.
SetAttr
(
OpProtoAndCheckerMaker
::
OpRoleAttrName
(),
...
@@ -780,6 +790,7 @@ void GraphToProgram(const Graph &graph,
...
@@ -780,6 +790,7 @@ void GraphToProgram(const Graph &graph,
VLOG
(
8
)
<<
"Merge main programs"
;
VLOG
(
8
)
<<
"Merge main programs"
;
MergePrograms
(
program
,
program_descs
,
/*append=*/
false
);
MergePrograms
(
program
,
program_descs
,
/*append=*/
false
);
}
}
// handle startup program
}
}
static
std
::
vector
<
std
::
vector
<
ir
::
Node
::
Dep
>>
GetOpDependencies
(
static
std
::
vector
<
std
::
vector
<
ir
::
Node
::
Dep
>>
GetOpDependencies
(
...
...
paddle/fluid/framework/program_utils.cc
浏览文件 @
0a144ca1
...
@@ -49,9 +49,11 @@ void MergePrograms(ProgramDesc *dst,
...
@@ -49,9 +49,11 @@ void MergePrograms(ProgramDesc *dst,
if
(
dst_block
->
FindVar
(
src_new_var
->
Name
()))
continue
;
if
(
dst_block
->
FindVar
(
src_new_var
->
Name
()))
continue
;
auto
*
dst_new_var
=
dst_block
->
Var
(
src_new_var
->
Name
());
auto
*
dst_new_var
=
dst_block
->
Var
(
src_new_var
->
Name
());
*
dst_new_var
=
*
src_new_var
;
*
dst_new_var
=
*
src_new_var
;
VLOG
(
10
)
<<
"Create new variable "
<<
dst_new_var
->
Name
();
VLOG
(
10
)
<<
"Create new variable "
<<
dst_new_var
->
Name
()
<<
", persistable:"
<<
dst_new_var
->
Persistable
();
}
}
};
};
VisitAllElements
(
srcs
,
create_var_visitor
,
reverse
);
VisitAllElements
(
srcs
,
create_var_visitor
,
reverse
);
auto
create_op_visitor
=
[
dst
,
reverse
](
const
ProgramDesc
&
src
)
{
auto
create_op_visitor
=
[
dst
,
reverse
](
const
ProgramDesc
&
src
)
{
...
...
paddle/fluid/operators/collective/c_allreduce_op.h
浏览文件 @
0a144ca1
...
@@ -76,6 +76,18 @@ class CAllReduceOp : public framework::OperatorWithKernel {
...
@@ -76,6 +76,18 @@ class CAllReduceOp : public framework::OperatorWithKernel {
return
framework
::
OpKernelType
(
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
}
framework
::
OpKernelType
GetKernelTypeForVar
(
const
std
::
string
&
var_name
,
const
framework
::
Tensor
&
tensor
,
const
framework
::
OpKernelType
&
expected_kernel_type
)
const
{
if
(
var_name
==
"Cond"
)
{
return
expected_kernel_type
;
}
else
{
return
framework
::
OpKernelType
(
expected_kernel_type
.
data_type_
,
tensor
.
place
(),
tensor
.
layout
());
}
}
};
};
template
<
ReduceType
red_type
,
typename
T
>
template
<
ReduceType
red_type
,
typename
T
>
...
@@ -83,6 +95,7 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
...
@@ -83,6 +95,7 @@ class CAllReduceOpCPUKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_GLOO)
#if defined(PADDLE_WITH_GLOO)
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
...
@@ -180,6 +193,23 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
...
@@ -180,6 +193,23 @@ class CAllReduceOpASCENDKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_ASCEND_CL)
#if defined(PADDLE_WITH_ASCEND_CL)
if
(
ctx
.
HasInput
(
"Cond"
))
{
auto
cond
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Cond"
);
auto
place
=
cond
->
place
();
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` tensor should be on cpu place"
));
PADDLE_ENFORCE_EQ
(
cond
->
numel
(),
1
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` should be shape [1]"
));
if
(
!
cond
->
data
<
bool
>
()[
0
])
{
VLOG
(
4
)
<<
"Skip all reduce Op since cond is 0"
;
return
;
}
}
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
...
@@ -296,6 +326,23 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
...
@@ -296,6 +326,23 @@ class CAllReduceOpXPUKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_XPU_BKCL)
#if defined(PADDLE_WITH_XPU_BKCL)
if
(
ctx
.
HasInput
(
"Cond"
))
{
auto
cond
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Cond"
);
auto
place
=
cond
->
place
();
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` tensor should be on cpu place"
));
PADDLE_ENFORCE_EQ
(
cond
->
numel
(),
1
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` should be shape [1]"
));
if
(
!
cond
->
data
<
bool
>
()[
0
])
{
VLOG
(
4
)
<<
"Skip all reduce Op since cond is 0"
;
return
;
}
}
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
...
@@ -364,6 +411,23 @@ template <ReduceType red_type, typename T>
...
@@ -364,6 +411,23 @@ template <ReduceType red_type, typename T>
class
CAllReduceOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
class
CAllReduceOpCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
if
(
ctx
.
HasInput
(
"Cond"
))
{
auto
cond
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Cond"
);
auto
place
=
cond
->
place
();
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` tensor should be on cpu place"
));
PADDLE_ENFORCE_EQ
(
cond
->
numel
(),
1
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` should be shape [1]"
));
if
(
!
cond
->
data
<
bool
>
()[
0
])
{
VLOG
(
4
)
<<
"Skip all reduce Op since cond is 0"
;
return
;
}
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
...
@@ -468,6 +532,23 @@ class CAllReduceOpMLUKernel : public framework::OpKernel<T> {
...
@@ -468,6 +532,23 @@ class CAllReduceOpMLUKernel : public framework::OpKernel<T> {
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
in
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
if
(
ctx
.
HasInput
(
"Cond"
))
{
auto
cond
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Cond"
);
auto
place
=
cond
->
place
();
PADDLE_ENFORCE_EQ
(
platform
::
is_cpu_place
(
place
),
true
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` tensor should be on cpu place"
));
PADDLE_ENFORCE_EQ
(
cond
->
numel
(),
1
,
platform
::
errors
::
PreconditionNotMet
(
"The input `cond` should be shape [1]"
));
if
(
!
cond
->
data
<
bool
>
()[
0
])
{
VLOG
(
4
)
<<
"Skip all reduce Op since cond is 0"
;
return
;
}
}
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
cnclDataType_t
dtype
=
cnclDataType_t
dtype
=
platform
::
ToCNCLDataType
(
framework
::
TransToProtoVarType
(
in
->
dtype
()));
platform
::
ToCNCLDataType
(
framework
::
TransToProtoVarType
(
in
->
dtype
()));
...
@@ -549,10 +630,12 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
...
@@ -549,10 +630,12 @@ Reference: https://docs.nvidia.com/deeplearning/sdk/nccl-developer-guide/docs/us
)DOC"
,
)DOC"
,
GetName
(),
GetName
(),
GetName
()));
GetName
()));
ExtraMake
();
}
}
protected:
protected:
virtual
std
::
string
GetName
()
const
=
0
;
virtual
std
::
string
GetName
()
const
=
0
;
virtual
void
ExtraMake
()
{}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/collective/c_allreduce_sum_op.cc
浏览文件 @
0a144ca1
...
@@ -47,6 +47,10 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
...
@@ -47,6 +47,10 @@ class CAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
class
CAllReduceSumOpMaker
:
public
CAllReduceOpMaker
{
class
CAllReduceSumOpMaker
:
public
CAllReduceOpMaker
{
protected:
protected:
void
ExtraMake
()
override
{
AddInput
(
"Cond"
,
"(Tensor), whether to do all reduce or not."
)
.
AsDispensable
();
}
std
::
string
GetName
()
const
override
{
return
"Sum"
;
}
std
::
string
GetName
()
const
override
{
return
"Sum"
;
}
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录