From 972ae6e98ffbddac7b68242f946934b07b275e01 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 9 Apr 2018 14:27:19 +0800 Subject: [PATCH] random selected rows value --- paddle/fluid/operators/uniform_random_op.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 87699362b2b..a50add9739d 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -24,7 +24,15 @@ template class CPUUniformRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* tensor = ctx.Output("Out"); + framework::Tensor* tensor(nullptr); + auto out_var = ctx.OutputVar("Out"); + if (out_var->IsType()) { + tensor = ctx.Output("Out"); + } else if (out_var->IsType()) { + tensor = ctx.Output("Out")->mutable_value(); + } else { + PADDLE_THROW("Only support LoDTensor and SelectedRows."); + } T* data = tensor->mutable_data(ctx.GetPlace()); unsigned int seed = static_cast(ctx.Attr("seed")); std::minstd_rand engine; @@ -36,6 +44,7 @@ class CPUUniformRandomKernel : public framework::OpKernel { static_cast(ctx.Attr("min")), static_cast(ctx.Attr("max"))); int64_t size = tensor->numel(); + VLOG(3) << "size = " << size; for (int64_t i = 0; i < size; ++i) { data[i] = dist(engine); } @@ -55,6 +64,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { "uniform_random's min must less then max"); auto& shape = ctx->Attrs().Get>("shape"); std::vector temp; + VLOG(3) << "shape.size() = " << shape.size(); temp.reserve(shape.size()); for (auto dim : shape) { temp.push_back(static_cast(dim)); -- GitLab