diff --git a/paddle/fluid/framework/ir/ipu/infer_shape_pass.cc b/paddle/fluid/framework/ir/ipu/infer_shape_pass.cc index f671d1b62a5692ddf5b662f443b0707435a9fb93..5c14849dd01c3635688879daa1721b29b1195ad9 100644 --- a/paddle/fluid/framework/ir/ipu/infer_shape_pass.cc +++ b/paddle/fluid/framework/ir/ipu/infer_shape_pass.cc @@ -44,7 +44,8 @@ void InferShapePass::ApplyImpl(ir::Graph* graph) const { feed_list.end(); if (is_feed) { auto input_shape = node->Var()->GetShape(); - if (input_shape[0] <= -1) { + // NOTE: some tensors may be 0-dim tensors + if (!input_shape.empty() && input_shape[0] <= -1) { input_shape[0] = micro_batch_size; node->Var()->SetShape(input_shape); need_infer_shape = true; diff --git a/python/paddle/fluid/tests/unittests/ipu/op_test_ipu.py b/python/paddle/fluid/tests/unittests/ipu/op_test_ipu.py index d3473bc2b0d4a1d2c0ba62acb00e2685c2006d04..9cb70d59d3e5931091e83f1f76d673f8aa77c11d 100644 --- a/python/paddle/fluid/tests/unittests/ipu/op_test_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/op_test_ipu.py @@ -245,6 +245,9 @@ class IPUOpTest(IPUTest): raise ValueError("output_dict is empty") cpu_fp32 = output_dict[ExecutionMode.CPU_FP32] ipu_fp32 = output_dict[ExecutionMode.IPU_FP32] + # Convert 0-dim tensor + if isinstance(cpu_fp32, np.ndarray) and cpu_fp32.shape == (): + cpu_fp32 = cpu_fp32.reshape(1) if len(cpu_fp32) != len(ipu_fp32): raise ValueError("different outputs number between ipu and cpu.") for cpu_fp32_res, ipu_fp32_res in zip(cpu_fp32, ipu_fp32): diff --git a/python/paddle/fluid/tests/unittests/ipu/test_transpose_op_ipu.py b/python/paddle/fluid/tests/unittests/ipu/test_transpose_op_ipu.py index f0e36fe6f4334fc08561c081e832f95d9dd05830..e8f3cae36739f915d39ffb3567e0b66c4feea933 100644 --- a/python/paddle/fluid/tests/unittests/ipu/test_transpose_op_ipu.py +++ b/python/paddle/fluid/tests/unittests/ipu/test_transpose_op_ipu.py @@ -83,6 +83,7 @@ class TestCase_ZeroDim(TestBase): def set_data_feed(self): data = np.random.uniform(size=[]) self.feed_fp32 = {"x": data.astype(np.float32)} + self.feed_fp16 = {"x": data.astype(np.float16)} def set_op_attrs(self): self.attrs = {"perm": []}