From 2b61db07d1300165e28596a8595bfef56265f37a Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Wed, 17 Apr 2019 14:56:03 +0800 Subject: [PATCH] fix sampling id op bug (#16909) * fix sampling id op bug, test=develop --- paddle/fluid/operators/sampling_id_op.cc | 8 +++---- .../tests/unittests/test_sampling_id_op.py | 22 +++++++++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index a4f41a17042..36712a8d06d 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -28,15 +28,15 @@ class SamplingIdOp : public framework::OperatorWithKernel { "Input(X) of SamplingIdOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SamplingIdOp should not be null."); - PADDLE_ENFORCE( - ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), - "min must less then max"); + PADDLE_ENFORCE_LT(ctx->Attrs().Get("min"), + ctx->Attrs().Get("max"), "min must less then max"); auto input_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE(input_dims.size() == 2, "Input(X, Filter) should be 2-D tensor."); - framework::DDim dims = input_dims; + auto dim0 = input_dims[0]; + framework::DDim dims = framework::make_ddim({dim0}); ctx->SetOutputDim("Out", dims); ctx->ShareLoD("X", "Out"); } diff --git a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py index 674ef2ddf44..0c784d3e49d 100644 --- a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py +++ b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py @@ -17,6 +17,7 @@ import numpy as np from op_test import OpTest import paddle.fluid.core as core +import paddle.fluid as fluid from paddle.fluid.op import Operator @@ -57,5 +58,26 @@ class TestSamplingIdOp(OpTest): pass +class TestSamplingIdShape(unittest.TestCase): + def test_shape(self): + x = fluid.layers.data(name='x', shape=[3], dtype='float32') + output = fluid.layers.sampling_id(x) + + place = fluid.CPUPlace() + exe = fluid.Executor(place=place) + exe.run(fluid.default_startup_program()) + + feed = { + 'x': np.array( + [[0.2, 0.3, 0.5], [0.2, 0.3, 0.4]], dtype='float32') + } + output_np = exe.run(feed=feed, fetch_list=[output])[0] + + self.assertEqual(output.shape[0], -1) + self.assertEqual(len(output.shape), 1) + self.assertEqual(output_np.shape[0], 2) + self.assertEqual(len(output_np.shape), 1) + + if __name__ == "__main__": unittest.main() -- GitLab