未验证 提交 016766cc 编写于 作者: A Allen Guo 提交者: GitHub

fix runtime error (#47133)

上级 b438dff3
......@@ -44,7 +44,8 @@ void InferShapePass::ApplyImpl(ir::Graph* graph) const {
feed_list.end();
if (is_feed) {
auto input_shape = node->Var()->GetShape();
if (input_shape[0] <= -1) {
// NOTE: some tensors may be 0-dim tensors
if (!input_shape.empty() && input_shape[0] <= -1) {
input_shape[0] = micro_batch_size;
node->Var()->SetShape(input_shape);
need_infer_shape = true;
......
......@@ -245,6 +245,9 @@ class IPUOpTest(IPUTest):
raise ValueError("output_dict is empty")
cpu_fp32 = output_dict[ExecutionMode.CPU_FP32]
ipu_fp32 = output_dict[ExecutionMode.IPU_FP32]
# Convert 0-dim tensor
if isinstance(cpu_fp32, np.ndarray) and cpu_fp32.shape == ():
cpu_fp32 = cpu_fp32.reshape(1)
if len(cpu_fp32) != len(ipu_fp32):
raise ValueError("different outputs number between ipu and cpu.")
for cpu_fp32_res, ipu_fp32_res in zip(cpu_fp32, ipu_fp32):
......
......@@ -83,6 +83,7 @@ class TestCase_ZeroDim(TestBase):
def set_data_feed(self):
data = np.random.uniform(size=[])
self.feed_fp32 = {"x": data.astype(np.float32)}
self.feed_fp16 = {"x": data.astype(np.float16)}
def set_op_attrs(self):
self.attrs = {"perm": []}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册