Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
13d1738f
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
13d1738f
编写于
8月 20, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 20, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4706 fix SmoothL1Loss gpu kernel
Merge pull request !4706 from Peilin/smoothL1Loss-fix
上级
9a3baf4f
0d5220d3
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
90 addition
and
52 deletion
+90
-52
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu
...kend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu
+17
-17
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh
...end/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh
+2
-2
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h
...ackend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h
+4
-4
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h
...d/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h
+4
-4
mindspore/ops/_grad/grad_nn_ops.py
mindspore/ops/_grad/grad_nn_ops.py
+1
-1
mindspore/ops/operations/_grad_ops.py
mindspore/ops/operations/_grad_ops.py
+1
-1
mindspore/ops/operations/nn_ops.py
mindspore/ops/operations/nn_ops.py
+6
-6
tests/st/ops/gpu/test_smoothl1loss_op.py
tests/st/ops/gpu/test_smoothl1loss_op.py
+55
-17
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cu
浏览文件 @
13d1738f
...
@@ -18,47 +18,47 @@
...
@@ -18,47 +18,47 @@
#include "runtime/device/gpu/cuda_common.h"
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SmoothL1LossKernel
(
const
int
input_size
,
const
float
sigm
a
,
const
T
*
prediction
,
const
T
*
target
,
__global__
void
SmoothL1LossKernel
(
const
int
input_size
,
const
float
bet
a
,
const
T
*
prediction
,
const
T
*
target
,
T
*
loss
)
{
T
*
loss
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
input_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
input_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
value
=
(
prediction
[
i
]
-
target
[
i
])
>
0
?
(
prediction
[
i
]
-
target
[
i
])
:
(
target
[
i
]
-
prediction
[
i
]);
T
value
=
fabsf
(
prediction
[
i
]
-
target
[
i
]);
if
(
value
<
sigm
a
)
{
if
(
value
<
bet
a
)
{
loss
[
i
]
=
static_cast
<
T
>
(
0.5
)
*
value
*
value
;
loss
[
i
]
=
0.5
*
value
*
value
/
beta
;
}
else
{
}
else
{
loss
[
i
]
=
value
-
static_cast
<
T
>
(
0.5
);
loss
[
i
]
=
value
-
(
0.5
*
beta
);
}
}
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
void
SmoothL1Loss
(
const
int
&
input_size
,
const
float
&
sigm
a
,
const
T
*
prediction
,
const
T
*
target
,
T
*
loss
,
void
SmoothL1Loss
(
const
int
&
input_size
,
const
float
&
bet
a
,
const
T
*
prediction
,
const
T
*
target
,
T
*
loss
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
SmoothL1LossKernel
<<<
GET_BLOCKS
(
input_size
),
GET_THREADS
,
0
,
stream
>>>
(
input_size
,
sigm
a
,
prediction
,
target
,
loss
);
SmoothL1LossKernel
<<<
GET_BLOCKS
(
input_size
),
GET_THREADS
,
0
,
stream
>>>
(
input_size
,
bet
a
,
prediction
,
target
,
loss
);
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
SmoothL1LossGradKernel
(
const
int
input_size
,
const
float
sigm
a
,
const
T
*
prediction
,
const
T
*
target
,
__global__
void
SmoothL1LossGradKernel
(
const
int
input_size
,
const
float
bet
a
,
const
T
*
prediction
,
const
T
*
target
,
const
T
*
dloss
,
T
*
dx
)
{
const
T
*
dloss
,
T
*
dx
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
input_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
input_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
value
=
prediction
[
i
]
-
target
[
i
];
T
value
=
prediction
[
i
]
-
target
[
i
];
if
(
value
>
static_cast
<
T
>
(
sigma
)
)
{
if
(
value
>
beta
)
{
dx
[
i
]
=
dloss
[
i
];
dx
[
i
]
=
dloss
[
i
];
}
else
if
(
value
<
static_cast
<
T
>
(
-
sigma
)
)
{
}
else
if
(
value
<
-
beta
)
{
dx
[
i
]
=
-
dloss
[
i
];
dx
[
i
]
=
-
dloss
[
i
];
}
else
{
}
else
{
dx
[
i
]
=
value
*
dloss
[
i
];
dx
[
i
]
=
(
value
/
beta
)
*
dloss
[
i
];
}
}
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
void
SmoothL1LossGrad
(
const
int
&
input_size
,
const
float
&
sigm
a
,
const
T
*
prediction
,
const
T
*
target
,
const
T
*
dloss
,
void
SmoothL1LossGrad
(
const
int
&
input_size
,
const
float
&
bet
a
,
const
T
*
prediction
,
const
T
*
target
,
const
T
*
dloss
,
T
*
dx
,
cudaStream_t
stream
)
{
T
*
dx
,
cudaStream_t
stream
)
{
SmoothL1LossGradKernel
<<<
GET_BLOCKS
(
input_size
),
GET_THREADS
,
0
,
stream
>>>
(
input_size
,
sigm
a
,
prediction
,
target
,
SmoothL1LossGradKernel
<<<
GET_BLOCKS
(
input_size
),
GET_THREADS
,
0
,
stream
>>>
(
input_size
,
bet
a
,
prediction
,
target
,
dloss
,
dx
);
dloss
,
dx
);
}
}
template
void
SmoothL1Loss
(
const
int
&
input_size
,
const
float
&
sigma
,
const
float
*
prediction
,
const
float
*
target
,
template
void
SmoothL1Loss
<
float
>(
const
int
&
input_size
,
const
float
&
beta
,
const
float
*
prediction
,
float
*
loss
,
cudaStream_t
stream
);
const
float
*
target
,
float
*
loss
,
cudaStream_t
stream
);
template
void
SmoothL1LossGrad
(
const
int
&
input_size
,
const
float
&
sigma
,
const
float
*
prediction
,
const
float
*
target
,
template
void
SmoothL1LossGrad
<
float
>(
const
int
&
input_size
,
const
float
&
beta
,
const
float
*
prediction
,
const
float
*
dloss
,
float
*
dx
,
cudaStream_t
stream
);
const
float
*
target
,
const
float
*
dloss
,
float
*
dx
,
cudaStream_t
stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/smooth_l1_loss_impl.cuh
浏览文件 @
13d1738f
...
@@ -17,9 +17,9 @@
...
@@ -17,9 +17,9 @@
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
template
<
typename
T
>
template
<
typename
T
>
void
SmoothL1Loss
(
const
int
&
input_size
,
const
float
&
sigm
a
,
const
T
*
prediction
,
const
T
*
target
,
T
*
loss
,
void
SmoothL1Loss
(
const
int
&
input_size
,
const
float
&
bet
a
,
const
T
*
prediction
,
const
T
*
target
,
T
*
loss
,
cudaStream_t
stream
);
cudaStream_t
stream
);
template
<
typename
T
>
template
<
typename
T
>
void
SmoothL1LossGrad
(
const
int
&
input_size
,
const
float
&
sigm
a
,
const
T
*
prediction
,
const
T
*
target
,
const
T
*
dloss
,
void
SmoothL1LossGrad
(
const
int
&
input_size
,
const
float
&
bet
a
,
const
T
*
prediction
,
const
T
*
target
,
const
T
*
dloss
,
T
*
dx
,
cudaStream_t
stream
);
T
*
dx
,
cudaStream_t
stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SMOOTH_L1_LOSS_H_
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_gpu_kernel.h
浏览文件 @
13d1738f
...
@@ -26,7 +26,7 @@ namespace kernel {
...
@@ -26,7 +26,7 @@ namespace kernel {
template
<
typename
T
>
template
<
typename
T
>
class
SmoothL1LossGpuKernel
:
public
GpuKernel
{
class
SmoothL1LossGpuKernel
:
public
GpuKernel
{
public:
public:
SmoothL1LossGpuKernel
()
:
input_size_
(
1
),
sigm
a_
(
1.0
)
{}
SmoothL1LossGpuKernel
()
:
input_size_
(
1
),
bet
a_
(
1.0
)
{}
~
SmoothL1LossGpuKernel
()
override
=
default
;
~
SmoothL1LossGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
...
@@ -39,7 +39,7 @@ class SmoothL1LossGpuKernel : public GpuKernel {
...
@@ -39,7 +39,7 @@ class SmoothL1LossGpuKernel : public GpuKernel {
T
*
target
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
target
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
loss
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
T
*
loss
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
SmoothL1Loss
(
input_size_
,
sigm
a_
,
prediction
,
target
,
loss
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
SmoothL1Loss
(
input_size_
,
bet
a_
,
prediction
,
target
,
loss
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
return
true
;
}
}
...
@@ -49,7 +49,7 @@ class SmoothL1LossGpuKernel : public GpuKernel {
...
@@ -49,7 +49,7 @@ class SmoothL1LossGpuKernel : public GpuKernel {
input_size_
*=
input_shape
[
i
];
input_size_
*=
input_shape
[
i
];
}
}
sigma_
=
GetAttr
<
float
>
(
kernel_node
,
"sigm
a"
);
beta_
=
GetAttr
<
float
>
(
kernel_node
,
"bet
a"
);
InitSizeLists
();
InitSizeLists
();
return
true
;
return
true
;
}
}
...
@@ -63,7 +63,7 @@ class SmoothL1LossGpuKernel : public GpuKernel {
...
@@ -63,7 +63,7 @@ class SmoothL1LossGpuKernel : public GpuKernel {
private:
private:
size_t
input_size_
;
size_t
input_size_
;
float
sigm
a_
;
float
bet
a_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/nn/smooth_l1_loss_grad_gpu_kernel.h
浏览文件 @
13d1738f
...
@@ -26,7 +26,7 @@ namespace kernel {
...
@@ -26,7 +26,7 @@ namespace kernel {
template
<
typename
T
>
template
<
typename
T
>
class
SmoothL1LossGradGpuKernel
:
public
GpuKernel
{
class
SmoothL1LossGradGpuKernel
:
public
GpuKernel
{
public:
public:
SmoothL1LossGradGpuKernel
()
:
input_size_
(
1
),
sigm
a_
(
1.0
)
{}
SmoothL1LossGradGpuKernel
()
:
input_size_
(
1
),
bet
a_
(
1.0
)
{}
~
SmoothL1LossGradGpuKernel
()
override
=
default
;
~
SmoothL1LossGradGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
...
@@ -40,7 +40,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel {
...
@@ -40,7 +40,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel {
T
*
dloss
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
T
*
dloss
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
T
*
dx
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
T
*
dx
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
SmoothL1LossGrad
(
input_size_
,
sigm
a_
,
prediction
,
target
,
dloss
,
dx
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
SmoothL1LossGrad
(
input_size_
,
bet
a_
,
prediction
,
target
,
dloss
,
dx
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
return
true
;
}
}
...
@@ -50,7 +50,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel {
...
@@ -50,7 +50,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel {
input_size_
*=
input_shape
[
i
];
input_size_
*=
input_shape
[
i
];
}
}
sigma_
=
GetAttr
<
float
>
(
kernel_node
,
"sigm
a"
);
beta_
=
GetAttr
<
float
>
(
kernel_node
,
"bet
a"
);
InitSizeLists
();
InitSizeLists
();
return
true
;
return
true
;
}
}
...
@@ -64,7 +64,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel {
...
@@ -64,7 +64,7 @@ class SmoothL1LossGradGpuKernel : public GpuKernel {
private:
private:
size_t
input_size_
;
size_t
input_size_
;
float
sigm
a_
;
float
bet
a_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
...
...
mindspore/ops/_grad/grad_nn_ops.py
浏览文件 @
13d1738f
...
@@ -713,7 +713,7 @@ def get_bprop_top_kv2(self):
...
@@ -713,7 +713,7 @@ def get_bprop_top_kv2(self):
@
bprop_getters
.
register
(
P
.
SmoothL1Loss
)
@
bprop_getters
.
register
(
P
.
SmoothL1Loss
)
def
get_bprop_smooth_l1_loss
(
self
):
def
get_bprop_smooth_l1_loss
(
self
):
"""Grad definition for `SmoothL1Loss` operation."""
"""Grad definition for `SmoothL1Loss` operation."""
grad
=
G
.
SmoothL1LossGrad
(
self
.
sigm
a
)
grad
=
G
.
SmoothL1LossGrad
(
self
.
bet
a
)
def
bprop
(
prediction
,
target
,
out
,
dout
):
def
bprop
(
prediction
,
target
,
out
,
dout
):
dx
=
grad
(
prediction
,
target
,
dout
)
dx
=
grad
(
prediction
,
target
,
dout
)
...
...
mindspore/ops/operations/_grad_ops.py
浏览文件 @
13d1738f
...
@@ -1274,7 +1274,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer):
...
@@ -1274,7 +1274,7 @@ class SmoothL1LossGrad(PrimitiveWithInfer):
"""Computes gradient for prediction on SmoothL1Loss."""
"""Computes gradient for prediction on SmoothL1Loss."""
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
sigm
a
=
1.0
):
def
__init__
(
self
,
bet
a
=
1.0
):
pass
pass
def
infer_shape
(
self
,
prediction
,
target
,
dloss
):
def
infer_shape
(
self
,
prediction
,
target
,
dloss
):
...
...
mindspore/ops/operations/nn_ops.py
浏览文件 @
13d1738f
...
@@ -1725,11 +1725,11 @@ class SmoothL1Loss(PrimitiveWithInfer):
...
@@ -1725,11 +1725,11 @@ class SmoothL1Loss(PrimitiveWithInfer):
Sets input prediction as `X`, input target as `Y`, output as `loss`. Then,
Sets input prediction as `X`, input target as `Y`, output as `loss`. Then,
.. math::
.. math::
\text{SmoothL1Loss} = \begin{cases}
0.5x^{2}, &if \left |x \right |\leq \text{sigm
a} \cr
\text{SmoothL1Loss} = \begin{cases}
\frac{0.5 x^{2}}{\text{beta}, &if \left |x \right | < \text{bet
a} \cr
\left |x \right|-0.5, &\text{otherwise}\end{cases}
\left |x \right|-0.5
\text{beta}
, &\text{otherwise}\end{cases}
Args:
Args:
sigm
a (float): A parameter used to control the point where the function will change from
bet
a (float): A parameter used to control the point where the function will change from
quadratic to linear. Default: 1.0.
quadratic to linear. Default: 1.0.
Inputs:
Inputs:
...
@@ -1748,9 +1748,9 @@ class SmoothL1Loss(PrimitiveWithInfer):
...
@@ -1748,9 +1748,9 @@ class SmoothL1Loss(PrimitiveWithInfer):
"""
"""
@
prim_attr_register
@
prim_attr_register
def
__init__
(
self
,
sigm
a
=
1.0
):
def
__init__
(
self
,
bet
a
=
1.0
):
validator
.
check_value_type
(
'
sigma'
,
sigm
a
,
[
float
],
self
.
name
)
validator
.
check_value_type
(
'
beta'
,
bet
a
,
[
float
],
self
.
name
)
validator
.
check
(
'
sigma'
,
sigm
a
,
''
,
0
,
Rel
.
GT
,
self
.
name
)
validator
.
check
(
'
beta'
,
bet
a
,
''
,
0
,
Rel
.
GT
,
self
.
name
)
self
.
init_prim_io_names
(
inputs
=
[
'prediction'
,
'target'
],
outputs
=
[
'output'
])
self
.
init_prim_io_names
(
inputs
=
[
'prediction'
,
'target'
],
outputs
=
[
'output'
])
def
infer_shape
(
self
,
prediction
,
target
):
def
infer_shape
(
self
,
prediction
,
target
):
...
...
tests/st/ops/gpu/test_smoothl1loss_op.py
浏览文件 @
13d1738f
...
@@ -21,25 +21,39 @@ import mindspore.nn as nn
...
@@ -21,25 +21,39 @@ import mindspore.nn as nn
from
mindspore
import
Tensor
from
mindspore
import
Tensor
from
mindspore.ops
import
composite
as
C
from
mindspore.ops
import
composite
as
C
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
,
save_graphs
=
True
)
def
smoothl1loss
(
beta
):
np
.
random
.
seed
(
42
)
prediction
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
target
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
net
=
nn
.
SmoothL1Loss
(
beta
)
return
net
(
Tensor
(
prediction
),
Tensor
(
target
))
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
@
pytest
.
mark
.
env_onecard
def
test_smoothl1loss
():
def
test_smoothl1loss
():
np
.
random
.
seed
(
42
)
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
,
save_graphs
=
True
)
prediction
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
target
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
epsilon
=
1e-6
sigma
=
1.0
net
=
nn
.
SmoothL1Loss
(
sigma
)
beta
=
1.0
loss
=
net
(
Tensor
(
prediction
),
Tensor
(
target
)
)
loss
=
smoothl1loss
(
beta
)
expect
=
[
0.46941718
,
0.00382918
,
0.16829303
,
2.447778
,
0.04812113
,
0.05953304
,
expect
=
[
0.46941718
,
0.00382918
,
0.16829303
,
2.447778
,
0.04812113
,
0.05953304
,
2.2302065
,
0.07672881
,
0.00860204
,
0.34798968
,
0.00956192
,
1.818008
,
2.2302065
,
0.07672881
,
0.00860204
,
0.34798968
,
0.00956192
,
1.818008
,
0.03262977
,
0.36599946
,
2.047463
,
0.2168481
,
0.7216947
,
1.7739174
,
0.03262977
,
0.36599946
,
2.047463
,
0.2168481
,
0.7216947
,
1.7739174
,
0.08826803
,
1.109165
]
0.08826803
,
1.109165
]
assert
np
.
allclose
(
loss
.
asnumpy
(),
expect
)
diff
=
np
.
absolute
(
loss
.
asnumpy
()
-
np
.
array
(
expect
))
assert
(
diff
<
epsilon
).
all
()
beta
=
1
/
9
loss
=
smoothl1loss
(
beta
)
expect
=
[
0.9133791
,
0.03446258
,
0.5246048
,
2.8922224
,
0.2546738
,
0.289504
,
2.674651
,
0.33618113
,
0.07560876
,
0.7786982
,
0.08273339
,
2.2624524
,
0.19990394
,
0.8000138
,
2.4919074
,
0.6030006
,
1.1661391
,
2.2183619
,
0.3646064
,
1.5536094
]
diff
=
np
.
absolute
(
loss
.
asnumpy
()
-
np
.
array
(
expect
))
assert
(
diff
<
epsilon
).
all
()
class
Grad
(
nn
.
Cell
):
class
Grad
(
nn
.
Cell
):
...
@@ -53,20 +67,26 @@ class Grad(nn.Cell):
...
@@ -53,20 +67,26 @@ class Grad(nn.Cell):
return
gout
return
gout
@
pytest
.
mark
.
level0
def
smoothl1loss_grad
(
beta
):
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_smoothl1loss_grad
():
np
.
random
.
seed
(
42
)
np
.
random
.
seed
(
42
)
prediction
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
prediction
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
target
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
target
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
sens
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
sens
=
np
.
random
.
randn
(
20
).
astype
(
np
.
float32
)
sigma
=
1.0
net
=
nn
.
SmoothL1Loss
(
sigm
a
)
net
=
nn
.
SmoothL1Loss
(
bet
a
)
grad
=
Grad
(
net
)
grad
=
Grad
(
net
)
dx
=
grad
(
Tensor
(
prediction
),
Tensor
(
target
),
Tensor
(
sens
))
return
grad
(
Tensor
(
prediction
),
Tensor
(
target
),
Tensor
(
sens
))
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_smoothl1loss_grad
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
,
save_graphs
=
True
)
epsilon
=
1e-6
beta
=
1.0
dx
=
smoothl1loss_grad
(
beta
)
dx1_expect
=
[
-
0.71552587
,
0.01499678
,
-
0.06709455
,
-
0.30110368
,
-
0.45868093
,
dx1_expect
=
[
-
0.71552587
,
0.01499678
,
-
0.06709455
,
-
0.30110368
,
-
0.45868093
,
0.24838912
,
-
0.46063876
,
0.41411355
,
0.04507046
,
-
1.4708229
,
0.24838912
,
-
0.46063876
,
0.41411355
,
0.04507046
,
-
1.4708229
,
0.04481723
,
0.38508227
,
-
0.17292616
,
-
0.52333146
,
-
1.0309995
,
0.04481723
,
0.38508227
,
-
0.17292616
,
-
0.52333146
,
-
1.0309995
,
...
@@ -77,5 +97,23 @@ def test_smoothl1loss_grad():
...
@@ -77,5 +97,23 @@ def test_smoothl1loss_grad():
-
0.04481723
,
-
0.38508227
,
0.17292616
,
0.52333146
,
1.0309995
,
-
0.04481723
,
-
0.38508227
,
0.17292616
,
0.52333146
,
1.0309995
,
-
0.61330026
,
-
0.83921754
,
0.3092124
,
-
0.1391843
,
0.9755451
]
-
0.61330026
,
-
0.83921754
,
0.3092124
,
-
0.1391843
,
0.9755451
]
assert
np
.
allclose
(
dx
[
0
].
asnumpy
(),
dx1_expect
)
diff1
=
np
.
absolute
(
dx
[
0
].
asnumpy
()
-
np
.
array
(
dx1_expect
))
assert
np
.
allclose
(
dx
[
1
].
asnumpy
(),
dx2_expect
)
diff2
=
np
.
absolute
(
dx
[
1
].
asnumpy
()
-
np
.
array
(
dx2_expect
))
assert
(
diff1
<
epsilon
).
all
()
assert
(
diff2
<
epsilon
).
all
()
beta
=
1
/
9
dx
=
smoothl1loss_grad
(
beta
)
dx1_expect
=
[
-
0.73846656
,
0.13497104
,
-
0.11564828
,
-
0.30110368
,
-
1.478522
,
0.7198442
,
-
0.46063876
,
1.0571222
,
0.3436183
,
-
1.7630402
,
0.32408398
,
0.38508227
,
-
0.676922
,
-
0.6116763
,
-
1.0309995
,
0.93128014
,
0.83921754
,
-
0.3092124
,
0.33126342
,
-
0.9755451
]
dx2_expect
=
[
0.73846656
,
-
0.13497104
,
0.11564828
,
0.30110368
,
1.478522
,
-
0.7198442
,
0.46063876
,
-
1.0571222
,
-
0.3436183
,
1.7630402
,
-
0.32408398
,
-
0.38508227
,
0.676922
,
0.6116763
,
1.0309995
,
-
0.93128014
,
-
0.83921754
,
0.3092124
,
-
0.33126342
,
0.9755451
]
diff1
=
np
.
absolute
(
dx
[
0
].
asnumpy
()
-
np
.
array
(
dx1_expect
))
diff2
=
np
.
absolute
(
dx
[
1
].
asnumpy
()
-
np
.
array
(
dx2_expect
))
assert
(
diff1
<
epsilon
).
all
()
assert
(
diff2
<
epsilon
).
all
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录