diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cu new file mode 100644 index 0000000000000000000000000000000000000000..3ec63ee03a7d28731b891ca2d6506f2fcd8a0bb4 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cu @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/cuda_impl/adam_impl.cuh" + +template +__device__ __forceinline__ T SqrtFunc(T input) { + return sqrt(input); +} + +template <> +__device__ __forceinline__ half SqrtFunc(half input) { + return hsqrt(input); +} + +template +__global__ void ApplyAdamKernel(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, + const T *learning_rate, const T *beta1, const T *beta2, const T *epsilon, T *variable, + T *m, T *v) { + const T one = static_cast(1.0); + const T new_learning_rate = learning_rate[0] * SqrtFunc(one - beta2_power[0]) / (one - beta1_power[0]); + + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) { + m[i] += (gradient[i] - m[i]) * (one - beta1[0]); + v[i] += (gradient[i] * gradient[i] - v[i]) * (one - beta2[0]); + variable[i] -= new_learning_rate * m[i] / (SqrtFunc(v[i]) + epsilon[0]); + } +} + +template +void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, + const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream) { + ApplyAdamKernel<<>>( + size, gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, variable, m, v); +} + +template void ApplyAdam(const size_t size, const float *gradient, const float *beta1_power, + const float *beta2_power, const float *learning_rate, const float *beta1, + const float *beta2, const float *epsilon, float *variable, float *m, float *v, + cudaStream_t cuda_stream); +template void ApplyAdam(const size_t size, const half *gradient, const half *beta1_power, const half *beta2_power, + const half *learning_rate, const half *beta1, const half *beta2, const half *epsilon, + half *variable, half *m, half *v, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..f48a113c261b262af892aeaa9cf92c3f9af3f5df --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/adam_impl.cuh @@ -0,0 +1,25 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ + +#include "device/gpu/cuda_common.h" +template +void ApplyAdam(const size_t size, const T *gradient, const T *beta1_power, const T *beta2_power, const T *learning_rate, + const T *beta1, const T *beta2, const T *epsilon, T *variable, T *m, T *v, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMP_ADAM_IMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..049a5cc280735218882d25887ba7f7b609e94ec2 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.cc @@ -0,0 +1,54 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/nn/adam_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(Adam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + AdamGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Adam, + KernelAttr() + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddInputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16) + .AddOutputAttr(kNumberTypeFloat16), + AdamGpuKernel, half) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..93c6381ab3452b19a407cfec5d1ad4859a8deffc --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/adam_gpu_kernel.h @@ -0,0 +1,142 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ + +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/cuda_impl/adam_impl.cuh" +namespace mindspore { +namespace kernel { +template +class AdamGpuKernel : public GpuKernel { + public: + AdamGpuKernel() + : variable_size_(0), + m_size_(0), + v_size_(0), + beta1_power_size_(0), + beta2_power_size_(0), + learning_rate_size_(0), + beta1_size_(0), + beta2_size_(0), + epsilon_size_(0), + gradient_size_(0) {} + + ~AdamGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, const std::vector &, + void *stream_ptr) override { + T *variable = GetDeviceAddress(inputs, 0); + T *m = GetDeviceAddress(inputs, 1); + T *v = GetDeviceAddress(inputs, 2); + T *beta1_power = GetDeviceAddress(inputs, 3); + T *beta2_power = GetDeviceAddress(inputs, 4); + T *learning_rate = GetDeviceAddress(inputs, 5); + T *beta1 = GetDeviceAddress(inputs, 6); + T *beta2 = GetDeviceAddress(inputs, 7); + T *epsilon = GetDeviceAddress(inputs, 8); + T *gradient = GetDeviceAddress(inputs, 9); + ApplyAdam(inputs[0]->size / sizeof(T), gradient, beta1_power, beta2_power, learning_rate, beta1, beta2, epsilon, + variable, m, v, reinterpret_cast(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 10) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but ftrl needs 10 inputs."; + return false; + } + + variable_size_ = sizeof(T); + m_size_ = sizeof(T); + v_size_ = sizeof(T); + beta1_power_size_ = sizeof(T); + beta2_power_size_ = sizeof(T); + learning_rate_size_ = sizeof(T); + beta1_size_ = sizeof(T); + beta2_size_ = sizeof(T); + epsilon_size_ = sizeof(T); + gradient_size_ = sizeof(T); + + auto variable_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + for (size_t i = 0; i < variable_shape.size(); i++) { + variable_size_ *= variable_shape[i]; + } + + auto m_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < m_shape.size(); i++) { + m_size_ *= m_shape[i]; + } + + auto v_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + for (size_t i = 0; i < v_shape.size(); i++) { + v_size_ *= v_shape[i]; + } + + auto gradient_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 9); + for (size_t i = 0; i < gradient_shape.size(); i++) { + gradient_size_ *= gradient_shape[i]; + } + + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(variable_size_); + input_size_list_.push_back(m_size_); + input_size_list_.push_back(v_size_); + input_size_list_.push_back(beta1_power_size_); + input_size_list_.push_back(beta2_power_size_); + input_size_list_.push_back(learning_rate_size_); + input_size_list_.push_back(beta1_size_); + input_size_list_.push_back(beta2_size_); + input_size_list_.push_back(epsilon_size_); + input_size_list_.push_back(gradient_size_); + output_size_list_.push_back(0); + output_size_list_.push_back(0); + output_size_list_.push_back(0); + } + + private: + size_t variable_size_; + size_t m_size_; + size_t v_size_; + size_t beta1_power_size_; + size_t beta2_power_size_; + size_t learning_rate_size_; + size_t beta1_size_; + size_t beta2_size_; + size_t epsilon_size_; + size_t gradient_size_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_NN_ADAM_GPU_KERNEL_H_ diff --git a/tests/st/ops/gpu/test_adam_op.py b/tests/st/ops/gpu/test_adam_op.py new file mode 100644 index 0000000000000000000000000000000000000000..6e2bb0ddab38fb12cb83618468a2d08ec54d8b6e --- /dev/null +++ b/tests/st/ops/gpu/test_adam_op.py @@ -0,0 +1,78 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.nn import Dense +from mindspore.nn import TrainOneStepCell, WithLossCell +from mindspore.nn.optim import Adam +from mindspore.ops import operations as P + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class NetAdam(nn.Cell): + def __init__(self): + super(NetAdam, self).__init__() + self.batch_size = 1 + self.reshape = P.Reshape() + weight = Tensor(np.ones([10, 16]).astype(np.float32) * 0.01) + self.fc1 = Dense(16, 10, weight_init=weight) + + def construct(self, input_x): + output = self.reshape(input_x, (self.batch_size, -1)) + output = self.fc1(output) + return output + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_adam(): + epoch = 3 + net = NetAdam() + optimizer = Adam(filter(lambda x: x.requires_grad, + net.get_parameters()), learning_rate=0.01) + criterion = nn.SoftmaxCrossEntropyWithLogits(is_grad=False, sparse=True) + net_with_criterion = WithLossCell(net, criterion) + train_network = TrainOneStepCell( + net_with_criterion, optimizer) + train_network.set_train() + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + losses1 = [] + for _ in range(epoch): + data = Tensor(np.arange(0, 16).reshape( + 1, 1, 4, 4).astype(np.float32) * 0.01) + label = Tensor(np.array([0]).astype(np.int32)) + loss = train_network(data, label) + losses1.append(loss.asnumpy()) + assert losses1[0] > losses1[1] + assert losses1[1] > losses1[2] + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + losses2 = [] + for _ in range(epoch): + data = Tensor(np.arange(0, 16).reshape( + 1, 1, 4, 4).astype(np.float32) * 0.01) + label = Tensor(np.array([0]).astype(np.int32)) + loss = train_network(data, label) + losses2.append(loss.asnumpy()) + assert losses2[0] > losses2[1] + assert losses2[1] > losses2[2]