From cb7bbf426c1be2d4a0989855f6440b0b8313f6b0 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Tue, 10 Apr 2018 14:28:35 +0800 Subject: [PATCH] revert uniform_random_op --- paddle/fluid/operators/uniform_random_op.cc | 13 +----- paddle/fluid/operators/uniform_random_op.cu | 13 +----- .../tests/unittests/test_uniform_random_op.py | 46 ++----------------- 3 files changed, 7 insertions(+), 65 deletions(-) diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index d8b38fb7eb..87699362b2 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -24,19 +24,8 @@ template class CPUUniformRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - framework::Tensor* tensor(nullptr); - auto out_var = ctx.OutputVar("Out"); - if (out_var->IsType()) { - tensor = ctx.Output("Out"); - } else if (out_var->IsType()) { - auto shape = ctx.Attr>("shape"); - tensor = ctx.Output("Out")->mutable_value(); - tensor->Resize(framework::make_ddim(shape)); - } else { - PADDLE_THROW("Only support LoDTensor and SelectedRows."); - } + auto* tensor = ctx.Output("Out"); T* data = tensor->mutable_data(ctx.GetPlace()); - data[0] = static_cast(1000); unsigned int seed = static_cast(ctx.Attr("seed")); std::minstd_rand engine; if (seed == 0) { diff --git a/paddle/fluid/operators/uniform_random_op.cu b/paddle/fluid/operators/uniform_random_op.cu index 115c859527..1232cd1eb3 100644 --- a/paddle/fluid/operators/uniform_random_op.cu +++ b/paddle/fluid/operators/uniform_random_op.cu @@ -43,18 +43,7 @@ template class GPUUniformRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - framework::Tensor* tensor(nullptr); - auto out_var = ctx.OutputVar("Out"); - if (out_var->IsType()) { - tensor = ctx.Output("Out"); - } else if (out_var->IsType()) { - auto shape = ctx.Attr>("shape"); - tensor = ctx.Output("Out")->mutable_value(); - tensor->Resize(framework::make_ddim(shape)); - } else { - PADDLE_THROW("Only support LoDTensor and SelectedRows."); - } - + auto* tensor = context.Output("Out"); T* data = tensor->mutable_data(context.GetPlace()); unsigned int seed = static_cast(context.Attr("seed")); if (seed == 0) { diff --git a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py index 3331e99c36..75ff85a55f 100644 --- a/python/paddle/fluid/tests/unittests/test_uniform_random_op.py +++ b/python/paddle/fluid/tests/unittests/test_uniform_random_op.py @@ -15,16 +15,6 @@ import unittest import numpy as np from op_test import OpTest -import paddle.fluid.core as core -from paddle.fluid.op import Operator - - -def output_hist(out): - hist, _ = np.histogram(out, range=(-5, 10)) - hist = hist.astype("float32") - hist /= float(out.size) - prob = 0.1 * np.ones((10)) - return hist, prob class TestUniformRandomOp(OpTest): @@ -43,37 +33,11 @@ class TestUniformRandomOp(OpTest): self.check_output_customized(self.verify_output) def verify_output(self, outs): - hist, prob = output_hist(outs[0]) - self.assertTrue( - np.allclose( - hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) - - -class TestUniformRandomOpSelectedRows(unittest.TestCase): - def get_places(self): - places = [core.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(core.CUDAPlace(0)) - return places - - def test_check_output(self): - for place in self.get_places(): - self.check_with_place(place) - - def check_with_place(self, place): - scope = core.Scope() - out = scope.var("X").get_selected_rows() - - op = Operator( - "uniform_random", - Out="X", - shape=[1000, 784], - min=-5.0, - max=10.0, - seed=10) - op.run(scope, place) - out_tensor = out.get_tensor() - hist, prob = output_hist(np.array(out_tensor)) + tensor = outs[0] + hist, _ = np.histogram(outs[0], range=(-5, 10)) + hist = hist.astype("float32") + hist /= float(outs[0].size) + prob = 0.1 * np.ones((10)) self.assertTrue( np.allclose( hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) -- GitLab