Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
97d21ba0
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
97d21ba0
编写于
5月 07, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
5月 07, 2020
浏览文件
操作
浏览文件
下载
差异文件
!502 Gpu Support Gelu & GeluGrad
Merge pull request !502 from chenweifeng/gelu
上级
a97f30ba
a304304c
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
441 addition
and
0 deletion
+441
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu
+65
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh
+27
-0
mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc
+29
-0
mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h
mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h
+75
-0
mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc
+24
-0
mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h
mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h
+72
-0
tests/st/ops/gpu/test_gelu_grad_op.py
tests/st/ops/gpu/test_gelu_grad_op.py
+61
-0
tests/st/ops/gpu/test_gelu_op.py
tests/st/ops/gpu/test_gelu_op.py
+88
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cu
0 → 100644
浏览文件 @
97d21ba0
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/cuda_impl/gelu_impl.cuh"
#include "device/gpu/cuda_common.h"
template
<
typename
T
>
__global__
void
GeluKernel
(
size_t
size
,
T
*
input_addr
,
T
*
output_addr
)
{
// formula:
// gelu(x) = 0.5 * x * (1.0 + tanh(y))
// tanh(y) = 2 / (1 + exp(-2y)) - 1)
// y = sqrt(2/pi) * (x + 0.044715 * x^3)
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
float
x
=
input_addr
[
pos
];
float
tanh_res
=
tanh
(
0.7978845608
*
(
x
+
0.044715
*
x
*
x
*
x
));
output_addr
[
pos
]
=
0.5
*
x
*
(
1.0
+
tanh_res
);
}
}
template
<
typename
T
>
void
Gelu
(
size_t
size
,
T
*
input_addr
,
T
*
output_addr
,
cudaStream_t
cuda_stream
)
{
GeluKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
input_addr
,
output_addr
);
return
;
}
template
<
typename
T
>
__global__
void
GeluGradKernel
(
size_t
size
,
T
*
dy_addr
,
T
*
x_addr
,
T
*
dx_addr
)
{
// formula:
// dx = dy * y'
// y' = 0.5 * (1 + tanh(tanh_para)) +
// 0.5 * x * (1 - tanh(tanh_para) * tanh(tanh_para)) * mul_right
// tanh_para = sqrt(2/pi) * (x + 0.044715 * x^3)
// mul_right = sqrt(2/pi) * (1 + 3 * 0.044715 * x^2))
for
(
size_t
pos
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
pos
<
(
size
);
pos
+=
blockDim
.
x
*
gridDim
.
x
)
{
T
x
=
x_addr
[
pos
];
T
tanh_res
=
tanh
(
0.7978845608
*
(
x
+
0.044715
*
x
*
x
*
x
));
T
mul_right
=
0.7978845608
+
0.1070322244
*
x
*
x
;
T
y_res
=
0.5
*
(
1
+
tanh_res
)
+
0.5
*
x
*
(
1
-
tanh_res
*
tanh_res
)
*
mul_right
;
dx_addr
[
pos
]
=
dy_addr
[
pos
]
*
y_res
;
}
}
template
<
typename
T
>
void
GeluGradKernel
(
size_t
size
,
T
*
dy_addr
,
T
*
x_addr
,
T
*
dx_addr
,
cudaStream_t
cuda_stream
)
{
GeluGradKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
dy_addr
,
x_addr
,
dx_addr
);
}
template
void
Gelu
(
size_t
size
,
float
*
input_addr
,
float
*
output_addr
,
cudaStream_t
cuda_stream
);
template
void
GeluGradKernel
(
size_t
size
,
float
*
dy_addr
,
float
*
x_addr
,
float
*
dx_addr
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/gelu_impl.cuh
0 → 100644
浏览文件 @
97d21ba0
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_
#include "device/gpu/cuda_common.h"
template
<
typename
T
>
void
Gelu
(
size_t
input_size
,
T
*
input_addr
,
T
*
output_addr
,
cudaStream_t
cuda_stream
);
template
<
typename
T
>
void
GeluGradKernel
(
size_t
size
,
T
*
dy_addr
,
T
*
x_addr
,
T
*
dx_addr
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_GELU_H_
mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.cc
0 → 100644
浏览文件 @
97d21ba0
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/nn/gelu_grad_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
GeluGrad
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
GeLUGpuGradKernel
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/nn/gelu_grad_kernel.h
0 → 100644
浏览文件 @
97d21ba0
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/kernel_constants.h"
#include "kernel/gpu/cuda_impl/gelu_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
GeLUGpuGradKernel
:
public
GpuKernel
{
public:
GeLUGpuGradKernel
()
:
input_size_
(
0
)
{}
~
GeLUGpuGradKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
override
{
T
*
dy_addr
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
x_addr
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
dx_addr
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
GeluGradKernel
(
input_size_
/
sizeof
(
T
),
dy_addr
,
x_addr
,
dx_addr
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
InitResource
();
input_size_
=
sizeof
(
T
);
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
for
(
auto
dim
:
input_shape
)
{
input_size_
*=
dim
;
}
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input_size_
);
input_size_list_
.
push_back
(
input_size_
);
input_size_list_
.
push_back
(
input_size_
);
output_size_list_
.
push_back
(
input_size_
);
}
private:
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
size_t
input_size_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GRAD_KERNEL_H_
mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.cc
0 → 100644
浏览文件 @
97d21ba0
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/nn/gelu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
Gelu
,
KernelAttr
().
AddInputAttr
(
kNumberTypeFloat32
).
AddOutputAttr
(
kNumberTypeFloat32
),
GeluGpuKernel
,
float
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/nn/gelu_kernel.h
0 → 100644
浏览文件 @
97d21ba0
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/kernel_constants.h"
#include "kernel/gpu/cuda_impl/gelu_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
GeluGpuKernel
:
public
GpuKernel
{
public:
GeluGpuKernel
()
:
input_size_
(
0
)
{}
~
GeluGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
outputs
,
uintptr_t
stream_ptr
)
override
{
T
*
input_addr
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
output_addr
=
GetDeviceAddress
<
T
>
(
outputs
,
0
);
Gelu
(
input_size_
/
sizeof
(
T
),
input_addr
,
output_addr
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
InitResource
();
input_size_
=
sizeof
(
T
);
auto
input_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
for
(
auto
dim
:
input_shape
)
{
input_size_
*=
dim
;
}
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
input_size_
);
output_size_list_
.
push_back
(
input_size_
);
}
private:
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
size_t
input_size_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_GELU_GPU_KERNEL_H_
tests/st/ops/gpu/test_gelu_grad_op.py
0 → 100644
浏览文件 @
97d21ba0
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
pytest
import
numpy
as
np
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
import
mindspore.context
as
context
from
mindspore.ops
import
composite
as
C
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
class
GeluNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
GeluNet
,
self
).
__init__
()
self
.
gelu
=
P
.
Gelu
()
def
construct
(
self
,
x
):
return
self
.
gelu
(
x
)
class
Grad
(
nn
.
Cell
):
def
__init__
(
self
,
network
):
super
(
Grad
,
self
).
__init__
()
self
.
grad
=
C
.
GradOperation
(
name
=
"get_all"
,
get_all
=
True
,
sens_param
=
True
)
self
.
network
=
network
def
construct
(
self
,
input_data
,
sens
):
gout
=
self
.
grad
(
self
.
network
)(
input_data
,
sens
)
return
gout
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_gelugrad
():
x_ms
=
Tensor
(
np
.
array
([
0.58401114
,
0.68800163
,
0.9760397
,
0.14702141
,
0.46563736
,
0.9607501
,
0.14567593
,
0.12261796
,
0.37054458
,
0.46421242
]).
astype
(
np
.
float32
))
dy_ms
=
Tensor
(
np
.
array
([
0.5559598
,
0.96994054
,
0.24770357
,
0.34646875
,
0.2984393
,
0.03287048
,
0.55681044
,
0.966908
,
0.06015943
,
0.6099489
]).
astype
(
np
.
float32
))
net
=
GeluNet
()
grad
=
Grad
(
net
)
output
=
grad
(
x_ms
,
dy_ms
)
print
(
output
)
expect
=
[
0.50963277
,
0.9414753
,
0.2667653
,
0.21358444
,
0.25243032
,
0.0352667
,
0.34266686
,
0.57757664
,
0.04707306
,
0.51536125
]
assert
np
.
allclose
(
output
[
0
].
asnumpy
(),
expect
)
\ No newline at end of file
tests/st/ops/gpu/test_gelu_op.py
0 → 100644
浏览文件 @
97d21ba0
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
pytest
from
mindspore
import
Tensor
from
mindspore.ops
import
operations
as
P
import
mindspore.nn
as
nn
import
numpy
as
np
import
mindspore.context
as
context
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
class
GeluNet
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
GeluNet
,
self
).
__init__
()
self
.
gelu
=
P
.
Gelu
()
def
construct
(
self
,
x
):
return
self
.
gelu
(
x
)
def
GeluCompute
(
x
):
return
0.5
*
x
*
(
1.0
+
np
.
tanh
(
np
.
sqrt
(
2
/
np
.
pi
)
*
(
x
+
0.044715
*
x
*
x
*
x
)))
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_gelu_1d
():
x_np
=
np
.
random
.
random
((
50
,)).
astype
(
np
.
float32
)
y_np
=
GeluCompute
(
x_np
)
x_ms
=
Tensor
(
x_np
)
net
=
GeluNet
()
y_ms
=
net
(
x_ms
)
assert
np
.
allclose
(
y_np
,
y_ms
.
asnumpy
())
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_gelu_2d
():
x_np
=
np
.
random
.
random
((
50
,
40
)).
astype
(
np
.
float32
)
y_np
=
GeluCompute
(
x_np
)
x_ms
=
Tensor
(
x_np
)
net
=
GeluNet
()
y_ms
=
net
(
x_ms
)
assert
np
.
allclose
(
y_np
,
y_ms
.
asnumpy
())
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_gelu_4d
():
x_np
=
np
.
random
.
random
((
32
,
3
,
224
,
224
)).
astype
(
np
.
float32
)
y_np
=
GeluCompute
(
x_np
)
x_ms
=
Tensor
(
x_np
)
net
=
GeluNet
()
y_ms
=
net
(
x_ms
)
assert
np
.
allclose
(
y_np
,
y_ms
.
asnumpy
())
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_gelu_neg
():
x_np
=
np
.
random
.
random
((
32
,
3
,
224
,
224
)).
astype
(
np
.
float32
)
*
-
1
y_np
=
GeluCompute
(
x_np
)
x_ms
=
Tensor
(
x_np
)
net
=
GeluNet
()
y_ms
=
net
(
x_ms
)
assert
np
.
allclose
(
y_np
,
y_ms
.
asnumpy
())
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录