diff --git a/paddle/fluid/operators/sampling_id_op.cc b/paddle/fluid/operators/sampling_id_op.cc index a4f41a170426a4650fd3bf8f7fec4758ff34e1b9..36712a8d06d3e9a6f582f8296e2c0c4b4b302eb1 100644 --- a/paddle/fluid/operators/sampling_id_op.cc +++ b/paddle/fluid/operators/sampling_id_op.cc @@ -28,15 +28,15 @@ class SamplingIdOp : public framework::OperatorWithKernel { "Input(X) of SamplingIdOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SamplingIdOp should not be null."); - PADDLE_ENFORCE( - ctx->Attrs().Get("min") < ctx->Attrs().Get("max"), - "min must less then max"); + PADDLE_ENFORCE_LT(ctx->Attrs().Get("min"), + ctx->Attrs().Get("max"), "min must less then max"); auto input_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE(input_dims.size() == 2, "Input(X, Filter) should be 2-D tensor."); - framework::DDim dims = input_dims; + auto dim0 = input_dims[0]; + framework::DDim dims = framework::make_ddim({dim0}); ctx->SetOutputDim("Out", dims); ctx->ShareLoD("X", "Out"); } diff --git a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py index 674ef2ddf44edb4246c9d952cb75b36fe3d6ddc8..0c784d3e49d85f0b5750c5e6d7307be754b43ab2 100644 --- a/python/paddle/fluid/tests/unittests/test_sampling_id_op.py +++ b/python/paddle/fluid/tests/unittests/test_sampling_id_op.py @@ -17,6 +17,7 @@ import numpy as np from op_test import OpTest import paddle.fluid.core as core +import paddle.fluid as fluid from paddle.fluid.op import Operator @@ -57,5 +58,26 @@ class TestSamplingIdOp(OpTest): pass +class TestSamplingIdShape(unittest.TestCase): + def test_shape(self): + x = fluid.layers.data(name='x', shape=[3], dtype='float32') + output = fluid.layers.sampling_id(x) + + place = fluid.CPUPlace() + exe = fluid.Executor(place=place) + exe.run(fluid.default_startup_program()) + + feed = { + 'x': np.array( + [[0.2, 0.3, 0.5], [0.2, 0.3, 0.4]], dtype='float32') + } + output_np = exe.run(feed=feed, fetch_list=[output])[0] + + self.assertEqual(output.shape[0], -1) + self.assertEqual(len(output.shape), 1) + self.assertEqual(output_np.shape[0], 2) + self.assertEqual(len(output_np.shape), 1) + + if __name__ == "__main__": unittest.main()