Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
795d7121
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
795d7121
编写于
4月 11, 2022
作者:
S
sneaxiy
提交者:
GitHub
4月 11, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix some ops (#41577)
上级
c1394c6a
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
16 addition
and
11 deletion
+16
-11
paddle/phi/kernels/cpu/size_kernel.cc
paddle/phi/kernels/cpu/size_kernel.cc
+1
-0
paddle/phi/kernels/gpu/cumsum_kernel.cu
paddle/phi/kernels/gpu/cumsum_kernel.cu
+13
-10
paddle/phi/kernels/gpu/size_kernel.cu
paddle/phi/kernels/gpu/size_kernel.cu
+1
-0
python/paddle/nn/functional/loss.py
python/paddle/nn/functional/loss.py
+1
-1
未找到文件。
paddle/phi/kernels/cpu/size_kernel.cc
浏览文件 @
795d7121
...
...
@@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(size,
CPU
,
ALL_LAYOUT
,
phi
::
SizeKernel
,
int16_t
,
int
,
int64_t
,
phi
::
dtype
::
float16
,
...
...
paddle/phi/kernels/gpu/cumsum_kernel.cu
浏览文件 @
795d7121
...
...
@@ -222,25 +222,28 @@ void CumsumKernel(const Context& dev_ctx,
// Use thrust for parallel acceleration when the input size is equal to the
// length of the ‘axis’ dimension.
if
(
size
==
out_dims
[
axis
])
{
#ifdef __HIPCC__
const
auto
&
policy
=
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
());
#else
const
auto
&
policy
=
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
());
#endif
if
(
reverse
)
{
thrust
::
device_ptr
<
const
T
>
dev_ptr
=
thrust
::
device_pointer_cast
(
in_data
);
thrust
::
device_vector
<
T
>
vec
(
dev_ptr
,
dev_ptr
+
size
);
thrust
::
reverse_iterator
<
thrust
::
device_ptr
<
const
T
>>
reversed_in
(
thrust
::
device_pointer_cast
(
in_data
)
+
size
);
thrust
::
reverse_iterator
<
thrust
::
device_ptr
<
T
>>
reversed_out
(
thrust
::
device_pointer_cast
(
out_data
)
+
size
);
if
(
exclusive
)
{
thrust
::
exclusive_scan
(
thrust
::
device
,
vec
.
rbegin
(),
vec
.
rend
(),
out_data
);
policy
,
reversed_in
,
reversed_in
+
size
,
reversed_out
);
}
else
{
thrust
::
inclusive_scan
(
thrust
::
device
,
vec
.
rbegin
(),
vec
.
rend
(),
out_data
);
policy
,
reversed_in
,
reversed_in
+
size
,
reversed_out
);
}
thrust
::
reverse
(
thrust
::
device
,
out_data
,
out_data
+
size
);
}
else
{
if
(
exclusive
)
{
thrust
::
exclusive_scan
(
thrust
::
device
,
in_data
,
in_data
+
size
,
out_data
);
thrust
::
exclusive_scan
(
policy
,
in_data
,
in_data
+
size
,
out_data
);
}
else
{
thrust
::
inclusive_scan
(
thrust
::
device
,
in_data
,
in_data
+
size
,
out_data
);
thrust
::
inclusive_scan
(
policy
,
in_data
,
in_data
+
size
,
out_data
);
}
}
return
;
...
...
paddle/phi/kernels/gpu/size_kernel.cu
浏览文件 @
795d7121
...
...
@@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(size,
GPU
,
ALL_LAYOUT
,
phi
::
SizeKernel
,
int16_t
,
int
,
int64_t
,
phi
::
dtype
::
float16
,
...
...
python/paddle/nn/functional/loss.py
浏览文件 @
795d7121
...
...
@@ -1795,7 +1795,7 @@ def cross_entropy(input,
# 2. else
# numerator: loss's weighted sum
# denominator: cal the sum of weight where the sample's class_index!=ignore_index
if
ignore_index
!=
-
10
0
:
if
ignore_index
>=
0
:
out_sum
=
_C_ops
.
reduce_sum
(
out
,
'reduce_all'
,
True
)
# for each label[i],set 1 or 0, according to ignore_index
# mask[i]=0, if label[i]==ignore_index
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录