Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
87be3865
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
87be3865
编写于
4月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!314 GPU add kernel assign
Merge pull request !314 from VectorSL/assign
上级
a51f01f2
9e372073
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
176 addition
and
0 deletion
+176
-0
mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.cc
+33
-0
mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.h
+93
-0
tests/st/ops/gpu/test_assign_op.py
tests/st/ops/gpu/test_assign_op.py
+50
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.cc
0 → 100644
浏览文件 @
87be3865
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/other/assign_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
Assign
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
AssignGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Assign
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
AssignGpuKernel
,
half
)
MS_REG_GPU_KERNEL_ONE
(
Assign
,
KernelAttr
().
AddInputAttr
(
kNumberTypeInt32
).
AddInputAttr
(
kNumberTypeInt32
).
AddOutputAttr
(
kNumberTypeInt32
),
AssignGpuKernel
,
int
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/other/assign_gpu_kernel.h
0 → 100644
浏览文件 @
87be3865
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H
#define MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
AssignGpuKernel
:
public
GpuKernel
{
public:
AssignGpuKernel
()
:
input_size_
(
0
)
{}
~
AssignGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
override
{
T
*
var
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
value
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
output
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMemcpyAsync
(
var
,
value
,
input_size_
,
cudaMemcpyDeviceToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"cudaMemxcpyAsync failed."
);
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaMemcpyAsync
(
output
,
value
,
input_size_
,
cudaMemcpyDeviceToDevice
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
)),
"cudaMemxcpyAsync failed."
);
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
if
(
!
CheckParam
(
kernel_node
))
{
return
false
;
}
auto
shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
input_size_
=
sizeof
(
T
);
for
(
size_t
x
:
shape
)
{
input_size_
=
input_size_
*
x
;
}
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input_size_
);
input_size_list_
.
push_back
(
input_size_
);
output_size_list_
.
push_back
(
input_size_
);
}
private:
bool
CheckParam
(
const
CNodePtr
&
kernel_node
)
{
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
2
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but AssignGpuKernel needs 2 output."
;
return
false
;
}
size_t
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
if
(
output_num
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Output number is "
<<
output_num
<<
", but AssignGpuKernel needs 1 output."
;
return
false
;
}
return
true
;
}
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
size_t
input_size_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_ASSIGN_GPU_KERNEL_H
tests/st/ops/gpu/test_assign_op.py
0 → 100644
浏览文件 @
87be3865
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
pytest
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
import
numpy
as
np
import
mindspore.context
as
context
class
Net
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
Net
,
self
).
__init__
()
self
.
assign
=
P
.
Assign
()
def
construct
(
self
,
var
,
value
):
return
self
.
assign
(
var
,
value
)
x
=
np
.
array
([[
1.2
,
1
],
[
1
,
0
]]).
astype
(
np
.
float32
)
value
=
np
.
array
([[
1
,
2
],
[
3
,
4.0
]]).
astype
(
np
.
float32
)
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_assign
():
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
assign
=
Net
()
var
=
Tensor
(
x
)
output
=
assign
(
var
,
Tensor
(
value
))
error
=
np
.
ones
(
shape
=
[
2
,
2
])
*
1.0e-6
diff1
=
output
.
asnumpy
()
-
value
diff2
=
var
.
asnumpy
()
-
value
assert
np
.
all
(
diff1
<
error
)
assert
np
.
all
(
-
diff1
<
error
)
assert
np
.
all
(
diff2
<
error
)
assert
np
.
all
(
-
diff2
<
error
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录