From e0ab2f71589a71e918a94dd307d18f9a54864199 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Sun, 5 Aug 2018 15:21:32 +0800 Subject: [PATCH] new sampling op --- paddle/fluid/operators/sampling_id_op.cc | 64 ++++++++++++++++++++++ paddle/fluid/operators/sampling_id_op.cu | 40 ++++++++++++++ paddle/fluid/operators/sampling_id_op.h | 68 ++++++++++++++++++++++++ 3 files changed, 172 insertions(+) create mode 100644 paddle/fluid/operators/sampling_id_op.cc create mode 100644 paddle/fluid/operators/sampling_id_op.cu create mode 100644 paddle/fluid/operators/sampling_id_op.h diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc new file mode 100644 index 000000000..20e3d4321 --- /dev/null +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -0,0 +1,64 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/sampling_id_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class SamplingIdOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of RowConvOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of RowConvOp should not be null."); + + auto input_dims = ctx->GetInputDim("X"); + + framework::DDim dims = input_dims; + ctx->SetOutputDim("Out", dims); + ctx->ShareLoD("X", "Out"); + } +}; + +class SamplingIdOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "The input tensor of softmax. " + "2-D with shape [batch_size, input_feature_dimensions]."); + AddOutput("Out", "Sliced data tensor."); + + AddComment(R"DOC( +SamplingId Operator. + @brief A layer for sampling id from multinomial distribution from the + input layer. Sampling one id for one sample. The result is stored in + output_.ids. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + slice, ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel); diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu new file mode 100644 index 000000000..4fa10de2c --- /dev/null +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -0,0 +1,40 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/operators/sampling_id_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class SamplingIdOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext *ctx) const override {} +} +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(samplingid, ops::SamplingIdOp, ops::SamplingIdOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL( + slice, ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel, + ops::SamplingIdKernel); diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h new file mode 100644 index 000000000..eeb72d8f7 --- /dev/null +++ b/paddle/fluid/operators/sampling_id_op.h @@ -0,0 +1,68 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class SamplingIdKernel : public framework::OpKernel { + /// Produces random floating-point values, uniformly distributed on [0, 1). + std::uniform_real_distribution rand1_; + + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("X"); + const int batch_size = static_cast(input->dims()[0]); + const int width = static_cast(input->dims()[1]); + + std::vector ids(batchSize); + auto& reng = get(); + + for (size_t i = 0; i < batchSize; ++i) { + double r = rand1_(reng); + int id = dim - 1; + for (int j = 0; j < dim; ++j) { + if ((r -= buf[i * dim + j]) < 0) { + id = j; + break; + } + } + ids[i] = id; + } + + std::vector out_dim; + out_dim.push_back(static_cast(batch_size)); + + Tensor* output = context.Output("Output"); + output->Resize(framework::make_ddim(in_dim)); + output->mutable_data(context.GetPlace()); + framework::TensorFromVector(ids, context.device_context(), output); + } + + std::default_random_engine& get() { + auto engine = new std::default_random_engine; + engine->seed(defaultSeed); + return *engine; + } + + private: + unsigned int defaultSeed = 0; +} +} // namespace operators +} // namespace paddle -- GitLab