提交 e4a7ca7f 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3637 Lowering value checking threshold to support training with very small steps or

Merge pull request !3637 from thlinh/dev_Jul28_lower_checking_threshold
...@@ -484,16 +484,18 @@ class PConstant : public PBase<PConstant<T> > { ...@@ -484,16 +484,18 @@ class PConstant : public PBase<PConstant<T> > {
TypeId tensor_type = tensor_ptr->Dtype()->type_id(); TypeId tensor_type = tensor_ptr->Dtype()->type_id();
if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) { if ((tensor_type == TypeId::kNumberTypeFloat32) || (tensor_type == TypeId::kNumberTypeFloat)) {
float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c()); float *data2 = reinterpret_cast<float *>(tensor_ptr->data_c());
auto threshold = FLT_EPSILON * FLT_EPSILON;
for (int i = 0; i < tensor_ptr->DataSize(); i++) { for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (fabs(data2[i] - check_value_) > FLT_EPSILON) { if (fabs(data2[i] - check_value_) > threshold) {
return false; return false;
} }
} }
return true; return true;
} else if (tensor_type == TypeId::kNumberTypeFloat64) { } else if (tensor_type == TypeId::kNumberTypeFloat64) {
double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c()); double *data2 = reinterpret_cast<double *>(tensor_ptr->data_c());
auto threshold = DBL_EPSILON * DBL_EPSILON;
for (int i = 0; i < tensor_ptr->DataSize(); i++) { for (int i = 0; i < tensor_ptr->DataSize(); i++) {
if (fabs(data2[i] - check_value_) > DBL_EPSILON) { if (fabs(data2[i] - check_value_) > threshold) {
return false; return false;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册