From 6bac3e17b5b1f9e6a0ebb34ff43e959a971ef111 Mon Sep 17 00:00:00 2001 From: dongzhihong Date: Wed, 9 Aug 2017 15:01:37 +0800 Subject: [PATCH] "remove unused test net modified" --- paddle/operators/gaussian_random_op.cc | 7 ++++--- paddle/operators/gaussian_random_op.cu | 5 ++--- .../v2/framework/tests/test_gaussian_random_op.py | 4 +++- python/paddle/v2/framework/tests/test_net.py | 12 ++++++------ 4 files changed, 15 insertions(+), 13 deletions(-) diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index b0b68ff36d3..ef417ae2f06 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -22,8 +22,8 @@ template class GaussianRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - T mean = static_cast(context.op_.GetAttr("mean")); - T std = static_cast(context.op_.GetAttr("std")); + float mean = context.op_.GetAttr("mean"); + float std = context.op_.GetAttr("std"); auto* tensor = context.Output(0); T* data = tensor->mutable_data(context.GetPlace()); @@ -35,7 +35,8 @@ class GaussianRandomKernel : public framework::OpKernel { } std::mt19937 g(seed); std::normal_distribution distribution(mean, std); - for (int i = 0; i < framework::product(tensor->dims()); ++i) { + ssize_t size = framework::product(tensor->dims()); + for (int i = 0; i < size; ++i) { data[i] = distribution(g); } } diff --git a/paddle/operators/gaussian_random_op.cu b/paddle/operators/gaussian_random_op.cu index 164753f946d..54e4ae5d2b2 100644 --- a/paddle/operators/gaussian_random_op.cu +++ b/paddle/operators/gaussian_random_op.cu @@ -26,8 +26,8 @@ template class GaussianRandomKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - T mean = static_cast(context.op_.GetAttr("mean")); - T std = static_cast(context.op_.GetAttr("std")); + float mean = context.op_.GetAttr("mean"); + float std = context.op_.GetAttr("std"); auto* tensor = context.Output(0); T* data = tensor->mutable_data(context.GetPlace()); @@ -40,7 +40,6 @@ class GaussianRandomKernel : public framework::OpKernel { &g, CURAND_RNG_PSEUDO_DEFAULT)); PADDLE_ENFORCE( platform::dynload::curandSetPseudoRandomGeneratorSeed(g, seed)); - // auto g = const_cast(ctx)->RandGenerator(); curandGenerateNormal(g, data, framework::product(tensor->dims()), mean, std); } diff --git a/python/paddle/v2/framework/tests/test_gaussian_random_op.py b/python/paddle/v2/framework/tests/test_gaussian_random_op.py index 0ff8c89a14b..20c68007b5c 100644 --- a/python/paddle/v2/framework/tests/test_gaussian_random_op.py +++ b/python/paddle/v2/framework/tests/test_gaussian_random_op.py @@ -14,13 +14,15 @@ class GaussianRandomTest(unittest.TestCase): def test_gaussian_random(self, place): scope = core.Scope() scope.new_var("Out").get_tensor() + op = Operator( "gaussian_random", Out="Out", dims=[1000, 784], mean=.0, std=1., - seed=0) + seed=10) + op.infer_shape(scope) context = core.DeviceContext.create(place) op.run(scope, context) diff --git a/python/paddle/v2/framework/tests/test_net.py b/python/paddle/v2/framework/tests/test_net.py index 7df9b997b19..b30896553de 100644 --- a/python/paddle/v2/framework/tests/test_net.py +++ b/python/paddle/v2/framework/tests/test_net.py @@ -16,13 +16,13 @@ class TestNet(unittest.TestCase): net.complete_add_op(True) expected = ''' - Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, Out, fc.out). - Op(add_two), inputs:(X, Y), outputs:(Out). - Op(plain_net), inputs:(@EMPTY@, X, w), outputs:(@TEMP@fc@0, fc.out). +Op(plain_net), inputs:(@EMPTY@, X, Y, w), outputs:(@TEMP@fc@0, Out, fc.out). + Op(add_two), inputs:(X, Y), outputs:(Out). + Op(plain_net), inputs:(@EMPTY@, X, w), outputs:(@TEMP@fc@0, fc.out). Op(fc), inputs:(X, w, @EMPTY@), outputs:(fc.out, @TEMP@fc@0). - Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0). - Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out). - ''' + Op(mul), inputs:(X, w), outputs:(@TEMP@fc@0). + Op(sigmoid), inputs:(@TEMP@fc@0), outputs:(fc.out). +''' self.assertEqual(expected, "\n" + str(net)) -- GitLab