diff --git a/paddle/fluid/operators/controlflow/compare_op.cc b/paddle/fluid/operators/controlflow/compare_op.cc index 60f29ba39a8ee64f9fe5d95e685cac1fb52dfd21..4940649c2a32649a068c364081071ac840b4e25a 100644 --- a/paddle/fluid/operators/controlflow/compare_op.cc +++ b/paddle/fluid/operators/controlflow/compare_op.cc @@ -111,8 +111,16 @@ class CompareOp : public framework::OperatorWithKernel { framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); // CompareOp kernel's device type is decided by input tensor place bool force_cpu = ctx.Attr("force_cpu"); - kt.place_ = force_cpu ? platform::CPUPlace() - : ctx.Input("X")->place(); + if (force_cpu) { + kt.place_ = platform::CPUPlace(); + } else { + if (ctx.Input("X")->place().type() != + typeid(platform::CUDAPinnedPlace)) { + kt.place_ = ctx.Input("X")->place(); + } else { + kt.place_ = ctx.GetPlace(); + } + } return kt; } }; diff --git a/python/paddle/fluid/tests/unittests/test_compare_op.py b/python/paddle/fluid/tests/unittests/test_compare_op.py index cfad50409802d4f3d35c9da3b22597c681da91b1..25ae65aa7c968b2e6f1f9429d1a4e4e618fe7033 100644 --- a/python/paddle/fluid/tests/unittests/test_compare_op.py +++ b/python/paddle/fluid/tests/unittests/test_compare_op.py @@ -38,6 +38,7 @@ def create_test_class(op_type, typename, callback): self.check_output() def test_errors(self): + paddle.enable_static() with program_guard(Program(), Program()): x = fluid.layers.data(name='x', shape=[2], dtype='int32') y = fluid.layers.data(name='y', shape=[2], dtype='int32') @@ -80,6 +81,7 @@ def create_paddle_case(op_type, callback): self.place = paddle.CUDAPlace(0) def test_api(self): + paddle.enable_static() with program_guard(Program(), Program()): x = fluid.data(name='x', shape=[4], dtype='int64') y = fluid.data(name='y', shape=[4], dtype='int64') @@ -92,6 +94,7 @@ def create_paddle_case(op_type, callback): self.assertEqual((res == self.real_result).all(), True) def test_broadcast_api_1(self): + paddle.enable_static() with program_guard(Program(), Program()): x = paddle.static.data( name='x', shape=[1, 2, 1, 3], dtype='int32') @@ -108,6 +111,7 @@ def create_paddle_case(op_type, callback): self.assertEqual((res == real_result).all(), True) def test_attr_name(self): + paddle.enable_static() with program_guard(Program(), Program()): x = fluid.layers.data(name='x', shape=[4], dtype='int32') y = fluid.layers.data(name='y', shape=[4], dtype='int32') @@ -130,6 +134,7 @@ create_paddle_case('not_equal', lambda _a, _b: _a != _b) class TestCompareOpError(unittest.TestCase): def test_errors(self): + paddle.enable_static() with program_guard(Program(), Program()): # The input x and y of compare_op must be Variable. x = fluid.layers.data(name='x', shape=[1], dtype="float32") @@ -140,6 +145,7 @@ class TestCompareOpError(unittest.TestCase): class API_TestElementwise_Equal(unittest.TestCase): def test_api(self): + paddle.enable_static() with fluid.program_guard(fluid.Program(), fluid.Program()): label = fluid.layers.assign(np.array([3, 3], dtype="int32")) limit = fluid.layers.assign(np.array([3, 2], dtype="int32")) @@ -159,5 +165,31 @@ class API_TestElementwise_Equal(unittest.TestCase): self.assertEqual((res == np.array([True, True])).all(), True) +class TestCompareOpPlace(unittest.TestCase): + def test_place_1(self): + paddle.enable_static() + place = paddle.CPUPlace() + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + label = fluid.layers.assign(np.array([3, 3], dtype="int32")) + limit = fluid.layers.assign(np.array([3, 2], dtype="int32")) + out = fluid.layers.less_than(label, limit, force_cpu=True) + exe = fluid.Executor(place) + res, = exe.run(fetch_list=[out]) + self.assertEqual((res == np.array([False, False])).all(), True) + + def test_place_2(self): + place = paddle.CPUPlace() + data_place = place + if core.is_compiled_with_cuda(): + place = paddle.CUDAPlace(0) + data_place = paddle.CUDAPinnedPlace() + paddle.disable_static(place) + data = np.array([9], dtype="int64") + data_tensor = paddle.to_tensor(data, place=data_place) + result = data_tensor == 0 + self.assertEqual((result.numpy() == np.array([False])).all(), True) + + if __name__ == '__main__': unittest.main()