// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/operators/uniform_random_op.h" #include "paddle/fluid/platform/enforce.h" namespace paddle { namespace operators { template class CPURandintKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { std::vector new_shape; auto list_new_shape_tensor = ctx.MultiInput("ShapeTensorList"); if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) { if (ctx.HasInput("ShapeTensor")) { auto* shape_tensor = ctx.Input("ShapeTensor"); new_shape = GetNewDataFromShapeTensor(shape_tensor); } else if (list_new_shape_tensor.size() > 0) { new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor); } } auto* out = ctx.Output("Out"); if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape)); T* data = out->mutable_data(ctx.GetPlace()); int64_t size = out->numel(); std::random_device rd; std::mt19937 gen(rd()); std::uniform_int_distribution<> dist(ctx.Attr("low"), ctx.Attr("high") - 1); for (int64_t i = 0; i < size; ++i) data[i] = dist(gen); } }; class RandintOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE_EQ( ctx->HasOutput("Out"), true, platform::errors::InvalidArgument("Output(Out) of RandintOp is null.")); PADDLE_ENFORCE_LT( ctx->Attrs().Get("low"), ctx->Attrs().Get("high"), platform::errors::InvalidArgument("randint's low must less then high, " "but received: low = %d, high = %d.", ctx->Attrs().Get("low"), ctx->Attrs().Get("high"))); if (ctx->HasInputs("ShapeTensorList")) { // top prority shape auto inputs_name = ctx->Inputs("ShapeTensorList"); PADDLE_ENFORCE_GT( inputs_name.size(), 0, platform::errors::InvalidArgument( "Input(ShapeTensorList)'size of Op(randint) can't be zero." "Please check the Attr(shape)'s size of" "Op(fluid.layers.randint).)")); auto out_dims = std::vector(inputs_name.size(), -1); ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); return; } auto& shape = ctx->Attrs().Get>("shape"); if (ctx->HasInput("ShapeTensor") && shape.empty()) { auto shape_dims = ctx->GetInputDim("ShapeTensor"); PADDLE_ENFORCE_EQ(shape_dims.size(), 1, platform::errors::InvalidArgument( "ShapeError: Input(ShapeTensor)' dimension size of " "Op(randint) must be 1." "But received ShapeTensor's dimensions = %d.", shape_dims.size())); int num_ele = 1; for (int i = 0; i < shape_dims.size(); ++i) { num_ele *= shape_dims[i]; } auto vec_dims = std::vector(num_ele, -1); auto out_dims = framework::make_ddim(vec_dims); ctx->SetOutputDim("Out", out_dims); return; } PADDLE_ENFORCE_EQ(shape.empty(), false, platform::errors::InvalidArgument( "if there is no Input(ShapeTensorList) and no " "Input(ShapeTensor),the " "attr(shape) information must " "be set by Attr(shape).")); std::vector tensor_shape; tensor_shape.reserve(shape.size()); for (auto dim : shape) { tensor_shape.push_back(static_cast(dim)); } ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape)); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { return framework::OpKernelType( static_cast(ctx.Attr("dtype")), ctx.GetPlace()); } }; class RandintOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("ShapeTensor", "(Tensor or Tensor, optional) . If provided, " "randint" "according to " "this given shape. It means that it has a higher priority than " "Attr(shape) but a lower priority than Input(ShapeTensor).") .AsDispensable(); AddInput("ShapeTensorList", "(vector> or vector>, optional). " "If provided, randint use this. The shape of the tensor " "must be [1], it has the highest priority comparing with " "Input(ShapeTensor) and attr(shape).") .AsDuplicable() .AsDispensable(); AddOutput("Out", "The output tensor of randint op"); AddComment(R"DOC( This operator initializes a tensor with random integers sampled from a uniform distribution. The random result is in set [low, high). )DOC"); AddAttr>("shape", "The shape of the output tensor.") .SetDefault({}); AddAttr("low", "The lower bound on the range of random values to generate."); AddAttr("high", "The upper bound on the range of random values to generate."); AddAttr("dtype", "Output tensor data type. [Default INT64].") .SetDefault(framework::proto::VarType::INT64); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR( randint, ops::RandintOp, ops::RandintOpMaker, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker) REGISTER_OP_CPU_KERNEL(randint, ops::CPURandintKernel, ops::CPURandintKernel)