From 470fb7c5c39ad0f84baf15de67f618e8826b6d79 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 15 Aug 2018 18:15:16 +0800 Subject: [PATCH] bug fix --- paddle/fluid/operators/sampling_id_op.cc | 6 +++--- paddle/fluid/operators/sampling_id_op.cu | 7 ++++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index 2549758a8e..e88310745f 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -36,15 +36,15 @@ class SamplingIdKernel : public framework::OpKernel { std::vector ins_vector; framework::TensorToVector(*input, context.device_context(), &ins_vector); - unsigned int seed = static_cast(ctx.Attr("seed")); + unsigned int seed = static_cast(context.Attr("seed")); std::minstd_rand engine; if (seed == 0) { seed = std::random_device()(); } engine.seed(seed); std::uniform_real_distribution dist( - static_cast(ctx.Attr("min")), - static_cast(ctx.Attr("max"))); + static_cast(context.Attr("min")), + static_cast(context.Attr("max"))); std::vector ids(batch_size); for (size_t i = 0; i < batch_size; ++i) { diff --git a/paddle/fluid/operators/sampling_id_op.cu b/paddle/fluid/operators/sampling_id_op.cu index 791675b73b..b104710374 100644 --- a/paddle/fluid/operators/sampling_id_op.cu +++ b/paddle/fluid/operators/sampling_id_op.cu @@ -39,7 +39,7 @@ namespace operators { using Tensor = framework::Tensor; template -class SamplingIdKernel : public framework::OpKernel { +class SamplingIdGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { const Tensor* input = context.Input("X"); @@ -83,5 +83,6 @@ class SamplingIdKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CPU_KERNEL(sampling_id, paddle::operators::SamplingIdKernel, - paddle::operators::SamplingIdKernel); +REGISTER_OP_CUDA_KERNEL(sampling_id, + paddle::operators::SamplingIdGPUKernel, + paddle::operators::SamplingIdGPUKernel); -- GitLab