Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f4ac6a2c
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f4ac6a2c
编写于
8月 03, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 03, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3600 Modifying SGD operator bug and adding operator check.
Merge pull request !3600 from linqingke/new_ops
上级
6e21618f
d54fe5a6
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
20 addition
and
14 deletion
+20
-14
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h
...ackend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h
+6
-6
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu
...e/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu
+3
-3
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu
...ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu
+4
-2
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu
...e/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu
+3
-3
mindspore/ops/operations/other_ops.py
mindspore/ops/operations/other_ops.py
+4
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/scatter_nd_gpu_kernel.h
浏览文件 @
f4ac6a2c
...
...
@@ -69,7 +69,10 @@ class ScatterNdGpuFwdKernel : public GpuKernel {
memcpy_flag_
=
true
;
}
ScatterNd
(
indices
,
update
,
output
,
block_size_
,
input_size_
,
output_size_
,
indices_dim_0_
,
indices_dim_1_
,
const
size_t
input_size
=
input_size_
/
sizeof
(
T
);
const
size_t
output_size
=
output_size_
/
sizeof
(
T
);
ScatterNd
(
indices
,
update
,
output
,
block_size_
,
input_size
,
output_size
,
indices_dim_0_
,
indices_dim_1_
,
indices_stride_
,
work_shape_
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
...
...
@@ -138,7 +141,7 @@ class ScatterNdGpuFwdKernel : public GpuKernel {
// calculate indices dim 0/1
indices_dim_0_
=
indices_shapes_
[
0
];
indices_dim_1_
=
indices_shapes_
[
1
];
indices_dim_1_
=
indices_shapes_
[
indices_shapes_
.
size
()
-
1
];
// calculate block_size
for
(
size_t
i
=
indices_dim_1_
;
i
<
output_shapes_
.
size
();
i
++
)
{
...
...
@@ -146,10 +149,7 @@ class ScatterNdGpuFwdKernel : public GpuKernel {
}
// calculate indices_stride
for
(
size_t
i
=
0
;
i
<
indices_dim_1_
;
i
++
)
{
vec_indices_stride_
.
push_back
(
0
);
}
vec_indices_stride_
.
resize
(
indices_dim_1_
,
0
);
vec_indices_stride_
[
indices_dim_1_
-
1
]
=
block_size_
;
for
(
size_t
i
=
indices_dim_1_
-
1
;
i
>
0
;
--
i
)
{
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/iou_impl.cu
浏览文件 @
f4ac6a2c
...
...
@@ -50,12 +50,12 @@ __global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *io
T
area1
=
(
location_coordinate
[
0
][
2
]
-
location_coordinate
[
0
][
0
]
+
1
)
*
(
location_coordinate
[
0
][
3
]
-
location_coordinate
[
0
][
1
]
+
1
);
T
area2
=
(
location_coordinate
[
1
][
2
]
-
location_coordinate
[
1
][
0
]
+
1
)
*
(
location_coordinate
[
1
][
3
]
-
location_coordinate
[
1
][
1
]
+
1
);
if
(
mode
==
0
)
{
T
area2
=
(
location_coordinate
[
1
][
2
]
-
location_coordinate
[
1
][
0
]
+
1
)
*
(
location_coordinate
[
1
][
3
]
-
location_coordinate
[
1
][
1
]
+
1
);
iou_results
[
i
]
=
overlaps
/
(
area1
+
area2
-
overlaps
+
epsilon
);
}
else
{
iou_results
[
i
]
=
overlaps
/
(
area
1
+
epsilon
);
iou_results
[
i
]
=
overlaps
/
(
area
2
+
epsilon
);
}
}
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cu
浏览文件 @
f4ac6a2c
...
...
@@ -15,7 +15,9 @@
*/
#include "backend/kernel_compiler/gpu/cuda_impl/scatter_nd.cuh"
#include "backend/kernel_compiler/gpu/cuda_impl/util.cuh"
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
,
typename
S
>
__global__
void
ScatterNdKernel
(
S
*
indices
,
T
*
update
,
T
*
output
,
const
size_t
block_size
,
const
size_t
input_size
,
const
size_t
output_size
,
const
size_t
indices_dim_0
,
const
size_t
indices_dim_1
,
...
...
@@ -39,7 +41,7 @@ __global__ void ScatterNdKernel(S *indices, T *update, T *output, const size_t b
out_bound
|=
write_index
>=
output_size
;
if
(
!
out_bound
)
{
output
[
write_index
]
=
update
[
read_index
]
;
ms_atomic_add
(
&
output
[
write_index
],
update
[
read_index
])
;
}
}
}
...
...
@@ -48,7 +50,7 @@ template <typename T, typename S>
void
ScatterNd
(
S
*
indices
,
T
*
update
,
T
*
output
,
const
size_t
&
block_size
,
const
size_t
&
input_size
,
const
size_t
&
output_size
,
const
size_t
&
indices_dim_0
,
const
size_t
&
indices_dim_1
,
S
*
indices_stride
,
S
*
work_shape
,
cudaStream_t
stream
)
{
ScatterNdKernel
<<<
GET_BLOCKS
(
in
put_size
),
GET_THREADS
,
0
,
stream
>>>
(
indices
,
update
,
output
,
block_size
,
input_size
,
ScatterNdKernel
<<<
GET_BLOCKS
(
out
put_size
),
GET_THREADS
,
0
,
stream
>>>
(
indices
,
update
,
output
,
block_size
,
input_size
,
output_size
,
indices_dim_0
,
indices_dim_1
,
indices_stride
,
work_shape
);
return
;
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/sgd_impl.cu
浏览文件 @
f4ac6a2c
...
...
@@ -22,12 +22,12 @@ __global__ void SGDKernel(const int size, const T dampening, const T weight_deca
const
T
*
momentum
,
const
T
*
lr
,
T
*
param
,
T
*
accum
,
T
*
stat
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
size
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
grad_new
=
grad
[
i
];
if
(
weight_decay
!=
static_cast
<
T
>
(
0
))
{
if
(
weight_decay
>
static_cast
<
T
>
(
0
))
{
grad_new
+=
param
[
i
]
*
weight_decay
;
}
if
(
momentum
[
0
]
!=
static_cast
<
T
>
(
0
))
{
if
(
stat
[
i
]
==
static_cast
<
T
>
(
0
))
{
if
(
momentum
[
0
]
>
static_cast
<
T
>
(
0
))
{
if
(
stat
[
i
]
>
static_cast
<
T
>
(
0
))
{
accum
[
i
]
=
grad_new
;
stat
[
i
]
=
0
;
}
else
{
...
...
mindspore/ops/operations/other_ops.py
浏览文件 @
f4ac6a2c
...
...
@@ -101,6 +101,8 @@ class BoundingBoxEncode(PrimitiveWithInfer):
def
infer_shape
(
self
,
anchor_box
,
groundtruth_box
):
validator
.
check
(
'anchor_box shape[0]'
,
anchor_box
[
0
],
'groundtruth_box shape[0]'
,
groundtruth_box
[
0
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"anchor_box rank"
,
len
(
anchor_box
),
""
,
2
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"groundtruth_box rank"
,
len
(
groundtruth_box
),
""
,
2
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
'anchor_box shape[1]'
,
anchor_box
[
1
],
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
'groundtruth_box shape[1]'
,
groundtruth_box
[
1
],
4
,
Rel
.
EQ
,
self
.
name
)
return
anchor_box
...
...
@@ -152,6 +154,8 @@ class BoundingBoxDecode(PrimitiveWithInfer):
def
infer_shape
(
self
,
anchor_box
,
deltas
):
validator
.
check
(
'anchor_box shape[0]'
,
anchor_box
[
0
],
'deltas shape[0]'
,
deltas
[
0
],
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"anchor_box rank"
,
len
(
anchor_box
),
""
,
2
,
Rel
.
EQ
,
self
.
name
)
validator
.
check
(
"deltas rank"
,
len
(
deltas
),
""
,
2
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
'anchor_box shape[1]'
,
anchor_box
[
1
],
4
,
Rel
.
EQ
,
self
.
name
)
validator
.
check_integer
(
'deltas shape[1]'
,
deltas
[
1
],
4
,
Rel
.
EQ
,
self
.
name
)
return
anchor_box
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录