Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
a2387ef2
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a2387ef2
编写于
4月 12, 2021
作者:
T
TTerror
提交者:
GitHub
4月 12, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix concat_grad on kunlun (#32151)
* fix concat_grad on kunlun * fix concat_grad on kunlun
上级
f8bab5b0
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
7 addition
and
14 deletion
+7
-14
cmake/external/xpu.cmake
cmake/external/xpu.cmake
+1
-1
paddle/fluid/operators/concat_op_xpu.cc
paddle/fluid/operators/concat_op_xpu.cc
+6
-13
未找到文件。
cmake/external/xpu.cmake
浏览文件 @
a2387ef2
...
...
@@ -13,7 +13,7 @@ if(NOT XPU_SDK_ROOT)
elseif
(
WITH_SUNWAY
)
SET
(
XPU_URL
"https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2021_01_13.tar.gz"
CACHE STRING
""
FORCE
)
else
()
SET
(
XPU_URL
"https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_0
3_30
.tar.gz"
CACHE STRING
""
FORCE
)
SET
(
XPU_URL
"https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_0
4_09
.tar.gz"
CACHE STRING
""
FORCE
)
endif
()
SET
(
XPU_SOURCE_DIR
"
${
THIRD_PARTY_PATH
}
/xpu"
)
...
...
paddle/fluid/operators/concat_op_xpu.cc
浏览文件 @
a2387ef2
...
...
@@ -132,16 +132,14 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
axis
=
ComputeAxis
(
static_cast
<
int64_t
>
(
axis
),
static_cast
<
int64_t
>
(
ins
[
0
]
->
dims
().
size
()));
// get output tensor that the name is not kEmptyVarName
std
::
vector
<
framework
::
Tensor
*>
outputs
;
std
::
vector
<
int
>
choose_idx
;
int
n
=
0
;
std
::
vector
<
T
*>
ptrs
(
outs
.
size
());
for
(
size_t
j
=
0
;
j
<
outs
.
size
();
++
j
)
{
if
(
out_var_names
[
j
]
!=
framework
::
kEmptyVarName
&&
outs
[
j
]
->
numel
()
!=
0UL
)
{
outs
[
j
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
outputs
.
push_back
(
outs
[
j
]
);
choose_idx
.
push_back
(
j
);
n
++
;
ptrs
[
j
]
=
outs
[
j
]
->
data
<
T
>
(
);
}
else
{
ptrs
[
j
]
=
nullptr
;
}
}
PADDLE_ENFORCE_GE
(
axis
,
0
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -157,10 +155,10 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
axis
,
out_grad
->
dims
().
size
()));
auto
input_dims
=
ins
[
0
]
->
dims
();
std
::
vector
<
int
>
split_list
(
n
);
std
::
vector
<
int
>
split_list
(
ins
.
size
()
);
std
::
vector
<
int
>
xdims_list
(
input_dims
.
size
());
int
total_length
=
0
;
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
ins
.
size
()
;
++
i
)
{
split_list
[
i
]
=
ins
[
i
]
->
dims
()[
axis
];
total_length
+=
ins
[
i
]
->
dims
()[
axis
];
}
...
...
@@ -172,11 +170,6 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
}
xdims_list
[
axis
]
=
total_length
;
std
::
vector
<
T
*>
ptrs
(
n
);
for
(
int
i
=
0
;
i
<
n
;
++
i
)
{
ptrs
[
i
]
=
outputs
[
i
]
->
data
<
T
>
();
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
int
r
=
xpu
::
split
<
T
>
(
dev_ctx
.
x_context
(),
out_grad
->
data
<
T
>
(),
ptrs
,
xdims_list
,
split_list
,
axis
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录