diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu index 5e7a25b8e64964257c2c14f18c31784edf47480f..4a8af83aa45e785f89a54201e4ed7978df953fcd 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cu @@ -60,6 +60,34 @@ __global__ void SquareKernel(T *input, T *output, size_t count) { return; } template +__global__ void SqrtKernel(T *input, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = sqrt(input[i]); + } + return; +} +template <> +__global__ void SqrtKernel(half *input, half *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = hsqrt(input[i]); + } + return; +} +template +__global__ void RsqrtKernel(T *input, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = rsqrt(input[i]); + } + return; +} +template <> +__global__ void RsqrtKernel(half *input, half *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + output[i] = hrsqrt(input[i]); + } + return; +} +template __global__ void ZeroslikeKernel(T *output, size_t count) { T zero = 0.0; for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { @@ -93,6 +121,21 @@ void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream) { return; } template +void Pow(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + PowKernel<<>>(input, output, count); + return; +} +template +void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + SqrtKernel<<>>(input, output, count); + return; +} +template +void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream) { + RsqrtKernel<<>>(input, output, count); + return; +} +template void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream) { ZeroslikeKernel<<>>(output, count); return; @@ -103,10 +146,14 @@ template void Logarithm(float *input, float *output, size_t count, cudaSt template void Negative(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Reciprocal(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Square(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Sqrt(float *input, float *output, size_t count, cudaStream_t cuda_stream); +template void Rsqrt(float *input, float *output, size_t count, cudaStream_t cuda_stream); template void Zeroslike(float *output, size_t count, cudaStream_t cuda_stream); template void Exponential(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Logarithm(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Negative(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Reciprocal(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Square(half *input, half *output, size_t count, cudaStream_t cuda_stream); +template void Sqrt(half *input, half *output, size_t count, cudaStream_t cuda_stream); +template void Rsqrt(half *input, half *output, size_t count, cudaStream_t cuda_stream); template void Zeroslike(half *output, size_t count, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh index 8ba9cb4a52688812e2d3aade5072ca405fcf3c85..623b1a8c03e2d7432db0fb9442b8f53645d2deb4 100755 --- a/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh +++ b/mindspore/ccsrc/kernel/gpu/cuda_impl/unary_op_impl.cuh @@ -29,6 +29,10 @@ void Reciprocal(T *input, T *output, size_t count, cudaStream_t cuda_stream); template void Square(T *input, T *output, size_t count, cudaStream_t cuda_stream); template +void Sqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template +void Rsqrt(T *input, T *output, size_t count, cudaStream_t cuda_stream); +template void Zeroslike(T *output, size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_UNARYOPIMPL_H_ diff --git a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc index bfdbe114224e782a67c9088a9647335d93f95707..77f53fc4173c3d4aa5dfc5051ce5d280e5aeb622 100644 --- a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.cc @@ -42,5 +42,9 @@ MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddO UnaryOpGpuKernel, float) MS_REG_GPU_KERNEL_ONE(Square, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), UnaryOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(Sqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(Rsqrt, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + UnaryOpGpuKernel, float) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h index d8fea7370b211b7712448f83718b61d9f6d9d653..6e011f6e3775fb91294b25a160dcfbe10d05c8d7 100644 --- a/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/unary_op_gpu_kernel.h @@ -34,6 +34,8 @@ enum UnaryOptype { UNARY_OP_RECIPROCAL, UNARY_OP_ZEROSLIKE, UNARY_OP_SQUARE, + UNARY_OP_SQRT, + UNARY_OP_RSQRT, UNARY_OP_INVALID_TYPE = 255 }; static const std::map kUnaryOpTypeMap = {{"Exp", UNARY_OP_EXP}, @@ -41,7 +43,9 @@ static const std::map kUnaryOpTypeMap = {{"Exp", UNARY {"Neg", UNARY_OP_NEG}, {"Reciprocal", UNARY_OP_RECIPROCAL}, {"ZerosLike", UNARY_OP_ZEROSLIKE}, - {"Square", UNARY_OP_SQUARE}}; + {"Square", UNARY_OP_SQUARE}, + {"Sqrt", UNARY_OP_SQRT}, + {"Rsqrt", UNARY_OP_RSQRT}}; template class UnaryOpGpuKernel : public GpuKernel { public: @@ -80,6 +84,14 @@ class UnaryOpGpuKernel : public GpuKernel { Square(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); break; } + case UNARY_OP_SQRT: { + Sqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } + case UNARY_OP_RSQRT: { + Rsqrt(input_addr, output_addr, inputs[0]->size / sizeof(T), reinterpret_cast(stream_ptr)); + break; + } case UNARY_OP_ZEROSLIKE: { Zeroslike(output_addr, output_size_ / sizeof(T), reinterpret_cast(stream_ptr)); return true; diff --git a/tests/st/ops/gpu/test_sqrt_op.py b/tests/st/ops/gpu/test_sqrt_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd9b757479c8e42ed5e61e28f23d4ee4796967c --- /dev/null +++ b/tests/st/ops/gpu/test_sqrt_op.py @@ -0,0 +1,38 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +from mindspore import Tensor +from mindspore.ops import operations as P +import mindspore.nn as nn +import numpy as np +import mindspore.context as context + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_sqrt(): + x_np = np.random.rand(2, 3, 4, 4).astype(np.float32) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + output_ms = P.Sqrt()(Tensor(x_np)) + output_np = np.sqrt(x_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Rsqrt()(Tensor(x_np)) + output_np = 1 / np.sqrt(x_np) + assert np.allclose(output_ms.asnumpy(), output_np) +