Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
946dcfa0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
946dcfa0
编写于
7月 17, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 17, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3142 add GPU operator: abs and floor
Merge pull request !3142 from caojian05/ms_master_dev
上级
6dd99ee3
f3f9fc95
变更
4
显示空白变更内容
内联
并排
Showing
4 changed file
with
68 addition
and
1 deletion
+68
-1
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu
...rc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu
+43
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh
...c/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh
+4
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc
...c/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc
+8
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h
...rc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h
+13
-1
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cu
浏览文件 @
946dcfa0
...
@@ -103,6 +103,35 @@ __global__ void ZeroslikeKernel(T *output, size_t count) {
...
@@ -103,6 +103,35 @@ __global__ void ZeroslikeKernel(T *output, size_t count) {
return
;
return
;
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
AbsKernel
(
T
*
input
,
T
*
output
,
size_t
count
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
i
]
=
abs
(
input
[
i
]);
}
return
;
}
template
<
>
__global__
void
AbsKernel
(
half
*
input
,
half
*
output
,
size_t
count
)
{
half
zero
=
0.0
;
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
i
]
=
input
[
i
]
<
zero
?
-
input
[
i
]
:
input
[
i
];
}
return
;
}
template
<
typename
T
>
__global__
void
FloorKernel
(
T
*
input
,
T
*
output
,
size_t
count
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
i
]
=
floor
(
input
[
i
]);
}
return
;
}
template
<
>
__global__
void
FloorKernel
(
half
*
input
,
half
*
output
,
size_t
count
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
i
]
=
hfloor
(
input
[
i
]);
}
return
;
}
template
<
typename
T
>
void
Exponential
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
void
Exponential
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
ExponentialKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
count
);
ExponentialKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
count
);
return
;
return
;
...
@@ -147,6 +176,16 @@ void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream) {
...
@@ -147,6 +176,16 @@ void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream) {
ZeroslikeKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
output
,
count
);
ZeroslikeKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
output
,
count
);
return
;
return
;
}
}
template
<
typename
T
>
void
Abs
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
AbsKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
count
);
return
;
}
template
<
typename
T
>
void
Floor
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
FloorKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
count
);
return
;
}
template
void
Exponential
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Exponential
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Logarithm
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Logarithm
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
...
@@ -156,6 +195,8 @@ template void Square<float>(float *input, float *output, size_t count, cudaStrea
...
@@ -156,6 +195,8 @@ template void Square<float>(float *input, float *output, size_t count, cudaStrea
template
void
Sqrt
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Sqrt
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Rsqrt
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Rsqrt
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Zeroslike
<
float
>(
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Zeroslike
<
float
>(
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Abs
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Floor
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Exponential
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Exponential
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Logarithm
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Logarithm
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Negative
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Negative
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
...
@@ -164,3 +205,5 @@ template void Square<half>(half *input, half *output, size_t count, cudaStream_t
...
@@ -164,3 +205,5 @@ template void Square<half>(half *input, half *output, size_t count, cudaStream_t
template
void
Sqrt
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Sqrt
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Rsqrt
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Rsqrt
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Zeroslike
<
half
>(
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Zeroslike
<
half
>(
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Abs
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Floor
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/unary_op_impl.cuh
浏览文件 @
946dcfa0
...
@@ -34,5 +34,9 @@ template <typename T>
...
@@ -34,5 +34,9 @@ template <typename T>
void
Rsqrt
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
void
Rsqrt
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
template
<
typename
T
>
void
Zeroslike
(
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
void
Zeroslike
(
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
Abs
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
Floor
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_
mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.cc
浏览文件 @
946dcfa0
...
@@ -46,5 +46,13 @@ MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
...
@@ -46,5 +46,13 @@ MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOut
UnaryOpGpuKernel
,
float
)
UnaryOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Rsqrt
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
MS_REG_GPU_KERNEL_ONE
(
Rsqrt
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
UnaryOpGpuKernel
,
float
)
UnaryOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Abs
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
UnaryOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Abs
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
UnaryOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_ONE
(
Floor
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
UnaryOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Floor
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
UnaryOpGpuKernel
,
half
)
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/math/unary_op_gpu_kernel.h
浏览文件 @
946dcfa0
...
@@ -36,6 +36,8 @@ enum UnaryOptype {
...
@@ -36,6 +36,8 @@ enum UnaryOptype {
UNARY_OP_SQUARE
,
UNARY_OP_SQUARE
,
UNARY_OP_SQRT
,
UNARY_OP_SQRT
,
UNARY_OP_RSQRT
,
UNARY_OP_RSQRT
,
UNARY_OP_ABS
,
UNARY_OP_FLOOR
,
UNARY_OP_INVALID_TYPE
=
255
UNARY_OP_INVALID_TYPE
=
255
};
};
static
const
std
::
map
<
std
::
string
,
UnaryOptype
>
kUnaryOpTypeMap
=
{{
"Exp"
,
UNARY_OP_EXP
},
static
const
std
::
map
<
std
::
string
,
UnaryOptype
>
kUnaryOpTypeMap
=
{{
"Exp"
,
UNARY_OP_EXP
},
...
@@ -45,7 +47,9 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY
...
@@ -45,7 +47,9 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY
{
"ZerosLike"
,
UNARY_OP_ZEROSLIKE
},
{
"ZerosLike"
,
UNARY_OP_ZEROSLIKE
},
{
"Square"
,
UNARY_OP_SQUARE
},
{
"Square"
,
UNARY_OP_SQUARE
},
{
"Sqrt"
,
UNARY_OP_SQRT
},
{
"Sqrt"
,
UNARY_OP_SQRT
},
{
"Rsqrt"
,
UNARY_OP_RSQRT
}};
{
"Rsqrt"
,
UNARY_OP_RSQRT
},
{
"Abs"
,
UNARY_OP_ABS
},
{
"Floor"
,
UNARY_OP_FLOOR
}};
template
<
typename
T
>
template
<
typename
T
>
class
UnaryOpGpuKernel
:
public
GpuKernel
{
class
UnaryOpGpuKernel
:
public
GpuKernel
{
public:
public:
...
@@ -100,6 +104,14 @@ class UnaryOpGpuKernel : public GpuKernel {
...
@@ -100,6 +104,14 @@ class UnaryOpGpuKernel : public GpuKernel {
Zeroslike
(
output_addr
,
output_size_
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
Zeroslike
(
output_addr
,
output_size_
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
return
true
;
}
}
case
UNARY_OP_ABS
:
{
Abs
(
input_addr
,
output_addr
,
inputs
[
0
]
->
size
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
break
;
}
case
UNARY_OP_FLOOR
:
{
Floor
(
input_addr
,
output_addr
,
inputs
[
0
]
->
size
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
break
;
}
default:
{
default:
{
MS_LOG
(
EXCEPTION
)
<<
"Unary operation "
<<
unary_op_type_
<<
" is not supported."
;
MS_LOG
(
EXCEPTION
)
<<
"Unary operation "
<<
unary_op_type_
<<
" is not supported."
;
}
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录