提交 e7a99397 编写于 作者: P peixu_ren

Add random uniform real op at GPU end

上级 16079e63
......@@ -24,6 +24,18 @@ __global__ void NormalKernel(int seed, curandState *globalState, T *output, size
return;
}
template <typename T>
__global__ void UniformKernel(int seed, curandState *globalState, T *input1, size_t input_size_1,
T *input2, size_t input_size_2, T *output, size_t count) {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (count); i += blockDim.x * gridDim.x) {
input1[i] = (input_size_1 == 1 ? input1[0] : input1[i]);
input2[i] = (input_size_2 == 1 ? input2[0] : input2[i]);
curand_init(seed, i, 0, &globalState[i]);
output[i] = curand_uniform(&globalState[i]) * (input2[i] - input1[i]) + input1[i];
}
return;
}
template <typename T>
void StandardNormal(int seed, int seed2, curandState *globalState, T *output, size_t count, cudaStream_t cuda_stream) {
int RNG_seed = 0;
......@@ -38,5 +50,17 @@ void StandardNormal(int seed, int seed2, curandState *globalState, T *output, si
return;
}
template <typename T>
void UniformReal(int seed, curandState *globalState, T *input1, size_t input_size_1,
T *input2, size_t input_size_2, T *output, size_t count, cudaStream_t cuda_stream) {
seed = (seed == 0 ? time(NULL):seed);
UniformKernel<<<GET_BLOCKS(count), GET_THREADS, 0, cuda_stream>>>
(seed, globalState, input1, input_size_1, input2, input_size_2, output, count);
return;
}
template void StandardNormal<float>(int seed, int seed2, curandState *globalState,
float *output, size_t count, cudaStream_t cuda_stream);
template void UniformReal<float>(int seed, curandState *globalState, float *input1, size_t input_size_1,
float *input2, size_t input_size_2, float *output, size_t count,
cudaStream_t cuda_stream);
......@@ -23,4 +23,8 @@
template <typename T>
void StandardNormal(int seed, int seed2, curandState *globalState,
T *output, size_t count, cudaStream_t cuda_stream);
template <typename T>
void UniformReal(int seed, curandState *globalState,
T *input1, size_t input_size_1, T *input2, size_t input_size_2,
T *output, size_t count, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_RANDOMOPIMPL_H_
......@@ -20,5 +20,12 @@ namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(StandardNormal, KernelAttr().AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeFloat32),
RandomOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(UniformReal,
KernelAttr()
.AddInputAttr(kNumberTypeInt32)
.AddInputAttr(kNumberTypeFloat32)
.AddInputAttr(kNumberTypeFloat32)
.AddOutputAttr(kNumberTypeFloat32),
RandomOpGpuKernel, float)
} // namespace kernel
} // namespace mindspore
......@@ -28,17 +28,22 @@
namespace mindspore {
namespace kernel {
enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_INVALID_TYPE = 255 };
enum RandomOptype { RANDOM_OP_NORMAL = 0, RANDOM_OP_UNIFORM_REAL, RANDOM_OP_INVALID_TYPE = 255 };
const std::map<std::string, RandomOptype> kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL}};
const std::map<std::string, RandomOptype> kRandomOpTypeMap = {{"StandardNormal", RANDOM_OP_NORMAL},
{"UniformReal", RANDOM_OP_UNIFORM_REAL}};
template <typename T>
class RandomOpGpuKernel : public GpuKernel {
public:
RandomOpGpuKernel()
: random_op_type_(RANDOM_OP_INVALID_TYPE),
input_size_0_(0),
input_size_0_(sizeof(int)),
input_size_1_(sizeof(T)),
input_size_2_(sizeof(T)),
output_size_(sizeof(T)),
workspace_size_(sizeof(curandState)) {}
workspace_size_(sizeof(curandState)),
seed_(0),
seed2_(0) {}
~RandomOpGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
......@@ -57,12 +62,21 @@ class RandomOpGpuKernel : public GpuKernel {
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
case RANDOM_OP_UNIFORM_REAL: {
T *input_addr_1 = GetDeviceAddress<T>(inputs, 1);
T *input_addr_2 = GetDeviceAddress<T>(inputs, 2);
UniformReal(seed_, devStates, input_addr_1, inputs[1]->size / sizeof(T), input_addr_2,
inputs[2]->size / sizeof(T), output_addr, outputs[0]->size / sizeof(T),
reinterpret_cast<cudaStream_t>(stream_ptr));
break;
}
default: {
MS_LOG(EXCEPTION) << "Random operation " << random_op_type_ << " is not supported.";
}
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
auto iter = kRandomOpTypeMap.find(kernel_name);
......@@ -72,10 +86,14 @@ class RandomOpGpuKernel : public GpuKernel {
random_op_type_ = iter->second;
}
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 1) {
if (random_op_type_ == RANDOM_OP_NORMAL && input_num != 1) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 1 input.";
return false;
}
if (random_op_type_ == RANDOM_OP_UNIFORM_REAL && input_num != 3) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but random op needs 3 inputs.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
if (output_num != 1) {
MS_LOG(ERROR) << "Output number is " << output_num << ", but random op needs 1 output.";
......@@ -86,13 +104,25 @@ class RandomOpGpuKernel : public GpuKernel {
input_size_0_ += input_shape_0[i];
}
input_size_0_ *= sizeof(int);
if (random_op_type_ == RANDOM_OP_UNIFORM_REAL) {
auto input_shape_1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
for (size_t i = 0; i < input_shape_1.size(); i++) {
input_size_1_ *= input_shape_1[i];
}
auto input_shape_2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 2);
for (size_t i = 0; i < input_shape_2.size(); i++) {
input_size_2_ *= input_shape_2[i];
}
}
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
for (size_t i = 0; i < output_shape.size(); i++) {
output_size_ *= output_shape[i];
workspace_size_ *= output_shape[i];
}
seed_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed"));
seed2_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2"));
if (random_op_type_ == RANDOM_OP_NORMAL) {
seed2_ = GetValue<int>(AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("seed2"));
}
InitSizeLists();
return true;
}
......@@ -100,6 +130,10 @@ class RandomOpGpuKernel : public GpuKernel {
protected:
void InitSizeLists() override {
input_size_list_.push_back(input_size_0_);
if (random_op_type_ == RANDOM_OP_UNIFORM_REAL) {
input_size_list_.push_back(input_size_1_);
input_size_list_.push_back(input_size_2_);
}
output_size_list_.push_back(output_size_);
workspace_size_list_.push_back(workspace_size_);
}
......@@ -107,6 +141,8 @@ class RandomOpGpuKernel : public GpuKernel {
private:
RandomOptype random_op_type_;
size_t input_size_0_;
size_t input_size_1_;
size_t input_size_2_;
size_t output_size_;
size_t workspace_size_;
int seed_;
......
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
class Net(nn.Cell):
def __init__(self, shape, seed=0):
super(Net, self).__init__()
self.uniformreal = P.UniformReal(seed=seed)
self.shape = shape
def construct(self, a, b):
return self.uniformreal(self.shape, a, b)
def test_net_1D():
seed = 10
shape = (3, 2, 4)
a = 0.0
b = 1.0
net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
output = net(ta, tb)
print(output.asnumpy())
assert output.shape == (3, 2, 4)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册