Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
正统之独孤求败
mindspore
提交
04f4be48
M
mindspore
项目概览
正统之独孤求败
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
04f4be48
编写于
8月 19, 2020
作者:
B
baihuawei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix gpu loss grad
上级
7a8fbbbb
变更
1
显示空白变更内容
内联
并排
Showing
1 changed file
with
11 addition
and
3 deletion
+11
-3
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu
...kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu
+11
-3
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cu
浏览文件 @
04f4be48
...
...
@@ -97,10 +97,14 @@ __global__ void KLDivLossGradKernel(const int input_size, const int reduction, c
dy
[
i
]
=
(
logf
(
denominator
)
+
1
-
input_x
[
i
])
*
dloss
[
i
];
}
}
else
{
T
dloss1
=
dloss
[
0
];
if
(
reduction
==
1
)
{
dloss1
=
dloss
[
0
]
/
input_size
;
}
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
input_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
denominator
=
max
(
input_y
[
i
],
epsilon
);
dx
[
i
]
=
-
input_y
[
i
]
*
dloss
[
0
]
;
dy
[
i
]
=
(
logf
(
denominator
)
+
1
-
input_x
[
i
])
*
dloss
[
0
]
;
dx
[
i
]
=
-
input_y
[
i
]
*
dloss
1
;
dy
[
i
]
=
(
logf
(
denominator
)
+
1
-
input_x
[
i
])
*
dloss
1
;
}
}
}
...
...
@@ -169,10 +173,14 @@ __global__ void BinaryCrossEntropyLossGradKernel(const int input_size, const int
dx
[
i
]
=
value
*
dloss
[
i
];
}
}
else
{
T
dloss1
=
dloss
[
0
];
if
(
reduction
==
1
)
{
dloss1
=
dloss
[
0
]
/
input_size
;
}
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
input_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
denominator
=
max
(
input_x
[
i
]
*
(
1
-
input_x
[
i
]),
epsilon
);
T
value
=
weight
[
i
]
*
(
input_x
[
i
]
-
input_y
[
i
])
/
denominator
;
dx
[
i
]
=
value
*
dloss
[
0
]
;
dx
[
i
]
=
value
*
dloss
1
;
}
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录