diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index ca7b2469010d57fc6bcb8b6cea8149fdbb091e58..724463c95c4a29fb5c00fe791b389d3908771640 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -12,67 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include -#include -#include -#include -#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/sampling_id_op.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; -template -class SamplingIdKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("X"); - const int batch_size = static_cast(input->dims()[0]); - const int width = static_cast(input->dims()[1]); - - PADDLE_ENFORCE_GE(batch_size, 0, - "batch_size(dims[0]) must be nonnegative."); - PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); - - std::vector ins_vector; - framework::TensorToVector(*input, context.device_context(), &ins_vector); - - unsigned int seed = static_cast(context.Attr("seed")); - std::minstd_rand engine; - if (seed == 0) { - seed = std::random_device()(); - } - engine.seed(seed); - std::uniform_real_distribution dist( - static_cast(context.Attr("min")), - static_cast(context.Attr("max"))); - - std::vector ids(batch_size); - for (size_t i = 0; i < batch_size; ++i) { - T r = dist(engine); - int idx = width - 1; - for (int j = 0; j < width; ++j) { - if ((r -= ins_vector[i * width + j]) < 0) { - idx = j; - break; - } - } - ids[i] = ins_vector[i * width + idx]; - } - - std::vector out_dim; - out_dim.push_back(static_cast(batch_size)); - - Tensor* output = context.Output("Out"); - output->Resize(framework::make_ddim(out_dim)); - output->mutable_data(context.GetPlace()); - framework::TensorFromVector(ids, context.device_context(), output); - } -}; - class SamplingIdOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index 114df044afcf2bef971ddd294aee7b2f4779aec4..a4f0470314d00b5e370fd478736b54579c88448c 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -11,83 +11,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -template -struct UniformGenerator { - T min_, max_; - unsigned int seed_; +#include "paddle/fluid/operators/sampling_id_op.h" - __host__ __device__ UniformGenerator(T min, T max, int seed) - : min_(min), max_(max), seed_(seed) {} - - __host__ __device__ T operator()(const unsigned int n) const { - thrust::minstd_rand rng; - rng.seed(seed_); - thrust::uniform_real_distribution dist(min_, max_); - rng.discard(n); - return dist(rng); - } -}; - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class SamplingIdGPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - const Tensor* input = context.Input("X"); - const int batch_size = static_cast(input->dims()[0]); - const int width = static_cast(input->dims()[1]); - - PADDLE_ENFORCE_GE(batch_size, 0, - "batch_size(dims[0]) must be nonnegative."); - PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); - - std::vector ins_vector; - framework::TensorToVector(*input, context.device_context(), &ins_vector); - - unsigned int seed = static_cast(context.Attr("seed")); - if (seed == 0) { - std::random_device rd; - seed = rd(); - } - T min = static_cast(context.Attr("min")); - T max = static_cast(context.Attr("max")); - UniformGenerator gen = UniformGenerator(min, max, seed); - - std::vector ids(batch_size); - for (size_t i = 0; i < batch_size; ++i) { - T r = gen(0); - int idx = width - 1; - for (int j = 0; j < width; ++j) { - if ((r -= ins_vector[i * width + j]) < 0) { - idx = j; - break; - } - } - ids[i] = ins_vector[i * width + idx]; - } - - std::vector out_dim; - out_dim.push_back(static_cast(batch_size)); - - Tensor* output = context.Output("Out"); - output->Resize(framework::make_ddim(out_dim)); - output->mutable_data(context.GetPlace()); - framework::TensorFromVector(ids, context.device_context(), output); - } -}; - -} // namespace operators -} // namespace paddle - -REGISTER_OP_CUDA_KERNEL(sampling_id, - paddle::operators::SamplingIdGPUKernel, - paddle::operators::SamplingIdGPUKernel); +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(sampling_id, paddle::operators::SamplingIdKernel, + paddle::operators::SamplingIdKernel); diff --git a/paddle/fluid/operators/sampling_id_op.h b/paddle/fluid/operators/sampling_id_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f730a9746da56ca82090122193ec54efb774483e --- /dev/null +++ b/paddle/fluid/operators/sampling_id_op.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class SamplingIdKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* input = context.Input("X"); + const int batch_size = static_cast(input->dims()[0]); + const int width = static_cast(input->dims()[1]); + + PADDLE_ENFORCE_GE(batch_size, 0, + "batch_size(dims[0]) must be nonnegative."); + PADDLE_ENFORCE_GE(width, 0, "width(dims[1]) must be nonnegative."); + + std::vector ins_vector; + framework::TensorToVector(*input, context.device_context(), &ins_vector); + + unsigned int seed = static_cast(context.Attr("seed")); + std::minstd_rand engine; + if (seed == 0) { + seed = std::random_device()(); + } + engine.seed(seed); + std::uniform_real_distribution dist( + static_cast(context.Attr("min")), + static_cast(context.Attr("max"))); + + std::vector ids(batch_size); + for (size_t i = 0; i < batch_size; ++i) { + T r = dist(engine); + int idx = width - 1; + for (int j = 0; j < width; ++j) { + if ((r -= ins_vector[i * width + j]) < 0) { + idx = j; + break; + } + } + ids[i] = ins_vector[i * width + idx]; + } + + std::vector out_dim; + out_dim.push_back(static_cast(batch_size)); + + Tensor* output = context.Output("Out"); + output->Resize(framework::make_ddim(out_dim)); + output->mutable_data(context.GetPlace()); + framework::TensorFromVector(ids, context.device_context(), output); + } +}; + +} // namespace operators +} // namespace paddle