diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..1400f623b7efe34761344bcd145b6e2b07b46779 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.cc @@ -0,0 +1,26 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h" +namespace mindspore { +namespace kernel { +MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + OnesLikeGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + OnesLikeGpuKernel, half) +MS_REG_GPU_KERNEL_ONE(OnesLike, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), + OnesLikeGpuKernel, int) +} // namespace kernel +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..355ed9e05f1a880199c81a34115e7dc692bc2151 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/arrays/oneslike_gpu_kernel.h @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_ONESLIKE_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_ONESLIKE_H_ + +#include <vector> +#include "backend/kernel_compiler/gpu/gpu_kernel.h" +#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h" +#include "backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh" +namespace mindspore { +namespace kernel { +template <typename T> +class OnesLikeGpuKernel : public GpuKernel { + public: + OnesLikeGpuKernel() : input_size_(0), output_size_(0) {} + ~OnesLikeGpuKernel() override = default; + + const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; } + const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; } + const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; } + + bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &, + const std::vector<AddressPtr> &outputs, void *stream_ptr) override { + T *input = GetDeviceAddress<T>(inputs, 0); + T *output = GetDeviceAddress<T>(outputs, 0); + int size = SizeToInt(input_size_ / sizeof(T)); + CalOnesLike(size, input, output, reinterpret_cast<cudaStream_t>(stream_ptr)); + return true; + } + + bool Init(const CNodePtr &kernel_node) override { + size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); + if (input_num != 1) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but oneslike needs 1 input."; + return false; + } + size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); + if (output_num != 1) { + MS_LOG(ERROR) << "Output number is " << output_num << ", but oneslike needs 1 output."; + return false; + } + auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0); + size_t shape_size = input_shape.size(); + + input_size_ = 1; + for (size_t i = 0; i < shape_size; i++) { + input_size_ *= input_shape[i]; + } + input_size_ *= sizeof(T); + output_size_ = input_size_; + InitSizeLists(); + return true; + } + + protected: + void InitSizeLists() override { + input_size_list_.push_back(input_size_); + output_size_list_.push_back(output_size_); + return; + } + + private: + std::vector<size_t> input_size_list_; + std::vector<size_t> output_size_list_; + std::vector<size_t> workspace_size_list_; + size_t input_size_; + size_t output_size_; +}; +} // namespace kernel +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_ONESLIKE_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cu new file mode 100644 index 0000000000000000000000000000000000000000..dc1d9cd206e179dfe65cc8e1cadc51cfd9fa1656 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cu @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include <cuda_runtime.h> +#include "oneslike_impl.cuh" +#include "runtime/device/gpu/cuda_common.h" +template <typename T> +__global__ void OnesLike(const int size, const T* input, T* output) { + int one = 1; + T val = static_cast<T>(one); + for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + output[pos] = val; + } + return; +} +template <typename T> +void CalOnesLike(const int size, const T* input, T* output, cudaStream_t cuda_stream) { + OnesLike<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, input, output); + return; +} + +template void CalOnesLike<float>(const int size, const float* input, float* output, cudaStream_t cuda_stream); +template void CalOnesLike<half>(const int size, const half* input, half* output, cudaStream_t cuda_stream); +template void CalOnesLike<int>(const int size, const int* input, int* output, cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh new file mode 100644 index 0000000000000000000000000000000000000000..81e92c1d099f78282c67e8c61db2a86db1c10324 --- /dev/null +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/oneslike_impl.cuh @@ -0,0 +1,23 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_ +#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_ + +template <typename T> +void CalOnesLike(const int size, const T* input, T* output, cudaStream_t cuda_stream); + +#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_ONESLIKE_H_ diff --git a/tests/st/ops/gpu/test_oneslike_op.py b/tests/st/ops/gpu/test_oneslike_op.py new file mode 100644 index 0000000000000000000000000000000000000000..e721d157295974ca4333e971b8a575119c830a74 --- /dev/null +++ b/tests/st/ops/gpu/test_oneslike_op.py @@ -0,0 +1,85 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np +import pytest + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P + +context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + + +class NetOnesLike(nn.Cell): + def __init__(self): + super(NetOnesLike, self).__init__() + self.ones_like = P.OnesLike() + + def construct(self, x): + return self.ones_like(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_OnesLike(): + x0_np = np.random.uniform(-2, 2, (2, 3, 4, 4)).astype(np.float32) + x1_np = np.random.uniform(-2, 2, 1).astype(np.float16) + x2_np = np.zeros([3, 3, 3], dtype=np.int32) + + x0 = Tensor(x0_np) + x1 = Tensor(x1_np) + x2 = Tensor(x2_np) + + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + ones_like = NetOnesLike() + output0 = ones_like(x0) + expect0 = np.ones_like(x0_np) + diff0 = output0.asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + output1 = ones_like(x1) + expect1 = np.ones_like(x1_np) + diff1 = output1.asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + ones_like = NetOnesLike() + output0 = ones_like(x0) + expect0 = np.ones_like(x0_np) + diff0 = output0.asnumpy() - expect0 + error0 = np.ones(shape=expect0.shape) * 1.0e-5 + assert np.all(diff0 < error0) + assert output0.shape == expect0.shape + + output1 = ones_like(x1) + expect1 = np.ones_like(x1_np) + diff1 = output1.asnumpy() - expect1 + error1 = np.ones(shape=expect1.shape) * 1.0e-5 + assert np.all(diff1 < error1) + assert output1.shape == expect1.shape + + output2 = ones_like(x2) + expect2 = np.ones_like(x2_np) + diff2 = output2.asnumpy() - expect2 + error2 = np.ones(shape=expect2.shape) * 1.0e-5 + assert np.all(diff2 < error2) + assert output2.shape == expect2.shape