Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2b4febb4
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2b4febb4
编写于
8月 19, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 19, 2020
浏览文件
操作
浏览文件
下载
差异文件
!4436 Refactor uniform ops in GPU context
Merge pull request !4436 from peixu_ren/custom_gpu
上级
58523a41
5dd49333
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
83 addition
and
43 deletion
+83
-43
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu
...c/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu
+51
-14
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh
.../backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh
+6
-3
mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc
.../backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc
+6
-4
mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h
...c/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h
+20
-22
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu
浏览文件 @
2b4febb4
...
@@ -19,19 +19,26 @@ template <typename T>
...
@@ -19,19 +19,26 @@ template <typename T>
__global__
void
NormalKernel
(
int
seed
,
curandState
*
globalState
,
T
*
output
,
size_t
count
)
{
__global__
void
NormalKernel
(
int
seed
,
curandState
*
globalState
,
T
*
output
,
size_t
count
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
curand_init
(
seed
,
i
,
0
,
&
globalState
[
i
]);
curand_init
(
seed
,
i
,
0
,
&
globalState
[
i
]);
output
[
i
]
=
curand_normal
(
&
globalState
[
i
]);
output
[
i
]
=
(
T
)
curand_normal
(
&
globalState
[
i
]);
}
}
return
;
return
;
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
UniformKernel
(
int
seed
,
curandState
*
globalState
,
T
*
input1
,
size_t
input_size_1
,
__global__
void
Uniform
Int
Kernel
(
int
seed
,
curandState
*
globalState
,
T
*
input1
,
size_t
input_size_1
,
T
*
input2
,
size_t
input_size_2
,
T
*
output
,
size_t
count
)
{
T
*
input2
,
size_t
input_size_2
,
T
*
output
,
size_t
count
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
input1
[
i
]
=
(
input_size_1
==
1
?
input1
[
0
]
:
input1
[
i
]);
input2
[
i
]
=
(
input_size_2
==
1
?
input2
[
0
]
:
input2
[
i
]);
curand_init
(
seed
,
i
,
0
,
&
globalState
[
i
]);
curand_init
(
seed
,
i
,
0
,
&
globalState
[
i
]);
output
[
i
]
=
curand_uniform
(
&
globalState
[
i
])
*
(
input2
[
i
]
-
input1
[
i
])
+
input1
[
i
];
output
[
i
]
=
(
T
)(
curand_uniform
(
&
globalState
[
i
]))
*
(
input2
[
0
]
-
input1
[
0
])
+
input1
[
0
];
}
return
;
}
template
<
typename
T
>
__global__
void
UniformRealKernel
(
int
seed
,
curandState
*
globalState
,
T
*
output
,
size_t
count
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
curand_init
(
seed
,
i
,
0
,
&
globalState
[
i
]);
output
[
i
]
=
(
T
)
curand_uniform
(
&
globalState
[
i
]);
}
}
return
;
return
;
}
}
...
@@ -51,16 +58,46 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si
...
@@ -51,16 +58,46 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si
}
}
template
<
typename
T
>
template
<
typename
T
>
void
UniformReal
(
int
seed
,
curandState
*
globalState
,
T
*
input1
,
size_t
input_size_1
,
void
UniformInt
(
int
seed
,
int
seed2
,
curandState
*
globalState
,
T
*
input1
,
size_t
input_size_1
,
T
*
input2
,
size_t
input_size_2
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
T
*
input2
,
size_t
input_size_2
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
seed
=
(
seed
==
0
?
time
(
NULL
)
:
seed
);
int
RNG_seed
=
0
;
UniformKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
if
(
seed2
!=
0
)
{
(
seed
,
globalState
,
input1
,
input_size_1
,
input2
,
input_size_2
,
output
,
count
);
RNG_seed
=
seed2
;
}
else
if
(
seed
!=
0
)
{
RNG_seed
=
seed
;
}
else
{
RNG_seed
=
time
(
NULL
);
}
UniformIntKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
RNG_seed
,
globalState
,
input1
,
input_size_1
,
input2
,
input_size_2
,
output
,
count
);
return
;
}
template
<
typename
T
>
void
UniformReal
(
int
seed
,
int
seed2
,
curandState
*
globalState
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
int
RNG_seed
=
0
;
if
(
seed2
!=
0
)
{
RNG_seed
=
seed2
;
}
else
if
(
seed
!=
0
)
{
RNG_seed
=
seed
;
}
else
{
RNG_seed
=
time
(
NULL
);
}
UniformRealKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
RNG_seed
,
globalState
,
output
,
count
);
return
;
return
;
}
}
template
void
StandardNormal
<
float
>(
int
seed
,
int
seed2
,
curandState
*
globalState
,
template
void
StandardNormal
<
float
>(
int
seed
,
int
seed2
,
curandState
*
globalState
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
UniformReal
<
float
>(
int
seed
,
curandState
*
globalState
,
float
*
input1
,
size_t
input_size_1
,
template
void
StandardNormal
<
int
>(
int
seed
,
int
seed2
,
curandState
*
globalState
,
float
*
input2
,
size_t
input_size_2
,
float
*
output
,
size_t
count
,
int
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
cudaStream_t
cuda_stream
);
template
void
UniformInt
<
float
>(
int
seed
,
int
seed2
,
curandState
*
globalState
,
float
*
input1
,
size_t
input_size_1
,
float
*
input2
,
size_t
input_size_2
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
UniformInt
<
int
>(
int
seed
,
int
seed2
,
curandState
*
globalState
,
int
*
input1
,
size_t
input_size_1
,
int
*
input2
,
size_t
input_size_2
,
int
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
UniformReal
<
float
>(
int
seed
,
int
seed2
,
curandState
*
globalState
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
UniformReal
<
int
>(
int
seed
,
int
seed2
,
curandState
*
globalState
,
int
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh
浏览文件 @
2b4febb4
...
@@ -24,7 +24,10 @@ template <typename T>
...
@@ -24,7 +24,10 @@ template <typename T>
void
StandardNormal
(
int
seed
,
int
seed2
,
curandState
*
globalState
,
void
StandardNormal
(
int
seed
,
int
seed2
,
curandState
*
globalState
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
template
<
typename
T
>
void
UniformReal
(
int
seed
,
curandState
*
globalState
,
void
UniformInt
(
int
seed
,
int
seed2
,
curandState
*
globalState
,
T
*
input1
,
size_t
input_size_1
,
T
*
input2
,
size_t
input_size_2
,
T
*
input1
,
size_t
input_size_1
,
T
*
input2
,
size_t
input_size_2
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
UniformReal
(
int
seed
,
int
seed2
,
curandState
*
globalState
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_
mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc
浏览文件 @
2b4febb4
...
@@ -20,12 +20,14 @@ namespace mindspore {
...
@@ -20,12 +20,14 @@ namespace mindspore {
namespace
kernel
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
StandardNormal
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeFloat32
),
MS_REG_GPU_KERNEL_ONE
(
StandardNormal
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeFloat32
),
RandomOpGpuKernel
,
float
)
RandomOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Uniform
Real
,
MS_REG_GPU_KERNEL_ONE
(
Uniform
Int
,
KernelAttr
()
KernelAttr
()
.
AddInputAttr
(
kNumberTypeInt32
)
.
AddInputAttr
(
kNumberTypeInt32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeInt32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeInt32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
.
AddOutputAttr
(
kNumberTypeInt32
),
RandomOpGpuKernel
,
int
)
MS_REG_GPU_KERNEL_ONE
(
UniformReal
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeFloat32
),
RandomOpGpuKernel
,
float
)
RandomOpGpuKernel
,
float
)
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h
浏览文件 @
2b4febb4
...
@@ -28,16 +28,17 @@
...
@@ -28,16 +28,17 @@
namespace
mindspore
{
namespace
mindspore
{
namespace
kernel
{
namespace
kernel
{
enum
RandomOptype
{
RANDOM_OP_NORMAL
=
0
,
RANDOM_OP_UNIFORM_REAL
,
RANDOM_OP_INVALID_TYPE
=
255
};
enum
RandomOptype
{
RANDOM_OP_NORMAL
=
0
,
RANDOM_OP_UNIFORM_INT
,
RANDOM_OP_UNIFORM_REAL
,
RANDOM_OP_INVALID_TYPE
=
255
};
const
std
::
map
<
std
::
string
,
RandomOptype
>
kRandomOpTypeMap
=
{
{
"StandardNormal"
,
RANDOM_OP_NORMAL
},
{
"UniformInt"
,
RANDOM_OP_UNIFORM_INT
},
{
"UniformReal"
,
RANDOM_OP_UNIFORM_REAL
}};
const
std
::
map
<
std
::
string
,
RandomOptype
>
kRandomOpTypeMap
=
{{
"StandardNormal"
,
RANDOM_OP_NORMAL
},
{
"UniformReal"
,
RANDOM_OP_UNIFORM_REAL
}};
template
<
typename
T
>
template
<
typename
T
>
class
RandomOpGpuKernel
:
public
GpuKernel
{
class
RandomOpGpuKernel
:
public
GpuKernel
{
public:
public:
RandomOpGpuKernel
()
RandomOpGpuKernel
()
:
random_op_type_
(
RANDOM_OP_INVALID_TYPE
),
:
random_op_type_
(
RANDOM_OP_INVALID_TYPE
),
input_size_0_
(
sizeof
(
int
)),
input_size_0_
(
sizeof
(
0
)),
input_size_1_
(
sizeof
(
T
)),
input_size_1_
(
sizeof
(
T
)),
input_size_2_
(
sizeof
(
T
)),
input_size_2_
(
sizeof
(
T
)),
output_size_
(
sizeof
(
T
)),
output_size_
(
sizeof
(
T
)),
...
@@ -62,11 +63,16 @@ class RandomOpGpuKernel : public GpuKernel {
...
@@ -62,11 +63,16 @@ class RandomOpGpuKernel : public GpuKernel {
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
break
;
break
;
}
}
case
RANDOM_OP_UNIFORM_
REAL
:
{
case
RANDOM_OP_UNIFORM_
INT
:
{
T
*
input_addr_1
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
input_addr_1
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
input_addr_2
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
T
*
input_addr_2
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
UniformReal
(
seed_
,
devStates
,
input_addr_1
,
inputs
[
1
]
->
size
/
sizeof
(
T
),
input_addr_2
,
UniformInt
(
seed_
,
seed2_
,
devStates
,
input_addr_1
,
inputs
[
1
]
->
size
/
sizeof
(
T
),
input_addr_2
,
inputs
[
2
]
->
size
/
sizeof
(
T
),
output_addr
,
outputs
[
0
]
->
size
/
sizeof
(
T
),
inputs
[
2
]
->
size
/
sizeof
(
T
),
output_addr
,
outputs
[
0
]
->
size
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
break
;
}
case
RANDOM_OP_UNIFORM_REAL
:
{
UniformReal
(
seed_
,
seed2_
,
devStates
,
output_addr
,
outputs
[
0
]
->
size
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
break
;
break
;
}
}
...
@@ -86,11 +92,11 @@ class RandomOpGpuKernel : public GpuKernel {
...
@@ -86,11 +92,11 @@ class RandomOpGpuKernel : public GpuKernel {
random_op_type_
=
iter
->
second
;
random_op_type_
=
iter
->
second
;
}
}
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
random_op_type_
==
RANDOM_OP_NORMAL
&&
input_num
!=
1
)
{
if
(
(
random_op_type_
==
RANDOM_OP_NORMAL
||
random_op_type_
==
RANDOM_OP_UNIFORM_REAL
)
&&
input_num
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but random op needs 1 input."
;
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but random op needs 1 input."
;
return
false
;
return
false
;
}
}
if
(
random_op_type_
==
RANDOM_OP_UNIFORM_
REAL
&&
input_num
!=
3
)
{
if
(
random_op_type_
==
RANDOM_OP_UNIFORM_
INT
&&
input_num
!=
3
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but random op needs 3 inputs."
;
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but random op needs 3 inputs."
;
return
false
;
return
false
;
}
}
...
@@ -104,15 +110,9 @@ class RandomOpGpuKernel : public GpuKernel {
...
@@ -104,15 +110,9 @@ class RandomOpGpuKernel : public GpuKernel {
input_size_0_
+=
input_shape_0
[
i
];
input_size_0_
+=
input_shape_0
[
i
];
}
}
input_size_0_
*=
sizeof
(
int
);
input_size_0_
*=
sizeof
(
int
);
if
(
random_op_type_
==
RANDOM_OP_UNIFORM_REAL
)
{
if
(
random_op_type_
==
RANDOM_OP_UNIFORM_INT
)
{
auto
input_shape_1
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
1
);
input_size_1_
*=
1
;
for
(
size_t
i
=
0
;
i
<
input_shape_1
.
size
();
i
++
)
{
input_size_2_
*=
1
;
input_size_1_
*=
input_shape_1
[
i
];
}
auto
input_shape_2
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
2
);
for
(
size_t
i
=
0
;
i
<
input_shape_2
.
size
();
i
++
)
{
input_size_2_
*=
input_shape_2
[
i
];
}
}
}
auto
output_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
0
);
auto
output_shape
=
AnfAlgo
::
GetOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
output_shape
.
size
();
i
++
)
{
for
(
size_t
i
=
0
;
i
<
output_shape
.
size
();
i
++
)
{
...
@@ -120,9 +120,7 @@ class RandomOpGpuKernel : public GpuKernel {
...
@@ -120,9 +120,7 @@ class RandomOpGpuKernel : public GpuKernel {
workspace_size_
*=
output_shape
[
i
];
workspace_size_
*=
output_shape
[
i
];
}
}
seed_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"seed"
));
seed_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"seed"
));
if
(
random_op_type_
==
RANDOM_OP_NORMAL
)
{
seed2_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"seed2"
));
seed2_
=
GetValue
<
int
>
(
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"seed2"
));
}
InitSizeLists
();
InitSizeLists
();
return
true
;
return
true
;
}
}
...
@@ -130,7 +128,7 @@ class RandomOpGpuKernel : public GpuKernel {
...
@@ -130,7 +128,7 @@ class RandomOpGpuKernel : public GpuKernel {
protected:
protected:
void
InitSizeLists
()
override
{
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input_size_0_
);
input_size_list_
.
push_back
(
input_size_0_
);
if
(
random_op_type_
==
RANDOM_OP_UNIFORM_
REAL
)
{
if
(
random_op_type_
==
RANDOM_OP_UNIFORM_
INT
)
{
input_size_list_
.
push_back
(
input_size_1_
);
input_size_list_
.
push_back
(
input_size_1_
);
input_size_list_
.
push_back
(
input_size_2_
);
input_size_list_
.
push_back
(
input_size_2_
);
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录