// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/uniform_random_op.h" namespace paddle { namespace operators { template class GPURandintKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { std::vector new_shape; auto list_new_shape_tensor = context.MultiInput("ShapeTensorList"); if (list_new_shape_tensor.size() > 0 || context.HasInput("ShapeTensor")) { if (context.HasInput("ShapeTensor")) { auto* shape_tensor = context.Input("ShapeTensor"); new_shape = GetNewDataFromShapeTensor(shape_tensor); } else if (list_new_shape_tensor.size() > 0) { new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor); } } platform::CPUPlace cpu; auto dtype = static_cast( context.Attr("dtype")); auto* out = context.Output("Out"); if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape)); T low = static_cast(context.Attr("low")); T high = static_cast(context.Attr("high")) - 1; framework::LoDTensor tensor; tensor.Resize(out->dims()); tensor.mutable_data(cpu, dtype); T* data = tensor.mutable_data(cpu); int64_t size = out->numel(); unsigned int seed = static_cast(context.Attr("seed")); std::minstd_rand engine; if (seed == 0) { std::random_device rd; seed = rd(); } engine.seed(seed); std::uniform_int_distribution<> dist(context.Attr("low"), context.Attr("high") - 1); for (int64_t i = 0; i < size; ++i) data[i] = dist(engine); if (platform::is_gpu_place(context.GetPlace())) { // Copy tensor to out framework::TensorCopy(tensor, context.GetPlace(), out); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(randint, ops::GPURandintKernel, ops::GPURandintKernel)