Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
636b8e2b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
636b8e2b
编写于
6月 17, 2020
作者:
L
lizhenyu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add SigmoidCrossEntropyWithLogitsGrad op
上级
4642df20
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
253 addition
and
0 deletion
+253
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu
.../cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu
+41
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh
...cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh
+25
-0
mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc
...u/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc
+29
-0
mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h
...pu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h
+96
-0
tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_grad_op.py
...ops/gpu/test_sigmoid_cross_entropy_with_logits_grad_op.py
+62
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cu
0 → 100644
浏览文件 @
636b8e2b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh"
template
<
typename
T
,
typename
S
>
__global__
void
SigmoidCrossEntropyWithLogitsGradKernel
(
const
size_t
size
,
const
T
*
logits
,
const
S
*
labels
,
T
*
outputs
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
size
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
if
(
logits
[
i
]
>=
0
)
{
outputs
[
i
]
=
1.
/
(
1.
+
exp
(
-
logits
[
i
]))
-
labels
[
i
];
}
else
{
const
T
exp_val
=
exp
(
logits
[
i
]);
outputs
[
i
]
=
exp_val
/
(
1.
+
exp_val
)
-
labels
[
i
];
}
}
}
template
<
typename
T
,
typename
S
>
void
SigmoidCrossEntropyWithLogitsGrad
(
const
size_t
size
,
const
T
*
logits
,
const
S
*
labels
,
T
*
outputs
,
cudaStream_t
cuda_stream
)
{
SigmoidCrossEntropyWithLogitsGradKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
logits
,
labels
,
outputs
);
}
template
void
SigmoidCrossEntropyWithLogitsGrad
<
float
,
float
>(
const
size_t
size
,
const
float
*
logits
,
const
float
*
labels
,
float
*
outputs
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh
0 → 100644
浏览文件 @
636b8e2b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_
#include "device/gpu/cuda_common.h"
template
<
typename
T
,
typename
S
>
void
SigmoidCrossEntropyWithLogitsGrad
(
const
size_t
size
,
const
T
*
logits
,
const
S
*
labels
,
T
*
outputs
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_IMPL_H_
mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.cc
0 → 100644
浏览文件 @
636b8e2b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_TWO
(
SigmoidCrossEntropyWithLogitsGrad
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
SigmoidCrossEntropyWithLogitsGradGpuKernel
,
float
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/nn/sigmoid_cross_entropy_with_logits_grad_gpu_kernel.h
0 → 100644
浏览文件 @
636b8e2b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/sigmoid_cross_entropy_with_logits_grad_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
,
typename
S
>
class
SigmoidCrossEntropyWithLogitsGradGpuKernel
:
public
GpuKernel
{
public:
SigmoidCrossEntropyWithLogitsGradGpuKernel
()
:
logits_size_
(
0
),
labels_size_
(
0
),
outputs_size_
(
0
)
{}
~
SigmoidCrossEntropyWithLogitsGradGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
T
*
logits_addr
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
S
*
labels_addr
=
GetDeviceAddress
<
S
>
(
inputs
,
1
);
T
*
outputs_addr
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
SigmoidCrossEntropyWithLogitsGrad
(
inputs
[
0
]
->
size
/
sizeof
(
T
),
logits_addr
,
labels_addr
,
outputs_addr
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
3
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but SigmoidCrossEntropyWithLogitsGrad needs 3 inputs."
;
return
false
;
}
logits_size_
=
sizeof
(
T
);
labels_size_
=
sizeof
(
S
);
outputs_size_
=
sizeof
(
T
);
auto
logits_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
logits_shape
.
size
();
i
++
)
{
logits_size_
*=
logits_shape
[
i
];
}
auto
labels_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
1
);
for
(
size_t
i
=
0
;
i
<
labels_shape
.
size
();
i
++
)
{
labels_size_
*=
labels_shape
[
i
];
}
auto
output_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
output_shape
.
size
();
i
++
)
{
outputs_size_
*=
output_shape
[
i
];
}
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
logits_size_
);
input_size_list_
.
push_back
(
labels_size_
);
output_size_list_
.
push_back
(
outputs_size_
);
}
private:
size_t
logits_size_
;
size_t
labels_size_
;
size_t
outputs_size_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SIGMOID_CROSS_ENTROPY_WITH_LOGITS_GRAD_GPU_KERNEL_H_
tests/st/ops/gpu/test_sigmoid_cross_entropy_with_logits_grad_op.py
0 → 100644
浏览文件 @
636b8e2b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.ops.operations
import
_grad_ops
as
G
class
NetSigmoidCrossEntropyWithLogits
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetSigmoidCrossEntropyWithLogits
,
self
).
__init__
()
self
.
sigmoid_cross_entropy_with_logits_grad
=
G
.
SigmoidCrossEntropyWithLogitsGrad
()
def
construct
(
self
,
logits
,
labels
,
dout
):
return
self
.
sigmoid_cross_entropy_with_logits_grad
(
logits
,
labels
,
dout
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_sigmoid_cross_entropy_with_logits
():
logits
=
Tensor
(
np
.
array
([[
1
,
1
,
2
],
[
1
,
2
,
1
],
[
2
,
1
,
1
]]).
astype
(
np
.
float32
))
labels
=
Tensor
(
np
.
array
([[
0
,
0
,
1
],
[
0
,
1
,
0
],
[
1
,
0
,
0
]]).
astype
(
np
.
float32
))
dout
=
Tensor
(
np
.
ones
(
shape
=
[
3
,
3
]).
astype
(
np
.
float32
))
expect
=
np
.
array
([[
0.731059
,
0.731059
,
-
0.119203
],
[
0.731059
,
-
0.119203
,
0.731059
],
[
-
0.119203
,
0.731059
,
0.731059
]]).
astype
(
np
.
float32
)
error
=
np
.
ones
(
shape
=
[
3
,
3
])
*
1.0e-6
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
'GPU'
)
sigmoid_cross_entropy_with_logits
=
NetSigmoidCrossEntropyWithLogits
()
output
=
sigmoid_cross_entropy_with_logits
(
logits
,
labels
,
dout
)
diff
=
output
.
asnumpy
()
-
expect
assert
np
.
all
(
abs
(
diff
)
<
error
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
'GPU'
)
sigmoid_cross_entropy_with_logits
=
NetSigmoidCrossEntropyWithLogits
()
output
=
sigmoid_cross_entropy_with_logits
(
logits
,
labels
,
dout
)
diff
=
output
.
asnumpy
()
-
expect
assert
np
.
all
(
abs
(
diff
)
<
error
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录