diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index 28d0faa47b24af493c5a23e7e7fdb6e90732c208..864c882a3eca7b6c1e2a7355830af8744699415a 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -457,11 +457,11 @@ static void TestComplexConvNxNS12(const std::vector &shape) { srand(time(NULL)); // generate random input - index_t batch = 1; + index_t batch = 3 + (rand() % 10); index_t height = shape[0]; index_t width = shape[1]; - index_t input_channels = shape[2]; - index_t output_channels = shape[3]; + index_t input_channels = shape[2] + (rand() % 10); + index_t output_channels = shape[3] + (rand() % 10); // Construct graph OpsTestNet net; OpDefBuilder("Conv2D", "Conv2dTest")