Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
15fac5e7
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
15fac5e7
编写于
1月 07, 2021
作者:
L
liuyuhui
提交者:
GitHub
1月 07, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix assign_op_xpu concat_op_xpu warining (#30120)
上级
f5428eca
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
20 addition
and
9 deletion
+20
-9
paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
...rk/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
+1
-1
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
...k/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
+1
-1
paddle/fluid/operators/concat_op_xpu.cc
paddle/fluid/operators/concat_op_xpu.cc
+18
-7
未找到文件。
paddle/fluid/framework/ir/multi_devices_graph_pass/fuse_all_reduce_op_pass.cc
浏览文件 @
15fac5e7
...
@@ -276,7 +276,7 @@ class FuseAllReduceOpPass : public ir::Pass {
...
@@ -276,7 +276,7 @@ class FuseAllReduceOpPass : public ir::Pass {
ir
::
Node
::
Type
::
kOperation
),
ir
::
Node
::
Type
::
kOperation
),
local_scopes
,
places
,
num_of_all_reduce
,
multi_nccl_ctxs
);
local_scopes
,
places
,
num_of_all_reduce
,
multi_nccl_ctxs
);
#elif defined(PADDLE_WITH_XPU_BKCL)
#elif defined(PADDLE_WITH_XPU_BKCL)
auto
*
op_handle
=
new
details
::
FusedAllReduceOpHandle
(
op_handle
=
new
details
::
FusedAllReduceOpHandle
(
result
->
CreateEmptyNode
(
"fused_all_reduce"
,
result
->
CreateEmptyNode
(
"fused_all_reduce"
,
ir
::
Node
::
Type
::
kOperation
),
ir
::
Node
::
Type
::
kOperation
),
local_scopes
,
places
,
num_of_all_reduce
,
multi_bkcl_ctxs
);
local_scopes
,
places
,
num_of_all_reduce
,
multi_bkcl_ctxs
);
...
...
paddle/fluid/framework/ir/multi_devices_graph_pass/multi_devices_graph_pass.cc
浏览文件 @
15fac5e7
...
@@ -522,7 +522,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
...
@@ -522,7 +522,7 @@ void MultiDevSSAGraphBuilderBase::CreateAllReduceOp(ir::Graph *result,
scopes
,
places
,
grad_merge_cond_name
,
multi_nccl_ctxs_
));
scopes
,
places
,
grad_merge_cond_name
,
multi_nccl_ctxs_
));
#elif defined(PADDLE_WITH_XPU_BKCL)
#elif defined(PADDLE_WITH_XPU_BKCL)
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
result
->
Get
<
GraphOps
>
(
kGraphOps
).
emplace_back
(
new
d
a
tails
::
GradMergeAllReduceOpHandle
(
new
d
e
tails
::
GradMergeAllReduceOpHandle
(
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
result
->
CreateEmptyNode
(
"allreduce"
,
ir
::
Node
::
Type
::
kOperation
),
scopes
,
places
,
grad_merge_cond_name
,
multi_bkcl_ctxs_
));
scopes
,
places
,
grad_merge_cond_name
,
multi_bkcl_ctxs_
));
#else
#else
...
...
paddle/fluid/operators/concat_op_xpu.cc
浏览文件 @
15fac5e7
...
@@ -36,11 +36,16 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
...
@@ -36,11 +36,16 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
"XPU donot surpport AxisTensor for now"
));
"XPU donot surpport AxisTensor for now"
));
axis
=
ComputeAxis
(
static_cast
<
int64_t
>
(
axis
),
axis
=
ComputeAxis
(
static_cast
<
int64_t
>
(
axis
),
static_cast
<
int64_t
>
(
ins
[
0
]
->
dims
().
size
()));
static_cast
<
int64_t
>
(
ins
[
0
]
->
dims
().
size
()));
PADDLE_ENFORCE_GE
(
PADDLE_ENFORCE_GE
(
axis
,
0
,
platform
::
errors
::
InvalidArgument
(
axis
,
0
,
platform
::
errors
::
InvalidArgument
(
"concat: axis shoud >= 0!"
));
"concat: axis should be larger than or "
"equal to 0, but received axis is %d."
,
axis
));
PADDLE_ENFORCE_LT
(
axis
,
ins
[
0
]
->
dims
().
size
(),
PADDLE_ENFORCE_LT
(
axis
,
ins
[
0
]
->
dims
().
size
(),
platform
::
errors
::
InvalidArgument
(
platform
::
errors
::
InvalidArgument
(
"concat: axis shoud < ins[0]->dims()!"
));
"concat: axis should be less than ins[0]->dims()!"
"But received axis is %d, while ins[0]->dims()"
"size is %d."
,
axis
,
ins
[
0
]
->
dims
().
size
()));
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
out
->
mutable_data
<
T
>
(
place
);
out
->
mutable_data
<
T
>
(
place
);
...
@@ -151,10 +156,16 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
...
@@ -151,10 +156,16 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
}
}
}
}
PADDLE_ENFORCE_GE
(
axis
,
0
,
platform
::
errors
::
InvalidArgument
(
PADDLE_ENFORCE_GE
(
axis
,
0
,
platform
::
errors
::
InvalidArgument
(
"concat_grad: axis shoud >= 0!"
));
"concat_grad: axis should be larger than or "
PADDLE_ENFORCE_LT
(
axis
,
out_grad
->
dims
().
size
(),
"equal to 0, but received axis is %d."
,
platform
::
errors
::
InvalidArgument
(
axis
));
"concat_grad: axis shoud < ins[0]->dims()!"
));
PADDLE_ENFORCE_LT
(
axis
,
out_grad
->
dims
().
size
(),
platform
::
errors
::
InvalidArgument
(
"concat_grad: axis should be less than ins[0]->dims()!"
"But received axis is %d, while ins[0]->dims()"
"size is %d."
,
axis
,
out_grad
->
dims
().
size
()));
auto
input_dims
=
ins
[
0
]
->
dims
();
auto
input_dims
=
ins
[
0
]
->
dims
();
std
::
vector
<
int
>
split_list
(
n
);
std
::
vector
<
int
>
split_list
(
n
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录