diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index de3015cdf02467be13d8b99afe5d362fac42dd52..3235c9027f16ffa0b250beaf9f64073f620ace78 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -143,7 +143,7 @@ class OpsTestNet { template void AddRandomInput(const std::string &name, const std::vector &shape, - bool positive = false) { + bool positive = true) { Tensor *input = ws_.CreateTensor(name, GetDeviceAllocator(D), DataTypeToEnum::v()); input->Resize(shape); @@ -318,7 +318,8 @@ class OpsTestBase : public ::testing::Test { template void GenerateRandomRealTypeData(const std::vector &shape, - std::vector *res) { + std::vector *res, + bool positive = true) { MACE_CHECK_NOTNULL(res); std::random_device rd; @@ -331,9 +332,14 @@ void GenerateRandomRealTypeData(const std::vector &shape, if (DataTypeToEnum::value == DT_HALF) { std::generate(res->begin(), res->end(), - [&gen, &nd] { return half_float::half_cast(nd(gen)); }); + [&gen, &nd, positive] { + return half_float::half_cast( + positive ? std::abs(nd(gen)) : nd(gen)); + }); } else { - std::generate(res->begin(), res->end(), [&gen, &nd] { return nd(gen); }); + std::generate(res->begin(), res->end(), [&gen, &nd, positive] { + return positive ? std::abs(nd(gen)) : nd(gen); + }); } }