Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
f746caf5
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f746caf5
编写于
4年前
作者:
M
mindspore-ci-bot
提交者:
Gitee
4年前
浏览文件
操作
浏览文件
下载
差异文件
!316 Edit GPU ops
Merge pull request !316 from VectorSL/edit
上级
2bcff36e
930c9101
master
r0.2
r0.3
r0.5
r0.6
r0.7
v0.7.0-beta
v0.6.0-beta
v0.5.0-beta
v0.3.1-alpha
v0.3.0-alpha
v0.2.0-alpha
无相关合并请求
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
69 addition
and
21 deletion
+69
-21
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu
+14
-1
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh
+2
-1
mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc
mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc
+4
-2
mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc
+8
-0
mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h
+17
-5
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc
+4
-0
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h
+13
-7
mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h
+1
-5
mindspore/ccsrc/vm/backend.cc
mindspore/ccsrc/vm/backend.cc
+6
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu
浏览文件 @
f746caf5
...
...
@@ -53,6 +53,13 @@ __global__ void ReciprocalKernel(T *input, T *output, size_t count) {
return
;
}
template
<
typename
T
>
__global__
void
SquareKernel
(
T
*
input
,
T
*
output
,
size_t
count
)
{
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
(
count
);
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
output
[
i
]
=
input
[
i
]
*
input
[
i
];
}
return
;
}
template
<
typename
T
>
void
Exponential
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
ExponentialKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
count
);
return
;
...
...
@@ -72,12 +79,18 @@ void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream) {
ReciprocalKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
count
);
return
;
}
template
<
typename
T
>
void
Square
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
)
{
SquareKernel
<<<
GET_BLOCKS
(
count
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
input
,
output
,
count
);
return
;
}
template
void
Exponential
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Logarithm
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Negative
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Reciprocal
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Square
<
float
>(
float
*
input
,
float
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Exponential
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Logarithm
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Negative
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Reciprocal
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
void
Square
<
half
>(
half
*
input
,
half
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh
浏览文件 @
f746caf5
...
...
@@ -26,5 +26,6 @@ template <typename T>
void
Negative
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
Reciprocal
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
Square
(
T
*
input
,
T
*
output
,
size_t
count
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/kernel/gpu/gpu_kernel_factory.cc
浏览文件 @
f746caf5
...
...
@@ -41,8 +41,9 @@ void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const Kernel
size_t
attr_index
)
{
if
(
kernel_info
->
GetInputNum
()
!=
iter_second
->
at
(
attr_index
).
first
.
GetInputSize
())
{
if
(
iter_second
->
at
(
attr_index
).
first
.
GetAllSame
())
{
auto
dtype
=
iter_second
->
at
(
attr_index
).
first
.
GetInputAttr
(
0
).
first
;
for
(
size_t
attr
=
1
;
attr
<
kernel_info
->
GetInputNum
();
++
attr
)
{
(
void
)
iter_second
->
at
(
attr_index
).
first
.
AddInputAttr
(
kernel_info
->
GetInputDeviceType
(
0
)
);
(
void
)
iter_second
->
at
(
attr_index
).
first
.
AddInputAttr
(
dtype
);
}
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"op["
<<
kernel_name
<<
"] Input size is mismatching!"
;
...
...
@@ -50,8 +51,9 @@ void GpuKernelFactory::CheckIOParam(const std::string &kernel_name, const Kernel
}
if
(
kernel_info
->
GetOutputNum
()
!=
iter_second
->
at
(
attr_index
).
first
.
GetOutputSize
())
{
if
(
iter_second
->
at
(
attr_index
).
first
.
GetAllSame
())
{
auto
dtype
=
iter_second
->
at
(
attr_index
).
first
.
GetOutputAttr
(
0
).
first
;
for
(
size_t
attr
=
1
;
attr
<
kernel_info
->
GetOutputNum
();
++
attr
)
{
(
void
)
iter_second
->
at
(
attr_index
).
first
.
AddOutputAttr
(
kernel_info
->
GetOutputDeviceType
(
0
)
);
(
void
)
iter_second
->
at
(
attr_index
).
first
.
AddOutputAttr
(
dtype
);
}
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"op["
<<
kernel_name
<<
"] Output size is mismatching!"
;
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.cc
浏览文件 @
f746caf5
...
...
@@ -38,5 +38,13 @@ MS_REG_GPU_KERNEL_ONE(
MS_REG_GPU_KERNEL_ONE
(
Sub
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BinaryOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_ONE
(
Maximum
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
BinaryOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Maximum
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
BinaryOpGpuKernel
,
half
)
}
// namespace kernel
}
// namespace mindspore
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/kernel/gpu/math/binary_op_gpu_kernel.h
浏览文件 @
f746caf5
...
...
@@ -27,12 +27,16 @@
#include "kernel/gpu/kernel_constants.h"
namespace
mindspore
{
namespace
kernel
{
enum
BinaryOpType
{
BINARY_OP_ADD
=
0
,
BINARY_OP_SUB
,
BINARY_OP_MUL
,
BINARY_OP_DIV
,
BINARY_OP_INVALID_TYPE
=
255
};
const
std
::
map
<
std
::
string
,
BinaryOpType
>
kBinaryOpTypeMap
=
{
{
"Sub"
,
BINARY_OP_SUB
},
{
"Mul"
,
BINARY_OP_MUL
},
{
"RealDiv"
,
BINARY_OP_DIV
},
enum
BinaryOpType
{
BINARY_OP_ADD
=
0
,
BINARY_OP_SUB
,
BINARY_OP_MUL
,
BINARY_OP_DIV
,
BINARY_OP_MAX
,
BINARY_OP_INVALID_TYPE
=
255
};
static
const
std
::
map
<
std
::
string
,
BinaryOpType
>
kBinaryOpTypeMap
=
{
{
"Sub"
,
BINARY_OP_SUB
},
{
"Mul"
,
BINARY_OP_MUL
},
{
"RealDiv"
,
BINARY_OP_DIV
},
{
"Maximum"
,
BINARY_OP_MAX
}};
template
<
typename
T
>
class
BinaryOpGpuKernel
:
public
GpuKernel
{
public:
...
...
@@ -84,6 +88,10 @@ class BinaryOpGpuKernel : public GpuKernel {
inputB_addr
=
workspace_addr
;
break
;
}
case
BINARY_OP_MAX
:
{
inputB_addr
=
input_addr2
;
break
;
}
default:
{
MS_LOG
(
EXCEPTION
)
<<
"Binary operation "
<<
binary_op_type_
<<
" is not supported."
;
}
...
...
@@ -201,6 +209,10 @@ class BinaryOpGpuKernel : public GpuKernel {
tensor_op_
=
CUDNN_OP_TENSOR_ADD
;
break
;
}
case
BINARY_OP_MAX
:
{
tensor_op_
=
CUDNN_OP_TENSOR_MAX
;
break
;
}
default:
{
MS_LOG
(
EXCEPTION
)
<<
"Binary operation "
<<
binary_op_type_
<<
" is not supported."
;
}
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc
浏览文件 @
f746caf5
...
...
@@ -38,5 +38,9 @@ MS_REG_GPU_KERNEL_ONE(ZerosLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).A
UnaryOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
ZerosLike
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
UnaryOpGpuKernel
,
half
)
MS_REG_GPU_KERNEL_ONE
(
Square
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
UnaryOpGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Square
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat16
).
AddOutputAttr
(
kNumberTypeFloat16
),
UnaryOpGpuKernel
,
half
)
}
// namespace kernel
}
// namespace mindspore
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h
浏览文件 @
f746caf5
...
...
@@ -33,13 +33,15 @@ enum UnaryOptype {
UNARY_OP_NEG
,
UNARY_OP_RECIPROCAL
,
UNARY_OP_ZEROSLIKE
,
UNARY_OP_SQUARE
,
UNARY_OP_INVALID_TYPE
=
255
};
const
std
::
map
<
std
::
string
,
UnaryOptype
>
kUnaryOpTypeMap
=
{{
"Exp"
,
UNARY_OP_EXP
},
{
"Log"
,
UNARY_OP_LOG
},
{
"Neg"
,
UNARY_OP_NEG
},
{
"Reciprocal"
,
UNARY_OP_RECIPROCAL
},
{
"ZerosLike"
,
UNARY_OP_ZEROSLIKE
}};
static
const
std
::
map
<
std
::
string
,
UnaryOptype
>
kUnaryOpTypeMap
=
{{
"Exp"
,
UNARY_OP_EXP
},
{
"Log"
,
UNARY_OP_LOG
},
{
"Neg"
,
UNARY_OP_NEG
},
{
"Reciprocal"
,
UNARY_OP_RECIPROCAL
},
{
"ZerosLike"
,
UNARY_OP_ZEROSLIKE
},
{
"Square"
,
UNARY_OP_SQUARE
}};
template
<
typename
T
>
class
UnaryOpGpuKernel
:
public
GpuKernel
{
public:
...
...
@@ -74,6 +76,10 @@ class UnaryOpGpuKernel : public GpuKernel {
Reciprocal
(
input_addr
,
output_addr
,
inputs
[
0
]
->
size
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
break
;
}
case
UNARY_OP_SQUARE
:
{
Square
(
input_addr
,
output_addr
,
inputs
[
0
]
->
size
/
sizeof
(
T
),
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
break
;
}
case
UNARY_OP_ZEROSLIKE
:
{
return
true
;
}
...
...
@@ -93,12 +99,12 @@ class UnaryOpGpuKernel : public GpuKernel {
}
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but
negative
op needs 1 inputs."
;
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but
unary
op needs 1 inputs."
;
return
false
;
}
size_t
output_num
=
AnfAlgo
::
GetOutputTensorNum
(
kernel_node
);
if
(
output_num
!=
1
)
{
MS_LOG
(
ERROR
)
<<
"Output number is "
<<
output_num
<<
", but
negative
op needs 1 output."
;
MS_LOG
(
ERROR
)
<<
"Output number is "
<<
output_num
<<
", but
unary
op needs 1 output."
;
return
false
;
}
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/kernel/gpu/nn/flatten_gpu_kernel.h
浏览文件 @
f746caf5
...
...
@@ -48,14 +48,10 @@ class FlattenGpuFwdKernel : public GpuKernel {
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
auto
shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
input_size_
=
sizeof
(
T
);
for
(
size_t
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
if
(
input_size_
==
0
)
{
input_size_
=
1
;
}
input_size_
*=
shape
[
i
];
}
input_size_
=
input_size_
*
sizeof
(
T
);
InitSizeLists
();
return
true
;
}
...
...
This diff is collapsed.
Click to expand it.
mindspore/ccsrc/vm/backend.cc
100755 → 100644
浏览文件 @
f746caf5
...
...
@@ -189,6 +189,12 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args) {
}
else
if
(
utils
::
isa
<
PyObjectRef
>
(
arg
))
{
auto
value
=
utils
::
cast
<
PyObjectRef
>
(
arg
).
object_
;
inputs
.
push_back
(
py
::
cast
<
tensor
::
TensorPtr
>
(
value
));
}
else
if
(
utils
::
isa
<
VectorRefPtr
>
(
arg
))
{
auto
args_new
=
utils
::
cast
<
VectorRef
>
(
arg
);
(
void
)
std
::
transform
(
args_new
.
begin
(),
args_new
.
end
(),
std
::
back_inserter
(
inputs
),
[](
const
BaseRef
&
v
)
{
return
utils
::
cast
<
tensor
::
TensorPtr
>
(
v
);
});
}
else
{
MS_LOG
(
WARNING
)
<<
"Invalid input type."
;
}
}
...
...
This diff is collapsed.
Click to expand it.
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录
反馈
建议
客服
返回
顶部