From 5432fcb43ffd11db20294a62caf962df7f9cae82 Mon Sep 17 00:00:00 2001 From: wilfChen Date: Sat, 18 Apr 2020 09:26:45 +0800 Subject: [PATCH] gpu support RMSProp kernel --- .../kernel/gpu/cuda_impl/rmsprop_impl.cu | 68 ++++++++ .../kernel/gpu/cuda_impl/rmsprop_impl.cuh | 30 ++++ .../ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc | 49 ++++++ .../ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h | 110 +++++++++++++ tests/st/ops/test_rmsprop.py | 152 ++++++++++++++++++ 5 files changed, 409 insertions(+) create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu create mode 100644 mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh create mode 100644 mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc create mode 100644 mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h create mode 100644 tests/st/ops/test_rmsprop.py diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu new file mode 100644 index 000000000..31a4d97df --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cu @@ -0,0 +1,68 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include "kernel/gpu/cuda_impl/rmsprop_impl.cuh" +#include "device/gpu/cuda_common.h" + +template +__global__ void RmsPropKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, + T* mean_square, T*moment, T* gradients, const size_t size) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { + mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i]; + moment[i] = momentum[0] * moment[i] + learning_rate[0] * rsqrt(mean_square[i] + epsilon[0]) * gradients[i]; + variable[i] -= moment[i]; + } +} + +template +void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, + T* variable, T* mean_square, T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream) { + RmsPropKernel<<>>(learning_rate, decay, momentum, epsilon, + variable, mean_square, moment, gradients, size); +} + +template +__global__ void RmsPropCenterKernel(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, + T* variable, T* mean_gradients, T* mean_square, T*moment, T* gradients, + const size_t size) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (size); i += blockDim.x * gridDim.x) { + mean_gradients[i] = decay[0] * mean_gradients[i] + (1.0 - decay[0]) * gradients[i]; + mean_square[i] = decay[0] * mean_square[i] + (1.0 - decay[0]) * gradients[i] * gradients[i]; + moment[i] = momentum[0] * moment[i] + learning_rate[0] * + rsqrt(mean_square[i] - mean_gradients[i] * mean_gradients[i] + epsilon[0]) * gradients[i]; + variable[i] -= moment[i]; + } +} + +template +void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, + T* mean_gradients, T* mean_square, T*moment, T* gradients, const size_t size, + cudaStream_t cuda_stream) { + RmsPropCenterKernel<<>>(learning_rate, decay, momentum, epsilon, + variable, mean_gradients, mean_square, + moment, gradients, size); +} + +template +void RmsProp(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon, + float* variable, float* mean_square, float* moment, float* gradients, const size_t size, + cudaStream_t cuda_stream); + +template +void RmsPropCenter(const float* learning_rate, const float* decay, const float* momentum, const float* epsilon, + float* variable, float* mean_gradients, float* mean_square, float*moment, float* gradients, + const size_t size, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh new file mode 100644 index 000000000..62d7e19ba --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/rmsprop_impl.cuh @@ -0,0 +1,30 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ +#include "device/gpu/cuda_common.h" + +template +void RmsProp(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, T* mean_square, + T* moment, T* gradients, const size_t size, cudaStream_t cuda_stream); + +template +void RmsPropCenter(const T* learning_rate, const T* decay, const T* momentum, const T* epsilon, T* variable, + T* mean_gradients, T* mean_square, T* moment, T* gradients, const size_t size, + cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RMSPROP_H_ diff --git a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc new file mode 100644 index 000000000..85aabe575 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.cc @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel/gpu/nn/rmsprop_gpu_kernel.h" + +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(ApplyRMSProp, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + RMSPropGpuKernel, float) + +MS_REG_GPU_KERNEL_ONE(ApplyCenteredRMSProp, + KernelAttr() + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + RMSPropGpuKernel, float) + +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h new file mode 100644 index 000000000..d1ca53110 --- /dev/null +++ b/mindspore/ccsrc/kernel/gpu/nn/rmsprop_gpu_kernel.h @@ -0,0 +1,110 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_NN_RMSPROP_KERNEL_H_ + +#include +#include "kernel/gpu/gpu_kernel.h" +#include "kernel/gpu/gpu_kernel_factory.h" +#include "kernel/gpu/cuda_impl/rmsprop_impl.cuh" + +namespace mindspore { +namespace kernel { +template +class RMSPropGpuKernel : public GpuKernel { + public: + RMSPropGpuKernel() : size_(1), use_center_(false) {} + ~RMSPropGpuKernel() override = default; + + const std::vector &GetInputSizeList() const override { return input_size_list_; } + const std::vector &GetOutputSizeList() const override { return output_size_list_; } + const std::vector &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector &inputs, const std::vector &, + const std::vector &outputs, uintptr_t stream) override { + if (!use_center_) { + T *variable = GetDeviceAddress(inputs, 0); + T *mean_square = GetDeviceAddress(inputs, 1); + T *moment = GetDeviceAddress(inputs, 2); + T *gradients = GetDeviceAddress(inputs, 3); + T *learning_rate = GetDeviceAddress(inputs, 4); + T *decay = GetDeviceAddress(inputs, 5); + T *momentum = GetDeviceAddress(inputs, 6); + T *epsilon = GetDeviceAddress(inputs, 7); + + RmsProp(learning_rate, decay, momentum, epsilon, variable, mean_square, moment, gradients, size_, + reinterpret_cast(stream)); + } else { + T *variable = GetDeviceAddress(inputs, 0); + T *mean_gradients = GetDeviceAddress(inputs, 1); + T *mean_square = GetDeviceAddress(inputs, 2); + T *moment = GetDeviceAddress(inputs, 3); + T *gradients = GetDeviceAddress(inputs, 4); + T *learning_rate = GetDeviceAddress(inputs, 5); + T *decay = GetDeviceAddress(inputs, 6); + T *momentum = GetDeviceAddress(inputs, 7); + T *epsilon = GetDeviceAddress(inputs, 8); + + RmsPropCenter(learning_rate, decay, momentum, epsilon, variable, mean_gradients, mean_square, moment, gradients, + size_, reinterpret_cast(stream)); + } + return true; + } + bool Init(const CNodePtr &kernel_node) override { + auto node_name = AnfAlgo::GetCNodeName(kernel_node); + if (node_name == "ApplyCenteredRMSProp") { + use_center_ = true; + } + + auto input_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); + for (auto &dim : input_shape) { + size_ *= dim; + } + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + size_t input_size = size_ * sizeof(T); + input_size_list_.push_back(input_size); + if (use_center_) { + input_size_list_.push_back(input_size); + } + + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(input_size); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + input_size_list_.push_back(sizeof(T)); + output_size_list_.push_back(0); + } + + private: + size_t size_; + bool use_center_; + + std::vector input_size_list_; + std::vector output_size_list_; + std::vector workspace_size_list_; +}; +} // namespace kernel +} // namespace mindspore + +#endif diff --git a/tests/st/ops/test_rmsprop.py b/tests/st/ops/test_rmsprop.py new file mode 100644 index 000000000..dcf65be2d --- /dev/null +++ b/tests/st/ops/test_rmsprop.py @@ -0,0 +1,152 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import numpy as np +import mindspore.context as context + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class NetRMSProp(nn.Cell): + def __init__(self, use_centered): + super(NetRMSProp, self).__init__() + self.use_centered = use_centered + if use_centered: + self.rms_opt = P.ApplyCenteredRMSProp() + else: + self.rms_opt = P.ApplyRMSProp() + + def construct(self, var, g, mg, rms, mom, lr, decay, momentum, epsilon): + if self.use_centered: + return self.rms_opt(var, mg, rms, mom, g, lr, decay, momentum, epsilon) + else: + return self.rms_opt(var, rms, mom, g, lr, decay, momentum, epsilon) + +def rmsprop_numpy(variable, gradients, mean_square, moment, + learning_rate, decay, momentum, epsilon): + mean_square = mean_square * decay + (1.0 - decay) * gradients * gradients + moment = momentum * moment + learning_rate / np.sqrt(mean_square + epsilon) * gradients + variable = variable - moment + +def rmspropcented_numpy(variable, gradients, mean_gradients, mean_square, moment, + learning_rate, decay, momentum, epsilon): + mean_gradients = mean_gradients * decay + (1.0 - decay) * gradients + mean_square = mean_square * decay + (1.0 - decay) * gradients * gradients + moment = momentum * moment + learning_rate / np.sqrt(mean_square -mean_gradients * mean_gradients + epsilon) * gradients + variable = variable - moment + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_rmsprop(): + learning_rate, decay, momentum, epsilon, centered = [0.5, 0.8, 0.9, 1e-3, True] + + variable_np = np.array([1.0, 2.0], dtype=np.float32) + gradients_np = np.array([0.1, 0.2], dtype=np.float32) + mean_gradients_np = np.array([0.0, 0.0], dtype=np.float32) + mean_square_np = np.array([epsilon, epsilon], dtype=np.float32) + moment_np = np.array([0.0, 0.0], dtype=np.float32) + + variable_ms = Tensor(variable_np) + gradients_ms = Tensor(gradients_np) + mean_gradients_ms = Tensor(mean_gradients_np) + mean_square_ms = Tensor(mean_square_np) + moment_ms = Tensor(moment_np) + + if centered: + rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, + learning_rate, decay, momentum, epsilon) + else: + rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, + learning_rate, decay, momentum, epsilon) + + net = NetRMSProp(centered) + _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, + moment_ms, learning_rate, decay, momentum, epsilon) + + error = np.ones(shape=variable_np.shape) * 10e-6 + diff = variable_ms.asnumpy() - variable_np + assert np.all(diff < error) + + error = np.ones(shape=gradients_np.shape) * 10e-6 + diff = gradients_ms.asnumpy() - gradients_np + assert np.all(diff < error) + + error = np.ones(shape=mean_gradients_np.shape) * 10e-6 + diff = mean_gradients_ms.asnumpy() - mean_gradients_np + assert np.all(diff < error) + + error = np.ones(shape=mean_square_np.shape) * 10e-6 + diff = mean_square_ms.asnumpy() - mean_square_np + assert np.all(diff < error) + + error = np.ones(shape=moment_np.shape) * 10e-6 + diff = moment_ms.asnumpy() - moment_np + assert np.all(diff < error) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_rmspropcenter(): + learning_rate, decay, momentum, epsilon, centered = [0.1, 0.3, 0.9, 1.0, False] + + variable_np = np.array([1.0, 2.0], dtype=np.float32) + gradients_np = np.array([0.1, 0.2], dtype=np.float32) + mean_gradients_np = np.array([0.0, 0.0], dtype=np.float32) + mean_square_np = np.array([epsilon, epsilon], dtype=np.float32) + moment_np = np.array([0.0, 0.0], dtype=np.float32) + + variable_ms = Tensor(variable_np) + gradients_ms = Tensor(gradients_np) + mean_gradients_ms = Tensor(mean_gradients_np) + mean_square_ms = Tensor(mean_square_np) + moment_ms = Tensor(moment_np) + + if centered: + rmspropcented_numpy(variable_np, gradients_np, mean_gradients_np, mean_square_np, moment_np, + learning_rate, decay, momentum, epsilon) + else: + rmsprop_numpy(variable_np, gradients_np, mean_square_np, moment_np, + learning_rate, decay, momentum, epsilon) + + net = NetRMSProp(centered) + _ = net(variable_ms, gradients_ms, mean_gradients_ms, mean_square_ms, moment_ms, + learning_rate, decay, momentum, epsilon) + + error = np.ones(shape=variable_np.shape) * 10e-6 + diff = variable_ms.asnumpy() - variable_np + assert np.all(diff < error) + + error = np.ones(shape=gradients_np.shape) * 10e-6 + diff = gradients_ms.asnumpy() - gradients_np + assert np.all(diff < error) + + error = np.ones(shape=mean_gradients_np.shape) * 10e-6 + diff = mean_gradients_ms.asnumpy() - mean_gradients_np + assert np.all(diff < error) + + error = np.ones(shape=mean_square_np.shape) * 10e-6 + diff = mean_square_ms.asnumpy() - mean_square_np + assert np.all(diff < error) + + error = np.ones(shape=moment_np.shape) * 10e-6 + diff = moment_ms.asnumpy() - moment_np + assert np.all(diff < error) \ No newline at end of file -- GitLab