Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
cf5a27e9
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
cf5a27e9
编写于
7月 10, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 10, 2020
浏览文件
操作
浏览文件
下载
差异文件
!2962 gpu support SmoothL1Loss kernel
Merge pull request !2962 from chenweifeng/smoothl1loss
上级
03ef509e
0fdc304a
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
376 addition
and
0 deletion
+376
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/smooth_l1_loss_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/smooth_l1_loss_impl.cu
+64
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/smooth_l1_loss_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/smooth_l1_loss_impl.cuh
+25
-0
mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_gpu_kernel.cc
+26
-0
mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_gpu_kernel.h
+75
-0
mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.cc
...ore/ccsrc/kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.cc
+29
-0
mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h
...pore/ccsrc/kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h
+76
-0
tests/st/ops/gpu/test_smoothl1loss_op.py
tests/st/ops/gpu/test_smoothl1loss_op.py
+81
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/smooth_l1_loss_impl.cu
0 → 100644
浏览文件 @
cf5a27e9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "smooth_l1_loss_impl.cuh"
#include "device/gpu/cuda_common.h"
template
<
typename
T
>
__global__
void
SmoothL1LossKernel
(
const
int
input_size
,
const
float
sigma
,
const
T
*
prediction
,
const
T
*
target
,
T
*
loss
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
input_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
value
=
(
prediction
[
i
]
-
target
[
i
])
>
0
?
(
prediction
[
i
]
-
target
[
i
])
:
(
target
[
i
]
-
prediction
[
i
]);
if
(
value
<
sigma
)
{
loss
[
i
]
=
static_cast
<
T
>
(
0.5
)
*
value
*
value
;
}
else
{
loss
[
i
]
=
value
-
static_cast
<
T
>
(
0.5
);
}
}
}
template
<
typename
T
>
void
SmoothL1Loss
(
const
int
&
input_size
,
const
float
&
sigma
,
const
T
*
prediction
,
const
T
*
target
,
T
*
loss
,
cudaStream_t
stream
)
{
SmoothL1LossKernel
<<<
GET_BLOCKS
(
input_size
),
GET_THREADS
,
0
,
stream
>>>
(
input_size
,
sigma
,
prediction
,
target
,
loss
);
}
template
<
typename
T
>
__global__
void
SmoothL1LossGradKernel
(
const
int
input_size
,
const
float
sigma
,
const
T
*
prediction
,
const
T
*
target
,
const
T
*
dloss
,
T
*
dx
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
input_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
value
=
prediction
[
i
]
-
target
[
i
];
if
(
value
>
static_cast
<
T
>
(
sigma
))
{
dx
[
i
]
=
dloss
[
i
];
}
else
if
(
value
<
static_cast
<
T
>
(
-
sigma
))
{
dx
[
i
]
=
-
dloss
[
i
];
}
else
{
dx
[
i
]
=
value
*
dloss
[
i
];
}
}
}
template
<
typename
T
>
void
SmoothL1LossGrad
(
const
int
&
input_size
,
const
float
&
sigma
,
const
T
*
prediction
,
const
T
*
target
,
const
T
*
dloss
,
T
*
dx
,
cudaStream_t
stream
)
{
SmoothL1LossGradKernel
<<<
GET_BLOCKS
(
input_size
),
GET_THREADS
,
0
,
stream
>>>
(
input_size
,
sigma
,
prediction
,
target
,
dloss
,
dx
);
}
template
void
SmoothL1Loss
(
const
int
&
input_size
,
const
float
&
sigma
,
const
float
*
prediction
,
const
float
*
target
,
float
*
loss
,
cudaStream_t
stream
);
template
void
SmoothL1LossGrad
(
const
int
&
input_size
,
const
float
&
sigma
,
const
float
*
prediction
,
const
float
*
target
,
const
float
*
dloss
,
float
*
dx
,
cudaStream_t
stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/smooth_l1_loss_impl.cuh
0 → 100644
浏览文件 @
cf5a27e9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
template
<
typename
T
>
void
SmoothL1Loss
(
const
int
&
input_size
,
const
float
&
sigma
,
const
T
*
prediction
,
const
T
*
target
,
T
*
loss
,
cudaStream_t
stream
);
template
<
typename
T
>
void
SmoothL1LossGrad
(
const
int
&
input_size
,
const
float
&
sigma
,
const
T
*
prediction
,
const
T
*
target
,
const
T
*
dloss
,
T
*
dx
,
cudaStream_t
stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_gpu_kernel.cc
0 → 100644
浏览文件 @
cf5a27e9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/nn/smooth_l1_loss_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
SmoothL1Loss
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
SmoothL1LossGpuKernel
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_gpu_kernel.h
0 → 100644
浏览文件 @
cf5a27e9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/smooth_l1_loss_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
SmoothL1LossGpuKernel
:
public
GpuKernel
{
public:
SmoothL1LossGpuKernel
()
:
input_size_
(
1
),
sigma_
(
1.0
)
{}
~
SmoothL1LossGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
T
*
prediction
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
target
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
loss
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
SmoothL1Loss
(
input_size_
,
sigma_
,
prediction
,
target
,
loss
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
input_size_
*=
input_shape
[
i
];
}
sigma_
=
GetAttr
<
float
>
(
kernel_node
,
"sigma"
);
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
input_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
output_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
}
private:
size_t
input_size_
;
float
sigma_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GPU_KERNEL_H_
mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.cc
0 → 100644
浏览文件 @
cf5a27e9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
SmoothL1LossGrad
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
SmoothL1LossGradGpuKernel
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h
0 → 100644
浏览文件 @
cf5a27e9
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/smooth_l1_loss_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
SmoothL1LossGradGpuKernel
:
public
GpuKernel
{
public:
SmoothL1LossGradGpuKernel
()
:
input_size_
(
1
),
sigma_
(
1.0
)
{}
~
SmoothL1LossGradGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
T
*
prediction
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
target
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
dloss
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
T
*
dx
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
SmoothL1LossGrad
(
input_size_
,
sigma_
,
prediction
,
target
,
dloss
,
dx
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
input_shape
.
size
();
i
++
)
{
input_size_
*=
input_shape
[
i
];
}
sigma_
=
GetAttr
<
float
>
(
kernel_node
,
"sigma"
);
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
input_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
output_size_list_
.
push_back
(
input_size_
*
sizeof
(
T
));
}
private:
size_t
input_size_
;
float
sigma_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_SMOOTH_L1_LOSS_GRAD_GPU_KERNEL_H_
tests/st/ops/gpu/test_smoothl1loss_op.py
0 → 100644
浏览文件 @
cf5a27e9
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.ops
import
composite
as
C
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
,
save_graphs
=
True
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_smoothl1loss
():
np
.
random
.
seed
(
42
)
prediction
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
target
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
sigma
=
1.0
net
=
nn
.
SmoothL1Loss
(
sigma
)
loss
=
net
(
Tensor
(
prediction
),
Tensor
(
target
))
expect
=
[
0.46941718
,
0.00382918
,
0.16829303
,
2.447778
,
0.04812113
,
0.05953304
,
2.2302065
,
0.07672881
,
0.00860204
,
0.34798968
,
0.00956192
,
1.818008
,
0.03262977
,
0.36599946
,
2.047463
,
0.2168481
,
0.7216947
,
1.7739174
,
0.08826803
,
1.109165
]
assert
np
.
allclose
(
loss
.
asnumpy
(),
expect
)
class
Grad
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
Grad
,
self
).
__init__
()
self
.
grad
=
C
.
GradOperation
(
name
=
"get_all"
,
get_all
=
True
,
sens_param
=
True
)
self
.
network
=
network
def
construct
(
self
,
x1
,
x2
,
sens
):
gout
=
self
.
grad
(
self
.
network
)(
x1
,
x2
,
sens
)
return
gout
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_smoothl1loss_grad
():
np
.
random
.
seed
(
42
)
prediction
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
target
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
sens
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
sigma
=
1.0
net
=
nn
.
SmoothL1Loss
(
sigma
)
grad
=
Grad
(
net
)
dx
=
grad
(
Tensor
(
prediction
),
Tensor
(
target
),
Tensor
(
sens
))
dx1_expect
=
[
-
0.71552587
,
0.01499678
,
-
0.06709455
,
-
0.30110368
,
-
0.45868093
,
0.24838912
,
-
0.46063876
,
0.41411355
,
0.04507046
,
-
1.4708229
,
0.04481723
,
0.38508227
,
-
0.17292616
,
-
0.52333146
,
-
1.0309995
,
0.61330026
,
0.83921754
,
-
0.3092124
,
0.1391843
,
-
0.9755451
]
dx2_expect
=
[
0.71552587
,
-
0.01499678
,
0.06709455
,
0.30110368
,
0.45868093
,
-
0.24838912
,
0.46063876
,
-
0.41411355
,
-
0.04507046
,
1.4708229
,
-
0.04481723
,
-
0.38508227
,
0.17292616
,
0.52333146
,
1.0309995
,
-
0.61330026
,
-
0.83921754
,
0.3092124
,
-
0.1391843
,
0.9755451
]
assert
np
.
allclose
(
dx
[
0
].
asnumpy
(),
dx1_expect
)
assert
np
.
allclose
(
dx
[
1
].
asnumpy
(),
dx2_expect
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录