提交 694a8213 编写于 作者: L lizhenyu

add adam optimizer

上级 65f2212f
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/cuda_impl/adam_impl.cuh"
template <typename T>
__device__ __forceinline__ T SqrtFunc(T input) {
return sqrt(input);
}
template <>
__device__ __forceinline__ half SqrtFunc(half input) {
return hsqrt(input);
}
template <typename T>
__global__ void ApplyAdamKernel(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power,
const T *learning_rate, const T *beta1, const T *beta2, const T *epsilon, T *variable,
T *m, T *v) {
const T one = static_cast<T>(1.0);
const T new_learning_rate = learning_rate[0] * SqrtFunc(one - beta2_power[0]) / (one - beta1_power[0]);
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
m[i] += (gradient[i] - m[i]) * (one - beta1[0]);
v[i] += (gradient[i] * gradient[i] - v[i]) * (one - beta2[0]);
variable[i] -= new_learning_rate * m[i] / (SqrtFunc(v[i]) + epsilon[0]);
}
}
template <typename T>
void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate,
const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream) {
ApplyAdamKernel<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(
size, gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, variable, m, v);
}
template void ApplyAdam<float>(const size_t size, const float *gradient, const float *beta1_power,
const float *beta2_power, const float *learning_rate, const float *beta1,
const float *beta2, const float *epsilon, float *variable, float *m, float *v,
cudaStream_t cuda_stream);
template void ApplyAdam<half>(const size_t size, const half *gradient, const half *beta1_power, const half *beta2_power,
const half *learning_rate, const half *beta1, const half *beta2, const half *epsilon,
half *variable, half *m, half *v, cudaStream_t cuda_stream);
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_
#include "device/gpu/cuda_common.h"
template <typename T>
void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate,
const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/nn/adam_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(Adam,
KernelAttr()
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
AdamGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(Adam,
KernelAttr()
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddInputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16)
.AddOutputAttr(kNumberTypeFloat16),
AdamGpuKernel, half)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
#include "kernel/gpu/cuda_impl/adam_impl.cuh"
namespace mindspore {
namespace kernel {
template <typename T>
class AdamGpuKernel : public GpuKernel {
public:
AdamGpuKernel()
: variable_size_(0),
m_size_(0),
v_size_(0),
beta1_power_size_(0),
beta2_power_size_(0),
learning_rate_size_(0),
beta1_size_(0),
beta2_size_(0),
epsilon_size_(0),
gradient_size_(0) {}
~AdamGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
void *stream_ptr) override {
T *variable = GetDeviceAddress<T>(inputs, 0);
T *m = GetDeviceAddress<T>(inputs, 1);
T *v = GetDeviceAddress<T>(inputs, 2);
T *beta1_power = GetDeviceAddress<T>(inputs, 3);
T *beta2_power = GetDeviceAddress<T>(inputs, 4);
T *learning_rate = GetDeviceAddress<T>(inputs, 5);
T *beta1 = GetDeviceAddress<T>(inputs, 6);
T *beta2 = GetDeviceAddress<T>(inputs, 7);
T *epsilon = GetDeviceAddress<T>(inputs, 8);
T *gradient = GetDeviceAddress<T>(inputs, 9);
ApplyAdam(inputs[0]->size / sizeof(T), gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon,
variable, m, v, reinterpret_cast<cudaStream_t>(stream_ptr));
return true;
}
bool Init(const CNodePtr &kernel_node) override {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 10) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 10 inputs.";
return false;
}
variable_size_ = sizeof(T);
m_size_ = sizeof(T);
v_size_ = sizeof(T);
beta1_power_size_ = sizeof(T);
beta2_power_size_ = sizeof(T);
learning_rate_size_ = sizeof(T);
beta1_size_ = sizeof(T);
beta2_size_ = sizeof(T);
epsilon_size_ = sizeof(T);
gradient_size_ = sizeof(T);
auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < variable_shape.size(); i++) {
variable_size_ *= variable_shape[i];
}
auto m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (size_t i = 0; i < m_shape.size(); i++) {
m_size_ *= m_shape[i];
}
auto v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
for (size_t i = 0; i < v_shape.size(); i++) {
v_size_ *= v_shape[i];
}
auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9);
for (size_t i = 0; i < gradient_shape.size(); i++) {
gradient_size_ *= gradient_shape[i];
}
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(variable_size_);
input_size_list_.push_back(m_size_);
input_size_list_.push_back(v_size_);
input_size_list_.push_back(beta1_power_size_);
input_size_list_.push_back(beta2_power_size_);
input_size_list_.push_back(learning_rate_size_);
input_size_list_.push_back(beta1_size_);
input_size_list_.push_back(beta2_size_);
input_size_list_.push_back(epsilon_size_);
input_size_list_.push_back(gradient_size_);
output_size_list_.push_back(0);
output_size_list_.push_back(0);
output_size_list_.push_back(0);
}
private:
size_t variable_size_;
size_t m_size_;
size_t v_size_;
size_t beta1_power_size_;
size_t beta2_power_size_;
size_t learning_rate_size_;
size_t beta1_size_;
size_t beta2_size_;
size_t epsilon_size_;
size_t gradient_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import pytest
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.nn import Dense
from mindspore.nn import TrainOneStepCell, WithLossCell
from mindspore.nn.optim import Adam
from mindspore.ops import operations as P
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class NetAdam(nn.Cell):
def __init__(self):
super(NetAdam, self).__init__()
self.batch_size = 1
self.reshape = P.Reshape()
weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01)
self.fc1 = Dense(16, 10, weight_init=weight)
def construct(self, input_x):
output = self.reshape(input_x, (self.batch_size, -1))
output = self.fc1(output)
return output
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_adam():
epoch = 3
net = NetAdam()
optimizer = Adam(filter(lambda x: x.requires_grad,
net.get_parameters()), learning_rate=0.01)
criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True)
net_with_criterion = WithLossCell(net, criterion)
train_network = TrainOneStepCell(
net_with_criterion, optimizer)
train_network.set_train()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
losses1 = []
for _ in range(epoch):
data = Tensor(np.arange(0, 16).reshape(
1, 1, 4, 4).astype(np.float32) * 0.01)
label = Tensor(np.array([0]).astype(np.int32))
loss = train_network(data, label)
losses1.append(loss.asnumpy())
assert losses1[0] > losses1[1]
assert losses1[1] > losses1[2]
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
losses2 = []
for _ in range(epoch):
data = Tensor(np.arange(0, 16).reshape(
1, 1, 4, 4).astype(np.float32) * 0.01)
label = Tensor(np.array([0]).astype(np.int32))
loss = train_network(data, label)
losses2.append(loss.asnumpy())
assert losses2[0] > losses2[1]
assert losses2[1] > losses2[2]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册