diff --git a/mace/kernels/pooling.h b/mace/kernels/pooling.h index f5aa606b2a671f19749c1df98a479311b300d73b..94a388bec7227d3d39f0caa60fddb30a13c059ae 100644 --- a/mace/kernels/pooling.h +++ b/mace/kernels/pooling.h @@ -373,10 +373,10 @@ struct PoolingFunctor: PoolingFunctorBase { MACE_UNUSED(future); MACE_CHECK(dilations_[0] == 1 && dilations_[1] == 1, "Quantized pooling does not support dilation > 1 yet."); - MACE_CHECK(input_tensor->scale() == output_tensor->scale(), - "Quantized pooling's input and output scale are not equal."); - MACE_CHECK(input_tensor->zero_point() == output_tensor->zero_point(), - "Quantized pooling's input and output zero_point are not equal"); + // Use the same scale and zero point with input and output. + output_tensor->SetScale(input_tensor->scale()); + output_tensor->SetZeroPoint(input_tensor->zero_point()); + std::vector output_shape(4); std::vector filter_shape = { input_tensor->dim(3), kernels_[0], kernels_[1], input_tensor->dim(3)}; diff --git a/mace/ops/depthwise_conv2d_test.cc b/mace/ops/depthwise_conv2d_test.cc index 0da8041d9b69a71a3148aba97ec748217398517e..a2d57911db1b7c136c87ec5c7b5ac6616f6ce289 100644 --- a/mace/ops/depthwise_conv2d_test.cc +++ b/mace/ops/depthwise_conv2d_test.cc @@ -497,11 +497,11 @@ void TestQuant(const index_t batch, TEST_F(DepthwiseConv2dOpTest, Quant) { QuantSimpleValidTest(); - TestQuant(1, 1, 2, 3, 3, 3, 3, VALID, {1, 1}); - TestQuant(1, 1, 2, 3, 3, 3, 3, SAME, {1, 1}); - TestQuant(1, 1, 2, 3, 3, 3, 3, FULL, {1, 1}); - TestQuant(1, 2, 2, 3, 3, 3, 3, SAME, {1, 1}); - TestQuant(1, 2, 2, 3, 3, 3, 3, SAME, {2, 2}); + TestQuant(1, 1, 1024, 7, 7, 3, 3, VALID, {1, 1}); + TestQuant(1, 1, 1024, 7, 7, 3, 3, SAME, {1, 1}); + TestQuant(1, 1, 1024, 7, 7, 3, 3, FULL, {1, 1}); + TestQuant(1, 2, 1024, 7, 7, 3, 3, SAME, {1, 1}); + TestQuant(1, 2, 1024, 7, 7, 3, 3, SAME, {2, 2}); TestQuant(1, 1, 512, 14, 14, 3, 3, SAME, {1, 1}); TestQuant(1, 1, 512, 14, 13, 5, 5, SAME, {2, 2}); TestQuant(1, 1, 256, 28, 28, 3, 3, SAME, {1, 1}); diff --git a/mace/ops/pooling_test.cc b/mace/ops/pooling_test.cc index af900f58ecc2a56d368b22ce0338ae723fec7da3..f0ac2572decce5df78832ba273b2a8d71480b50a 100644 --- a/mace/ops/pooling_test.cc +++ b/mace/ops/pooling_test.cc @@ -603,12 +603,7 @@ void TestQuant(const index_t batch, .AddIntArg("pooling_type", pooling) .AddIntArg("T", DT_UINT8) .Finalize(net.NewOperatorDef()); - net.Setup(DeviceType::CPU); - Tensor *q_input = net.GetTensor("QuantizedInput"); - Tensor *q_output = net.GetTensor("QuantizedOutput"); - q_output->SetScale(q_input->scale()); - q_output->SetZeroPoint(q_input->zero_point()); - net.Run(); + net.RunOp(); OpDefBuilder("Dequantize", "DeQuantizeTest") .Input("QuantizedOutput")