From 9ee77b1f4138835dc13bb1bd3945320b5d5850ad Mon Sep 17 00:00:00 2001 From: ShenLiang Date: Thu, 17 Sep 2020 14:40:55 +0800 Subject: [PATCH] Fix elementwise_floordiv op (#27352) * fix floordiv --- .../operators/elementwise/elementwise_floordiv_op.h | 11 +++++++++-- .../tests/unittests/test_elementwise_floordiv_op.py | 7 +++++++ 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h index 5dc93740949..721c23e3830 100644 --- a/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h +++ b/paddle/fluid/operators/elementwise/elementwise_floordiv_op.h @@ -61,8 +61,15 @@ void elementwise_floor_div(const framework::ExecutionContext &ctx, const framework::Tensor *x, const framework::Tensor *y, framework::Tensor *z) { int axis = ctx.Attr("axis"); - ElementwiseComputeEx, DeviceContext, T>( - ctx, x, y, axis, FloorDivFunctor(), z); + auto x_dims = x->dims(); + auto y_dims = y->dims(); + if (x_dims.size() >= y_dims.size()) { + ElementwiseComputeEx, DeviceContext, T>( + ctx, x, y, axis, FloorDivFunctor(), z); + } else { + ElementwiseComputeEx, DeviceContext, T>( + ctx, x, y, axis, InverseFloorDivFunctor(), z); + } } template diff --git a/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py b/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py index f339081e31b..007affc1408 100644 --- a/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py +++ b/python/paddle/fluid/tests/unittests/test_elementwise_floordiv_op.py @@ -67,6 +67,13 @@ class TestElementwiseModOp_scalar(TestElementwiseModOp): self.out = np.floor_divide(self.x, self.y) +class TestElementwiseModOpInverse(TestElementwiseModOp): + def init_input_output(self): + self.x = np.random.uniform(0, 10000, [10]).astype(self.dtype) + self.y = np.random.uniform(0, 1000, [10, 10]).astype(self.dtype) + self.out = np.floor_divide(self.x, self.y) + + class TestFloorDivideOp(unittest.TestCase): def test_name(self): with fluid.program_guard(fluid.Program()): -- GitLab