Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
e7a99397
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
e7a99397
编写于
7月 27, 2020
作者:
P
peixu_ren
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add random uniform real op at GPU end
上级
16079e63
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
120 addition
and
6 deletion
+120
-6
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu
...c/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu
+24
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh
.../backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh
+4
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc
.../backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc
+7
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h
...c/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h
+42
-6
tests/st/ops/gpu/test_uniform_real.py
tests/st/ops/gpu/test_uniform_real.py
+43
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu
浏览文件 @
e7a99397
...
...
@@ -24,6 +24,18 @@ __global__ void NormalKernel(int seed, curandState *globalState, T *output, size
return
;
}
template
<
typename
T
>
__global__
void
UniformKernel
(
int
seed
,
curandState
*
globalState
,
T
*
input1
,
size_t
input_size_1
,
T
*
input2
,
size_t
input_size_2
,
T
*
output
,
size_t
count
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
input1
[
i
]
=
(
input_size_1
==
1
?
input1
[
0
]
:
input1
[
i
]);
input2
[
i
]
=
(
input_size_2
==
1
?
input2
[
0
]
:
input2
[
i
]);
curand_init
(
seed
,
i
,
0
,
&
globalState
[
i
]);
output
[
i
]
=
curand_uniform
(
&
globalState
[
i
])
*
(
input2
[
i
]
-
input1
[
i
])
+
input1
[
i
];
}
return
;
}
template
<
typename
T
>
void
StandardNormal
(
int
seed
,
int
seed2
,
curandState
*
globalState
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
int
RNG_seed
=
0
;
...
...
@@ -38,5 +50,17 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si
return
;
}
template
<
typename
T
>
void
UniformReal
(
int
seed
,
curandState
*
globalState
,
T
*
input1
,
size_t
input_size_1
,
T
*
input2
,
size_t
input_size_2
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
seed
=
(
seed
==
0
?
time
(
NULL
)
:
seed
);
UniformKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
seed
,
globalState
,
input1
,
input_size_1
,
input2
,
input_size_2
,
output
,
count
);
return
;
}
template
void
StandardNormal
<
float
>(
int
seed
,
int
seed2
,
curandState
*
globalState
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
UniformReal
<
float
>(
int
seed
,
curandState
*
globalState
,
float
*
input1
,
size_t
input_size_1
,
float
*
input2
,
size_t
input_size_2
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh
浏览文件 @
e7a99397
...
...
@@ -23,4 +23,8 @@
template
<
typename
T
>
void
StandardNormal
(
int
seed
,
int
seed2
,
curandState
*
globalState
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
UniformReal
(
int
seed
,
curandState
*
globalState
,
T
*
input1
,
size_t
input_size_1
,
T
*
input2
,
size_t
input_size_2
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_
mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc
浏览文件 @
e7a99397
...
...
@@ -20,5 +20,12 @@ namespace mindspore {
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
StandardNormal
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeFloat32
),
RandomOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
UniformReal
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeInt32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
RandomOpGpuKernel
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h
浏览文件 @
e7a99397
...
...
@@ -28,17 +28,22 @@
namespace
mindspore
{
namespace
kernel
{
enum
RandomOptype
{
RANDOM_OP_NORMAL
=
0
,
RANDOM_OP_INVALID_TYPE
=
255
};
enum
RandomOptype
{
RANDOM_OP_NORMAL
=
0
,
RANDOM_OP_
UNIFORM_REAL
,
RANDOM_OP_
INVALID_TYPE
=
255
};
const
std
::
map
<
std
::
string
,
RandomOptype
>
kRandomOpTypeMap
=
{{
"StandardNormal"
,
RANDOM_OP_NORMAL
}};
const
std
::
map
<
std
::
string
,
RandomOptype
>
kRandomOpTypeMap
=
{{
"StandardNormal"
,
RANDOM_OP_NORMAL
},
{
"UniformReal"
,
RANDOM_OP_UNIFORM_REAL
}};
template
<
typename
T
>
class
RandomOpGpuKernel
:
public
GpuKernel
{
public:
RandomOpGpuKernel
()
:
random_op_type_
(
RANDOM_OP_INVALID_TYPE
),
input_size_0_
(
0
),
input_size_0_
(
sizeof
(
int
)),
input_size_1_
(
sizeof
(
T
)),
input_size_2_
(
sizeof
(
T
)),
output_size_
(
sizeof
(
T
)),
workspace_size_
(
sizeof
(
curandState
))
{}
workspace_size_
(
sizeof
(
curandState
)),
seed_
(
0
),
seed2_
(
0
)
{}
~
RandomOpGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
...
...
@@ -57,12 +62,21 @@ class RandomOpGpuKernel : public GpuKernel {
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
break
;
}
case
RANDOM_OP_UNIFORM_REAL
:
{
T
*
input_addr_1
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
input_addr_2
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
UniformReal
(
seed_
,
devStates
,
input_addr_1
,
inputs
[
1
]
->
size
/
sizeof
(
T
),
input_addr_2
,
inputs
[
2
]
->
size
/
sizeof
(
T
),
output_addr
,
outputs
[
0
]
->
size
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
break
;
}
default:
{
MS_LOG
(
EXCEPTION
)
<<
"Random operation "
<<
random_op_type_
<<
" is not supported."
;
}
}
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
std
::
string
kernel_name
=
AnfAlgo
::
GetCNodeName
(
kernel_node
);
auto
iter
=
kRandomOpTypeMap
.
find
(
kernel_name
);
...
...
@@ -72,10 +86,14 @@ class RandomOpGpuKernel : public GpuKernel {
random_op_type_
=
iter
->
second
;
}
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
1
)
{
if
(
random_op_type_
==
RANDOM_OP_NORMAL
&&
input_num
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but random op needs 1 input."
;
return
false
;
}
if
(
random_op_type_
==
RANDOM_OP_UNIFORM_REAL
&&
input_num
!=
3
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but random op needs 3 inputs."
;
return
false
;
}
size_t
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
if
(
output_num
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Output number is "
<<
output_num
<<
", but random op needs 1 output."
;
...
...
@@ -86,13 +104,25 @@ class RandomOpGpuKernel : public GpuKernel {
input_size_0_
+=
input_shape_0
[
i
];
}
input_size_0_
*=
sizeof
(
int
);
if
(
random_op_type_
==
RANDOM_OP_UNIFORM_REAL
)
{
auto
input_shape_1
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
1
);
for
(
size_t
i
=
0
;
i
<
input_shape_1
.
size
();
i
++
)
{
input_size_1_
*=
input_shape_1
[
i
];
}
auto
input_shape_2
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
2
);
for
(
size_t
i
=
0
;
i
<
input_shape_2
.
size
();
i
++
)
{
input_size_2_
*=
input_shape_2
[
i
];
}
}
auto
output_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
output_shape
.
size
();
i
++
)
{
output_size_
*=
output_shape
[
i
];
workspace_size_
*=
output_shape
[
i
];
}
seed_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"seed"
));
seed2_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"seed2"
));
if
(
random_op_type_
==
RANDOM_OP_NORMAL
)
{
seed2_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"seed2"
));
}
InitSizeLists
();
return
true
;
}
...
...
@@ -100,6 +130,10 @@ class RandomOpGpuKernel : public GpuKernel {
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input_size_0_
);
if
(
random_op_type_
==
RANDOM_OP_UNIFORM_REAL
)
{
input_size_list_
.
push_back
(
input_size_1_
);
input_size_list_
.
push_back
(
input_size_2_
);
}
output_size_list_
.
push_back
(
output_size_
);
workspace_size_list_
.
push_back
(
workspace_size_
);
}
...
...
@@ -107,6 +141,8 @@ class RandomOpGpuKernel : public GpuKernel {
private:
RandomOptype
random_op_type_
;
size_t
input_size_0_
;
size_t
input_size_1_
;
size_t
input_size_2_
;
size_t
output_size_
;
size_t
workspace_size_
;
int
seed_
;
...
...
tests/st/ops/gpu/test_uniform_real.py
0 → 100644
浏览文件 @
e7a99397
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
from
mindspore.common
import
dtype
as
mstype
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
class
Net
(
nn
.
Cell
):
def
__init__
(
self
,
shape
,
seed
=
0
):
super
(
Net
,
self
).
__init__
()
self
.
uniformreal
=
P
.
UniformReal
(
seed
=
seed
)
self
.
shape
=
shape
def
construct
(
self
,
a
,
b
):
return
self
.
uniformreal
(
self
.
shape
,
a
,
b
)
def
test_net_1D
():
seed
=
10
shape
=
(
3
,
2
,
4
)
a
=
0.0
b
=
1.0
net
=
Net
(
shape
,
seed
)
ta
,
tb
=
Tensor
(
a
,
mstype
.
float32
),
Tensor
(
b
,
mstype
.
float32
)
output
=
net
(
ta
,
tb
)
print
(
output
.
asnumpy
())
assert
output
.
shape
==
(
3
,
2
,
4
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录