diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index d13eeabcb96fcf8c4073184bc078186eb6b5f089..4929a7edc21a020ec688a6a00705712dd41ab477 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -12,18 +12,68 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/sampling_id_op.h" +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; +template +class SamplingIdKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("X"); + const int batch_size = static_cast(input->dims()[0]); + const int width = static_cast(input->dims()[1]); + + std::vector ins_vector; + framework::TensorToVector(*input, context.device_context(), &ins_vector); + + std::vector ids(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + double r = getRandReal(); + int idx = width - 1; + for (int j = 0; j < width; ++j) { + if ((r -= ins_vector[i * width + j]) < 0) { + idx = j; + break; + } + } + ids[i] = ins_vector[i * width + idx]; + } + + std::vector out_dim; + out_dim.push_back(static_cast(batch_size)); + + Tensor* output = context.Output("Out"); + output->Resize(framework::make_ddim(out_dim)); + output->mutable_data(context.GetPlace()); + framework::TensorFromVector(ids, context.device_context(), output); + } + + private: + double getRandReal() const { + std::random_device + rd; // Will be used to obtain a seed for the random number engine + std::mt19937 gen(rd()); // Standard mersenne_twister_engine seeded with + // rd() + std::uniform_real_distribution<> dis(1.0, 2.0); + return dis(gen); + } +}; + class SamplingIdOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext *ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SamplingIdOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h deleted file mode 100644 index 7f3ca8e761cc252b21a83e9be57d0639801734d7..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/sampling_id_op.h +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/framework/lod_tensor.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class SamplingIdKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("X"); - const int batch_size = static_cast(input->dims()[0]); - const int width = static_cast(input->dims()[1]); - - std::vector ins_vector; - framework::TensorToVector(*input, context.device_context(), &ins_vector); - - std::vector ids(batch_size); - for (size_t i = 0; i < batch_size; ++i) { - double r = this->getRandReal(); - int idx = width - 1; - for (int j = 0; j < width; ++j) { - if ((r -= ins_vector[i * width + j]) < 0) { - idx = j; - break; - } - } - ids[i] = ins_vector[i * width + idx]; - } - - std::vector out_dim; - out_dim.push_back(static_cast(batch_size)); - - Tensor* output = context.Output("Out"); - output->Resize(framework::make_ddim(out_dim)); - output->mutable_data(context.GetPlace()); - framework::TensorFromVector(ids, context.device_context(), output); - } - - private: - double getRandReal() const { - std::call_once(init_flag_, &SamplingIdKernel::getRndInstance); - return rnd(); - } - - static void getRndInstance() { - // Will be used to obtain a seed for the random number engine - std::random_device rd; - // Standard mersenne_twister_engine seeded with rd() - std::mt19937 gen(rd()); - std::uniform_real_distribution<> dis(0, 1); - rnd = std::bind(dis, std::ref(gen)); - } - - static std::once_flag init_flag_; - static std::function rnd; -}; -} // namespace operators -} // namespace paddle