Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
32921ea3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
32921ea3
编写于
7月 18, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
7月 18, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3166 add gpu oneslike op
Merge pull request !3166 from qujianwei/gpu-oneslike
上级
4e0cfafc
fb2ac74d
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
256 addition
and
0 deletion
+256
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.cc
...backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.cc
+26
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h
.../backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h
+85
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cu
...rc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cu
+37
-0
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh
...c/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh
+23
-0
tests/st/ops/gpu/test_oneslike_op.py
tests/st/ops/gpu/test_oneslike_op.py
+85
-0
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.cc
0 → 100644
浏览文件 @
32921ea3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
OnesLike
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
OnesLikeGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
OnesLike
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
OnesLikeGpuKernel
,
half
)
MS_REG_GPU_KERNEL_ONE
(
OnesLike
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
OnesLikeGpuKernel
,
int
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h
0 → 100644
浏览文件 @
32921ea3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ONESLIKE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_ONESLIKE_H_
#include <vector>
#include "backend/kernel_compiler/gpu/gpu_kernel.h"
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
OnesLikeGpuKernel
:
public
GpuKernel
{
public:
OnesLikeGpuKernel
()
:
input_size_
(
0
),
output_size_
(
0
)
{}
~
OnesLikeGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
void
*
stream_ptr
)
override
{
T
*
input
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
int
size
=
SizeToInt
(
input_size_
/
sizeof
(
T
));
CalOnesLike
(
size
,
input
,
output
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but oneslike needs 1 input."
;
return
false
;
}
size_t
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
if
(
output_num
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Output number is "
<<
output_num
<<
", but oneslike needs 1 output."
;
return
false
;
}
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
size_t
shape_size
=
input_shape
.
size
();
input_size_
=
1
;
for
(
size_t
i
=
0
;
i
<
shape_size
;
i
++
)
{
input_size_
*=
input_shape
[
i
];
}
input_size_
*=
sizeof
(
T
);
output_size_
=
input_size_
;
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input_size_
);
output_size_list_
.
push_back
(
output_size_
);
return
;
}
private:
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
size_t
input_size_
;
size_t
output_size_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_ONESLIKE_H_
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cu
0 → 100644
浏览文件 @
32921ea3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime.h>
#include "oneslike_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
template
<
typename
T
>
__global__
void
OnesLike
(
const
int
size
,
const
T
*
input
,
T
*
output
)
{
int
one
=
1
;
T
val
=
static_cast
<
T
>
(
one
);
for
(
int
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
size
;
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
pos
]
=
val
;
}
return
;
}
template
<
typename
T
>
void
CalOnesLike
(
const
int
size
,
const
T
*
input
,
T
*
output
,
cudaStream_t
cuda_stream
)
{
OnesLike
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
input
,
output
);
return
;
}
template
void
CalOnesLike
<
float
>(
const
int
size
,
const
float
*
input
,
float
*
output
,
cudaStream_t
cuda_stream
);
template
void
CalOnesLike
<
half
>(
const
int
size
,
const
half
*
input
,
half
*
output
,
cudaStream_t
cuda_stream
);
template
void
CalOnesLike
<
int
>(
const
int
size
,
const
int
*
input
,
int
*
output
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh
0 → 100644
浏览文件 @
32921ea3
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_
template
<
typename
T
>
void
CalOnesLike
(
const
int
size
,
const
T
*
input
,
T
*
output
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_
tests/st/ops/gpu/test_oneslike_op.py
0 → 100644
浏览文件 @
32921ea3
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"GPU"
)
class
NetOnesLike
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetOnesLike
,
self
).
__init__
()
self
.
ones_like
=
P
.
OnesLike
()
def
construct
(
self
,
x
):
return
self
.
ones_like
(
x
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_OnesLike
():
x0_np
=
np
.
random
.
uniform
(
-
2
,
2
,
(
2
,
3
,
4
,
4
)).
astype
(
np
.
float32
)
x1_np
=
np
.
random
.
uniform
(
-
2
,
2
,
1
).
astype
(
np
.
float16
)
x2_np
=
np
.
zeros
([
3
,
3
,
3
],
dtype
=
np
.
int32
)
x0
=
Tensor
(
x0_np
)
x1
=
Tensor
(
x1_np
)
x2
=
Tensor
(
x2_np
)
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"GPU"
)
ones_like
=
NetOnesLike
()
output0
=
ones_like
(
x0
)
expect0
=
np
.
ones_like
(
x0_np
)
diff0
=
output0
.
asnumpy
()
-
expect0
error0
=
np
.
ones
(
shape
=
expect0
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff0
<
error0
)
assert
output0
.
shape
==
expect0
.
shape
output1
=
ones_like
(
x1
)
expect1
=
np
.
ones_like
(
x1_np
)
diff1
=
output1
.
asnumpy
()
-
expect1
error1
=
np
.
ones
(
shape
=
expect1
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff1
<
error1
)
assert
output1
.
shape
==
expect1
.
shape
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
ones_like
=
NetOnesLike
()
output0
=
ones_like
(
x0
)
expect0
=
np
.
ones_like
(
x0_np
)
diff0
=
output0
.
asnumpy
()
-
expect0
error0
=
np
.
ones
(
shape
=
expect0
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff0
<
error0
)
assert
output0
.
shape
==
expect0
.
shape
output1
=
ones_like
(
x1
)
expect1
=
np
.
ones_like
(
x1_np
)
diff1
=
output1
.
asnumpy
()
-
expect1
error1
=
np
.
ones
(
shape
=
expect1
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff1
<
error1
)
assert
output1
.
shape
==
expect1
.
shape
output2
=
ones_like
(
x2
)
expect2
=
np
.
ones_like
(
x2_np
)
diff2
=
output2
.
asnumpy
()
-
expect2
error2
=
np
.
ones
(
shape
=
expect2
.
shape
)
*
1.0e-5
assert
np
.
all
(
diff2
<
error2
)
assert
output2
.
shape
==
expect2
.
shape
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录