Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
49ba473b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
49ba473b
编写于
8月 03, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 03, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3803 add gpu klDivLoss op
Merge pull request !3803 from baihuawei/loss
上级
12102ae3
9eca5663
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
316 addition
and
0 deletion
+316
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.cc
.../backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.cc
+26
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h
...c/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h
+86
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.cc
...backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.cc
+30
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h
.../backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h
+88
-0
tests/st/ops/gpu/test_kl_div_op.py
tests/st/ops/gpu/test_kl_div_op.py
+86
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.cc
0 → 100644
浏览文件 @
49ba473b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
KLDivLoss
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
KLDivLossGpuKernel
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_gpu_kernel.h
0 → 100644
浏览文件 @
49ba473b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_GPU_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_GPU_KERNEL_H
#include <vector>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
KLDivLossGpuKernel
:
public
GpuKernel
{
public:
KLDivLossGpuKernel
()
:
input_size_
(
1
),
reduction_
(
1
)
{}
~
KLDivLossGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
T
*
input_x
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
input_y
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
loss
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
KLDivLoss
(
input_size_
,
reduction_
,
input_x
,
input_y
,
loss
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
input_size_
*=
input_shape
[
i
];
}
string
reduction
=
GetAttr
<
string
>
(
kernel_node
,
"reduction"
);
if
(
reduction
==
"none"
)
{
reduction_
=
0
;
}
else
if
(
reduction
==
"sum"
)
{
reduction_
=
2
;
}
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
input_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
if
(
reduction_
==
0
)
{
output_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
}
else
{
output_size_list_
.
push_back
(
sizeof
(
T
));
}
}
private:
size_t
input_size_
;
int
reduction_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_GPU_KERNEL_H
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.cc
0 → 100644
浏览文件 @
49ba473b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
KLDivLossGrad
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
KLDivLossGradGpuKernel
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/kl_div_loss_grad_kernel.h
0 → 100644
浏览文件 @
49ba473b
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_LOSS_GRAD_KERNEL_H
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_NN_KL_DIV_LOSS_GRAD_KERNEL_H
#include <vector>
#include <string>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/loss_with_reduction_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
KLDivLossGradGpuKernel
:
public
GpuKernel
{
public:
KLDivLossGradGpuKernel
()
:
input_size_
(
1
),
reduction_
(
1
)
{}
~
KLDivLossGradGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
T
*
input_x
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
input_y
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
dloss
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
T
*
dx
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
T
*
dy
=
GetDeviceAddress
<
T
>
(
outputs
,
1
);
KLDivLossGrad
(
input_size_
,
reduction_
,
input_x
,
input_y
,
dloss
,
dx
,
dy
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
input_size_
*=
input_shape
[
i
];
}
string
reduction
=
GetAttr
<
string
>
(
kernel_node
,
"reduction"
);
if
(
reduction
==
"none"
)
{
reduction_
=
0
;
}
else
if
(
reduction
==
"sum"
)
{
reduction_
=
2
;
}
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
input_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
output_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
output_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
if
(
reduction_
==
0
)
{
input_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
}
else
{
input_size_list_
.
push_back
(
sizeof
(
T
));
}
}
private:
size_t
input_size_
;
int
reduction_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_KL_DIV_LOSS_GRAD_KERNEL_H
tests/st/ops/gpu/test_kl_div_op.py
0 → 100644
浏览文件 @
49ba473b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
reduction
=
"none"
):
super
(
Net
,
self
).
__init__
()
self
.
KLDivLoss
=
P
.
KLDivLoss
(
"none"
)
def
construct
(
self
,
x
,
y
):
return
self
.
KLDivLoss
(
x
,
y
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_binary_cross_entropy_loss
():
np
.
random
.
seed
(
42
)
prediction
=
np
.
random
.
rand
(
20
).
astype
(
np
.
float32
)
target
=
np
.
random
.
rand
(
20
).
astype
(
np
.
float32
)
net
=
Net
()
loss
=
net
(
Tensor
(
prediction
),
Tensor
(
target
))
expect
=
[
-
0.5297444
,
-
0.40738472
,
-
0.5733339
,
-
0.58720195
,
-
0.42922008
,
-
0.31237593
,
-
0.3332863
,
-
0.78742254
,
-
0.6662671
,
-
0.17546377
,
-
0.31526336
,
-
0.46702948
,
-
0.23191005
,
-
0.2512708
,
-
0.20934652
,
-
0.32021108
,
-
0.45477402
,
-
0.278453
,
-
0.5551879
,
-
0.48938933
]
assert
np
.
allclose
(
loss
.
asnumpy
(),
expect
)
class
Grad
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
Grad
,
self
).
__init__
()
self
.
grad
=
C
.
GradOperation
(
name
=
"get_all"
,
get_all
=
True
,
sens_param
=
True
)
self
.
network
=
network
def
construct
(
self
,
x1
,
x2
,
sens
):
gout
=
self
.
grad
(
self
.
network
)(
x1
,
x2
,
sens
)
return
gout
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_binary_cross_entropy_loss_grad
():
np
.
random
.
seed
(
42
)
prediction
=
np
.
random
.
rand
(
20
).
astype
(
np
.
float32
)
target
=
np
.
random
.
rand
(
20
).
astype
(
np
.
float32
)
sens
=
np
.
random
.
rand
(
20
).
astype
(
np
.
float32
)
grad
=
Grad
(
Net
())
dx
=
grad
(
Tensor
(
prediction
),
Tensor
(
target
),
Tensor
(
sens
))
dx1_expect
=
[
-
0.07466945
,
-
0.06907414
,
-
0.01004642
,
-
0.3331403
,
-
0.11802178
,
-
0.52019656
,
-
0.06224053
,
-
0.2674369
,
-
0.32387912
,
-
0.00858657
,
-
0.58906615
,
-
0.13217884
,
-
0.06111591
,
-
0.8490888
,
-
0.57735133
,
-
0.7452407
,
-
0.02695603
,
-
0.01914206
,
-
0.03094601
,
-
0.14319494
]
dx2_expect
=
[
0.0163771
,
-
0.950962
,
-
0.03309895
,
-
0.5481312
,
0.01523498
,
0.39894313
,
-
0.20858267
,
-
0.27628726
,
-
0.06815486
,
-
0.5134226
,
0.46645382
,
-
1.3477919
,
-
2.409831
,
0.65787154
,
0.4682768
,
0.55671424
,
-
0.04362264
,
-
0.36274382
,
0.00852979
,
-
0.03639247
]
assert
np
.
allclose
(
dx
[
0
].
asnumpy
(),
dx1_expect
)
assert
np
.
allclose
(
dx
[
1
].
asnumpy
(),
dx2_expect
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录