diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu index 6f993945625acdbebc9fb635e567dd28d7c69c6d..19a1273cb3d17d8ebc78a65115ae0105c2b62b7c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cu @@ -24,6 +24,18 @@ __global__ void NormalKernel(int seed, curandState *globalState, T *output, size return; } +template +__global__ void UniformKernel(int seed, curandState *globalState, T *input1, size_t input_size_1, + T *input2, size_t input_size_2, T *output, size_t count) { + for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) { + input1[i] = (input_size_1 == 1 ? input1[0] : input1[i]); + input2[i] = (input_size_2 == 1 ? input2[0] : input2[i]); + curand_init(seed, i, 0, &globalState[i]); + output[i] = curand_uniform(&globalState[i]) * (input2[i] - input1[i]) + input1[i]; + } + return; +} + template void StandardNormal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream) { int RNG_seed = 0; @@ -38,5 +50,17 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si return; } +template +void UniformReal(int seed, curandState *globalState, T *input1, size_t input_size_1, + T *input2, size_t input_size_2, T *output, size_t count, cudaStream_t cuda_stream) { + seed = (seed == 0 ? time(NULL):seed); + UniformKernel<<>> + (seed, globalState, input1, input_size_1, input2, input_size_2, output, count); + return; +} + template void StandardNormal(int seed, int seed2, curandState *globalState, float *output, size_t count, cudaStream_t cuda_stream); +template void UniformReal(int seed, curandState *globalState, float *input1, size_t input_size_1, + float *input2, size_t input_size_2, float *output, size_t count, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh index b099ead9bf3ab756e7ee90b4f77c54ccb389be3b..f5699cee0ad6f6aa9676008b1cd8d26c14a4c546 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/random_op_impl.cuh @@ -23,4 +23,8 @@ template void StandardNormal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream); +template +void UniformReal(int seed, curandState *globalState, + T *input1, size_t input_size_1, T *input2, size_t input_size_2, + T *output, size_t count, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc index c72c271c522b4f568a3418dc646f2ab0575edd14..8dfd4eef08c80c41f805e87f6806c04e70f3a37f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.cc @@ -20,5 +20,12 @@ namespace mindspore { namespace kernel { MS_REG_GPU_KERNEL_ONE(StandardNormal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32), RandomOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE(UniformReal, + KernelAttr() + .AddInputAttr(kNumberTypeInt32) + .AddInputAttr(kNumberTypeFloat32) + .AddInputAttr(kNumberTypeFloat32) + .AddOutputAttr(kNumberTypeFloat32), + RandomOpGpuKernel, float) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h index c77339f7658954f86658d11a64b1931d2f43822a..98a421c92279627436c3a74aeaf7661b78dfac55 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/random_op_gpu_kernel.h @@ -28,17 +28,22 @@ namespace mindspore { namespace kernel { -enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_INVALID_TYPE = 255 }; +enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_UNIFORM_REAL, RANDOM_OP_INVALID_TYPE = 255 }; -const std::map kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}}; +const std::map kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}, + {"UniformReal", RANDOM_OP_UNIFORM_REAL}}; template class RandomOpGpuKernel : public GpuKernel { public: RandomOpGpuKernel() : random_op_type_(RANDOM_OP_INVALID_TYPE), - input_size_0_(0), + input_size_0_(sizeof(int)), + input_size_1_(sizeof(T)), + input_size_2_(sizeof(T)), output_size_(sizeof(T)), - workspace_size_(sizeof(curandState)) {} + workspace_size_(sizeof(curandState)), + seed_(0), + seed2_(0) {} ~RandomOpGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -57,12 +62,21 @@ class RandomOpGpuKernel : public GpuKernel { reinterpret_cast(stream_ptr)); break; } + case RANDOM_OP_UNIFORM_REAL: { + T *input_addr_1 = GetDeviceAddress(inputs, 1); + T *input_addr_2 = GetDeviceAddress(inputs, 2); + UniformReal(seed_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2, + inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T), + reinterpret_cast(stream_ptr)); + break; + } default: { MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported."; } } return true; } + bool Init(const CNodePtr &kernel_node) override { std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); auto iter = kRandomOpTypeMap.find(kernel_name); @@ -72,10 +86,14 @@ class RandomOpGpuKernel : public GpuKernel { random_op_type_ = iter->second; } size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node); - if (input_num != 1) { + if (random_op_type_ == RANDOM_OP_NORMAL && input_num != 1) { MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 1 input."; return false; } + if (random_op_type_ == RANDOM_OP_UNIFORM_REAL && input_num != 3) { + MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 3 inputs."; + return false; + } size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node); if (output_num != 1) { MS_LOG(ERROR) << "Output number is " << output_num << ", but random op needs 1 output."; @@ -86,13 +104,25 @@ class RandomOpGpuKernel : public GpuKernel { input_size_0_ += input_shape_0[i]; } input_size_0_ *= sizeof(int); + if (random_op_type_ == RANDOM_OP_UNIFORM_REAL) { + auto input_shape_1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1); + for (size_t i = 0; i < input_shape_1.size(); i++) { + input_size_1_ *= input_shape_1[i]; + } + auto input_shape_2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2); + for (size_t i = 0; i < input_shape_2.size(); i++) { + input_size_2_ *= input_shape_2[i]; + } + } auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); for (size_t i = 0; i < output_shape.size(); i++) { output_size_ *= output_shape[i]; workspace_size_ *= output_shape[i]; } seed_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed")); - seed2_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); + if (random_op_type_ == RANDOM_OP_NORMAL) { + seed2_ = GetValue(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2")); + } InitSizeLists(); return true; } @@ -100,6 +130,10 @@ class RandomOpGpuKernel : public GpuKernel { protected: void InitSizeLists() override { input_size_list_.push_back(input_size_0_); + if (random_op_type_ == RANDOM_OP_UNIFORM_REAL) { + input_size_list_.push_back(input_size_1_); + input_size_list_.push_back(input_size_2_); + } output_size_list_.push_back(output_size_); workspace_size_list_.push_back(workspace_size_); } @@ -107,6 +141,8 @@ class RandomOpGpuKernel : public GpuKernel { private: RandomOptype random_op_type_; size_t input_size_0_; + size_t input_size_1_; + size_t input_size_2_; size_t output_size_; size_t workspace_size_; int seed_; diff --git a/tests/st/ops/gpu/test_uniform_real.py b/tests/st/ops/gpu/test_uniform_real.py new file mode 100644 index 0000000000000000000000000000000000000000..8fa4b0eb0b75e1fbeda58f9ba2b5ce82c6ebff5d --- /dev/null +++ b/tests/st/ops/gpu/test_uniform_real.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common import dtype as mstype + +context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + +class Net(nn.Cell): + def __init__(self, shape, seed=0): + super(Net, self).__init__() + self.uniformreal = P.UniformReal(seed=seed) + self.shape = shape + + def construct(self, a, b): + return self.uniformreal(self.shape, a, b) + + +def test_net_1D(): + seed = 10 + shape = (3, 2, 4) + a = 0.0 + b = 1.0 + net = Net(shape, seed) + ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32) + output = net(ta, tb) + print(output.asnumpy()) + assert output.shape == (3, 2, 4)