Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
67a0cc3b
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
67a0cc3b
编写于
5月 07, 2020
作者:
W
wilfChen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
gpu queue support unary
上级
afe04847
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
106 addition
and
1 deletion
+106
-1
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu
+47
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh
+4
-0
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc
+4
-0
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h
+13
-1
tests/st/ops/gpu/test_sqrt_op.py
tests/st/ops/gpu/test_sqrt_op.py
+38
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu
浏览文件 @
67a0cc3b
...
...
@@ -60,6 +60,34 @@ __global__ void SquareKernel(T *input, T *output, size_t count) {
return
;
}
template
<
typename
T
>
__global__
void
SqrtKernel
(
T
*
input
,
T
*
output
,
size_t
count
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
i
]
=
sqrt
(
input
[
i
]);
}
return
;
}
template
<
>
__global__
void
SqrtKernel
(
half
*
input
,
half
*
output
,
size_t
count
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
i
]
=
hsqrt
(
input
[
i
]);
}
return
;
}
template
<
typename
T
>
__global__
void
RsqrtKernel
(
T
*
input
,
T
*
output
,
size_t
count
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
i
]
=
rsqrt
(
input
[
i
]);
}
return
;
}
template
<
>
__global__
void
RsqrtKernel
(
half
*
input
,
half
*
output
,
size_t
count
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
i
]
=
hrsqrt
(
input
[
i
]);
}
return
;
}
template
<
typename
T
>
__global__
void
ZeroslikeKernel
(
T
*
output
,
size_t
count
)
{
T
zero
=
0.0
;
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
...
...
@@ -93,6 +121,21 @@ void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
return
;
}
template
<
typename
T
>
void
Pow
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
PowKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
count
);
return
;
}
template
<
typename
T
>
void
Sqrt
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
SqrtKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
count
);
return
;
}
template
<
typename
T
>
void
Rsqrt
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
RsqrtKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
count
);
return
;
}
template
<
typename
T
>
void
Zeroslike
(
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
ZeroslikeKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
output
,
count
);
return
;
...
...
@@ -103,10 +146,14 @@ template void Logarithm<float>(float *input, float *output, size_t count, cudaSt
template
void
Negative
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Reciprocal
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Square
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Sqrt
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Rsqrt
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Zeroslike
<
float
>(
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Exponential
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Logarithm
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Negative
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Reciprocal
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Square
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Sqrt
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Rsqrt
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Zeroslike
<
half
>(
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh
浏览文件 @
67a0cc3b
...
...
@@ -29,6 +29,10 @@ void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream);
template
<
typename
T
>
void
Square
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
Sqrt
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
Rsqrt
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
Zeroslike
(
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc
浏览文件 @
67a0cc3b
...
...
@@ -42,5 +42,9 @@ MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddO
UnaryOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Square
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
UnaryOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_ONE
(
Sqrt
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
UnaryOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Rsqrt
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
UnaryOpGpuKernel
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h
浏览文件 @
67a0cc3b
...
...
@@ -34,6 +34,8 @@ enum UnaryOptype {
UNARY_OP_RECIPROCAL
,
UNARY_OP_ZEROSLIKE
,
UNARY_OP_SQUARE
,
UNARY_OP_SQRT
,
UNARY_OP_RSQRT
,
UNARY_OP_INVALID_TYPE
=
255
};
static
const
std
::
map
<
std
::
string
,
UnaryOptype
>
kUnaryOpTypeMap
=
{{
"Exp"
,
UNARY_OP_EXP
},
...
...
@@ -41,7 +43,9 @@ static const std::map<std::string, UnaryOptype> kUnaryOpTypeMap = {{"Exp", UNARY
{
"Neg"
,
UNARY_OP_NEG
},
{
"Reciprocal"
,
UNARY_OP_RECIPROCAL
},
{
"ZerosLike"
,
UNARY_OP_ZEROSLIKE
},
{
"Square"
,
UNARY_OP_SQUARE
}};
{
"Square"
,
UNARY_OP_SQUARE
},
{
"Sqrt"
,
UNARY_OP_SQRT
},
{
"Rsqrt"
,
UNARY_OP_RSQRT
}};
template
<
typename
T
>
class
UnaryOpGpuKernel
:
public
GpuKernel
{
public:
...
...
@@ -80,6 +84,14 @@ class UnaryOpGpuKernel : public GpuKernel {
Square
(
input_addr
,
output_addr
,
inputs
[
0
]
->
size
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
break
;
}
case
UNARY_OP_SQRT
:
{
Sqrt
(
input_addr
,
output_addr
,
inputs
[
0
]
->
size
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
break
;
}
case
UNARY_OP_RSQRT
:
{
Rsqrt
(
input_addr
,
output_addr
,
inputs
[
0
]
->
size
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
break
;
}
case
UNARY_OP_ZEROSLIKE
:
{
Zeroslike
(
output_addr
,
output_size_
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
...
...
tests/st/ops/gpu/test_sqrt_op.py
0 → 100644
浏览文件 @
67a0cc3b
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
pytest
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
import
numpy
as
np
import
mindspore.context
as
context
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_sqrt
():
x_np
=
np
.
random
.
rand
(
2
,
3
,
4
,
4
).
astype
(
np
.
float32
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"GPU"
)
output_ms
=
P
.
Sqrt
()(
Tensor
(
x_np
))
output_np
=
np
.
sqrt
(
x_np
)
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
output_ms
=
P
.
Rsqrt
()(
Tensor
(
x_np
))
output_np
=
1
/
np
.
sqrt
(
x_np
)
assert
np
.
allclose
(
output_ms
.
asnumpy
(),
output_np
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录