Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
694a8213
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
694a8213
编写于
6月 17, 2020
作者:
L
lizhenyu
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add adam optimizer
上级
65f2212f
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
355 addition
and
0 deletion
+355
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cu
mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cu
+56
-0
mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cuh
mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cuh
+25
-0
mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.cc
+54
-0
mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.h
+142
-0
tests/st/ops/gpu/test_adam_op.py
tests/st/ops/gpu/test_adam_op.py
+78
-0
未找到文件。
mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cu
0 → 100644
浏览文件 @
694a8213
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/cuda_impl/adam_impl.cuh"
template
<
typename
T
>
__device__
__forceinline__
T
SqrtFunc
(
T
input
)
{
return
sqrt
(
input
);
}
template
<
>
__device__
__forceinline__
half
SqrtFunc
(
half
input
)
{
return
hsqrt
(
input
);
}
template
<
typename
T
>
__global__
void
ApplyAdamKernel
(
const
size_t
size
,
const
T
*
gradient
,
const
T
*
beta1_power
,
const
T
*
beta2_power
,
const
T
*
learning_rate
,
const
T
*
beta1
,
const
T
*
beta2
,
const
T
*
epsilon
,
T
*
variable
,
T
*
m
,
T
*
v
)
{
const
T
one
=
static_cast
<
T
>
(
1.0
);
const
T
new_learning_rate
=
learning_rate
[
0
]
*
SqrtFunc
(
one
-
beta2_power
[
0
])
/
(
one
-
beta1_power
[
0
]);
for
(
size_t
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
size
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
m
[
i
]
+=
(
gradient
[
i
]
-
m
[
i
])
*
(
one
-
beta1
[
0
]);
v
[
i
]
+=
(
gradient
[
i
]
*
gradient
[
i
]
-
v
[
i
])
*
(
one
-
beta2
[
0
]);
variable
[
i
]
-=
new_learning_rate
*
m
[
i
]
/
(
SqrtFunc
(
v
[
i
])
+
epsilon
[
0
]);
}
}
template
<
typename
T
>
void
ApplyAdam
(
const
size_t
size
,
const
T
*
gradient
,
const
T
*
beta1_power
,
const
T
*
beta2_power
,
const
T
*
learning_rate
,
const
T
*
beta1
,
const
T
*
beta2
,
const
T
*
epsilon
,
T
*
variable
,
T
*
m
,
T
*
v
,
cudaStream_t
cuda_stream
)
{
ApplyAdamKernel
<<<
GET_BLOCKS
(
size
),
GET_THREADS
,
0
,
cuda_stream
>>>
(
size
,
gradient
,
beta1_power
,
beta2_power
,
learning_rate
,
beta1
,
beta2
,
epsilon
,
variable
,
m
,
v
);
}
template
void
ApplyAdam
<
float
>(
const
size_t
size
,
const
float
*
gradient
,
const
float
*
beta1_power
,
const
float
*
beta2_power
,
const
float
*
learning_rate
,
const
float
*
beta1
,
const
float
*
beta2
,
const
float
*
epsilon
,
float
*
variable
,
float
*
m
,
float
*
v
,
cudaStream_t
cuda_stream
);
template
void
ApplyAdam
<
half
>(
const
size_t
size
,
const
half
*
gradient
,
const
half
*
beta1_power
,
const
half
*
beta2_power
,
const
half
*
learning_rate
,
const
half
*
beta1
,
const
half
*
beta2
,
const
half
*
epsilon
,
half
*
variable
,
half
*
m
,
half
*
v
,
cudaStream_t
cuda_stream
);
mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cuh
0 → 100644
浏览文件 @
694a8213
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_
#include "device/gpu/cuda_common.h"
template
<
typename
T
>
void
ApplyAdam
(
const
size_t
size
,
const
T
*
gradient
,
const
T
*
beta1_power
,
const
T
*
beta2_power
,
const
T
*
learning_rate
,
const
T
*
beta1
,
const
T
*
beta2
,
const
T
*
epsilon
,
T
*
variable
,
T
*
m
,
T
*
v
,
cudaStream_t
cuda_stream
);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_
mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.cc
0 → 100644
浏览文件 @
694a8213
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/nn/adam_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_ONE
(
Adam
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddInputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
)
.
AddOutputAttr
(
kNumberTypeFloat32
),
AdamGpuKernel
,
float
)
MS_REG_GPU_KERNEL_ONE
(
Adam
,
KernelAttr
()
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddInputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
)
.
AddOutputAttr
(
kNumberTypeFloat16
),
AdamGpuKernel
,
half
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.h
0 → 100644
浏览文件 @
694a8213
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/adam_impl.cuh"
namespace
mindspore
{
namespace
kernel
{
template
<
typename
T
>
class
AdamGpuKernel
:
public
GpuKernel
{
public:
AdamGpuKernel
()
:
variable_size_
(
0
),
m_size_
(
0
),
v_size_
(
0
),
beta1_power_size_
(
0
),
beta2_power_size_
(
0
),
learning_rate_size_
(
0
),
beta1_size_
(
0
),
beta2_size_
(
0
),
epsilon_size_
(
0
),
gradient_size_
(
0
)
{}
~
AdamGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
inputs
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
,
void
*
stream_ptr
)
override
{
T
*
variable
=
GetDeviceAddress
<
T
>
(
inputs
,
0
);
T
*
m
=
GetDeviceAddress
<
T
>
(
inputs
,
1
);
T
*
v
=
GetDeviceAddress
<
T
>
(
inputs
,
2
);
T
*
beta1_power
=
GetDeviceAddress
<
T
>
(
inputs
,
3
);
T
*
beta2_power
=
GetDeviceAddress
<
T
>
(
inputs
,
4
);
T
*
learning_rate
=
GetDeviceAddress
<
T
>
(
inputs
,
5
);
T
*
beta1
=
GetDeviceAddress
<
T
>
(
inputs
,
6
);
T
*
beta2
=
GetDeviceAddress
<
T
>
(
inputs
,
7
);
T
*
epsilon
=
GetDeviceAddress
<
T
>
(
inputs
,
8
);
T
*
gradient
=
GetDeviceAddress
<
T
>
(
inputs
,
9
);
ApplyAdam
(
inputs
[
0
]
->
size
/
sizeof
(
T
),
gradient
,
beta1_power
,
beta2_power
,
learning_rate
,
beta1
,
beta2
,
epsilon
,
variable
,
m
,
v
,
reinterpret_cast
<
cudaStream_t
>
(
stream_ptr
));
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
size_t
input_num
=
AnfAlgo
::
GetInputTensorNum
(
kernel_node
);
if
(
input_num
!=
10
)
{
MS_LOG
(
ERROR
)
<<
"Input number is "
<<
input_num
<<
", but ftrl needs 10 inputs."
;
return
false
;
}
variable_size_
=
sizeof
(
T
);
m_size_
=
sizeof
(
T
);
v_size_
=
sizeof
(
T
);
beta1_power_size_
=
sizeof
(
T
);
beta2_power_size_
=
sizeof
(
T
);
learning_rate_size_
=
sizeof
(
T
);
beta1_size_
=
sizeof
(
T
);
beta2_size_
=
sizeof
(
T
);
epsilon_size_
=
sizeof
(
T
);
gradient_size_
=
sizeof
(
T
);
auto
variable_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
0
);
for
(
size_t
i
=
0
;
i
<
variable_shape
.
size
();
i
++
)
{
variable_size_
*=
variable_shape
[
i
];
}
auto
m_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
1
);
for
(
size_t
i
=
0
;
i
<
m_shape
.
size
();
i
++
)
{
m_size_
*=
m_shape
[
i
];
}
auto
v_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
2
);
for
(
size_t
i
=
0
;
i
<
v_shape
.
size
();
i
++
)
{
v_size_
*=
v_shape
[
i
];
}
auto
gradient_shape
=
AnfAlgo
::
GetPrevNodeOutputInferShape
(
kernel_node
,
9
);
for
(
size_t
i
=
0
;
i
<
gradient_shape
.
size
();
i
++
)
{
gradient_size_
*=
gradient_shape
[
i
];
}
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
push_back
(
variable_size_
);
input_size_list_
.
push_back
(
m_size_
);
input_size_list_
.
push_back
(
v_size_
);
input_size_list_
.
push_back
(
beta1_power_size_
);
input_size_list_
.
push_back
(
beta2_power_size_
);
input_size_list_
.
push_back
(
learning_rate_size_
);
input_size_list_
.
push_back
(
beta1_size_
);
input_size_list_
.
push_back
(
beta2_size_
);
input_size_list_
.
push_back
(
epsilon_size_
);
input_size_list_
.
push_back
(
gradient_size_
);
output_size_list_
.
push_back
(
0
);
output_size_list_
.
push_back
(
0
);
output_size_list_
.
push_back
(
0
);
}
private:
size_t
variable_size_
;
size_t
m_size_
;
size_t
v_size_
;
size_t
beta1_power_size_
;
size_t
beta2_power_size_
;
size_t
learning_rate_size_
;
size_t
beta1_size_
;
size_t
beta2_size_
;
size_t
epsilon_size_
;
size_t
gradient_size_
;
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_
tests/st/ops/gpu/test_adam_op.py
0 → 100644
浏览文件 @
694a8213
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import
numpy
as
np
import
pytest
import
mindspore.context
as
context
import
mindspore.nn
as
nn
from
mindspore
import
Tensor
from
mindspore.nn
import
Dense
from
mindspore.nn
import
TrainOneStepCell
,
WithLossCell
from
mindspore.nn.optim
import
Adam
from
mindspore.ops
import
operations
as
P
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
class
NetAdam
(
nn
.
Cell
):
def
__init__
(
self
):
super
(
NetAdam
,
self
).
__init__
()
self
.
batch_size
=
1
self
.
reshape
=
P
.
Reshape
()
weight
=
Tensor
(
np
.
ones
([
10
,
16
]).
astype
(
np
.
float32
)
*
0.01
)
self
.
fc1
=
Dense
(
16
,
10
,
weight_init
=
weight
)
def
construct
(
self
,
input_x
):
output
=
self
.
reshape
(
input_x
,
(
self
.
batch_size
,
-
1
))
output
=
self
.
fc1
(
output
)
return
output
@
pytest
.
mark
.
level0
@
pytest
.
mark
.
platform_x86_gpu_training
@
pytest
.
mark
.
env_onecard
def
test_adam
():
epoch
=
3
net
=
NetAdam
()
optimizer
=
Adam
(
filter
(
lambda
x
:
x
.
requires_grad
,
net
.
get_parameters
()),
learning_rate
=
0.01
)
criterion
=
nn
.
SoftmaxCrossEntropyWithLogits
(
is_grad
=
False
,
sparse
=
True
)
net_with_criterion
=
WithLossCell
(
net
,
criterion
)
train_network
=
TrainOneStepCell
(
net_with_criterion
,
optimizer
)
train_network
.
set_train
()
context
.
set_context
(
mode
=
context
.
GRAPH_MODE
,
device_target
=
"GPU"
)
losses1
=
[]
for
_
in
range
(
epoch
):
data
=
Tensor
(
np
.
arange
(
0
,
16
).
reshape
(
1
,
1
,
4
,
4
).
astype
(
np
.
float32
)
*
0.01
)
label
=
Tensor
(
np
.
array
([
0
]).
astype
(
np
.
int32
))
loss
=
train_network
(
data
,
label
)
losses1
.
append
(
loss
.
asnumpy
())
assert
losses1
[
0
]
>
losses1
[
1
]
assert
losses1
[
1
]
>
losses1
[
2
]
context
.
set_context
(
mode
=
context
.
PYNATIVE_MODE
,
device_target
=
"GPU"
)
losses2
=
[]
for
_
in
range
(
epoch
):
data
=
Tensor
(
np
.
arange
(
0
,
16
).
reshape
(
1
,
1
,
4
,
4
).
astype
(
np
.
float32
)
*
0.01
)
label
=
Tensor
(
np
.
array
([
0
]).
astype
(
np
.
int32
))
loss
=
train_network
(
data
,
label
)
losses2
.
append
(
loss
.
asnumpy
())
assert
losses2
[
0
]
>
losses2
[
1
]
assert
losses2
[
1
]
>
losses2
[
2
]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录