diff --git a/mindspore/lite/src/ops/space_to_depth.cc b/mindspore/lite/src/ops/space_to_depth.cc index e0df66b25fd5464af9e040981b1343b56e4714ad..79f3f175b3d3f2c4e055e24edb6176ab9cc177e8 100644 --- a/mindspore/lite/src/ops/space_to_depth.cc +++ b/mindspore/lite/src/ops/space_to_depth.cc @@ -58,9 +58,9 @@ int SpaceToDepth::InferShape(std::vector inputs, std::ve } int32_t block_size = GetBlockSize(); - if (input_shape[NHWC_C] % (block_size * block_size) != 0 || input_shape[NHWC_C] == 0) { - MS_LOG(ERROR) << "input dimension c size " << input_shape[NHWC_C] << " should be mulitple of block_size(" - << block_size << ") * block_size)!"; + if (input_shape[NHWC_H] % block_size != 0 || input_shape[NHWC_H] == 0 || input_shape[NHWC_W] % block_size != 0 || + input_shape[NHWC_W] == 0) { + MS_LOG(ERROR) << "input dimension h or w size error!"; return 1; } std::vector output_shape(input_shape.size());