提交 d74ea085 编写于 作者: A Adam 提交者: Tao Luo

Add relative error measure when (value > 1) (#21144)

* Add relative error measure when value > 1
test=develop

* Move code to CheckError function
test=develop
上级 3976bbe2
...@@ -92,6 +92,14 @@ void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) { ...@@ -92,6 +92,14 @@ void PrintConfig(const PaddlePredictor::Config *config, bool use_analysis) {
LOG(INFO) << analysis_config->ToNativeConfig(); LOG(INFO) << analysis_config->ToNativeConfig();
} }
void CheckError(float data_ref, float data) {
if (std::abs(data_ref) > 1) {
CHECK_LE(std::abs((data_ref - data) / data_ref), FLAGS_accuracy);
} else {
CHECK_LE(std::abs(data_ref - data), FLAGS_accuracy);
}
}
// Compare result between two PaddleTensor // Compare result between two PaddleTensor
void CompareResult(const std::vector<PaddleTensor> &outputs, void CompareResult(const std::vector<PaddleTensor> &outputs,
const std::vector<PaddleTensor> &ref_outputs) { const std::vector<PaddleTensor> &ref_outputs) {
...@@ -118,7 +126,7 @@ void CompareResult(const std::vector<PaddleTensor> &outputs, ...@@ -118,7 +126,7 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
float *pdata = static_cast<float *>(out.data.data()); float *pdata = static_cast<float *>(out.data.data());
float *pdata_ref = static_cast<float *>(ref_out.data.data()); float *pdata_ref = static_cast<float *>(ref_out.data.data());
for (size_t j = 0; j < size; ++j) { for (size_t j = 0; j < size; ++j) {
CHECK_LE(std::abs(pdata_ref[j] - pdata[j]), FLAGS_accuracy); CheckError(pdata_ref[j], pdata[j]);
} }
break; break;
} }
...@@ -169,7 +177,7 @@ void CompareResult(const std::vector<PaddleTensor> &outputs, ...@@ -169,7 +177,7 @@ void CompareResult(const std::vector<PaddleTensor> &outputs,
float *pdata_ref = ref_out.data<float>(&place, &ref_size); float *pdata_ref = ref_out.data<float>(&place, &ref_size);
EXPECT_EQ(size, ref_size); EXPECT_EQ(size, ref_size);
for (size_t j = 0; j < size; ++j) { for (size_t j = 0; j < size; ++j) {
CHECK_LE(std::abs(pdata_ref[j] - pdata[j]), FLAGS_accuracy); CheckError(pdata_ref[j], pdata[j]);
} }
break; break;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册