diff --git a/nets/loss.py b/nets/loss.py index 43b54edc0cebcd5b101dbee6d4640a21491a2345..8de636c2224a21c6bd702069076017557df79d45 100644 --- a/nets/loss.py +++ b/nets/loss.py @@ -7,7 +7,7 @@ from nets.ious import box_ciou # 平滑标签 #---------------------------------------------------# def _smooth_labels(y_true, label_smoothing): - num_classes = K.shape(y_true)[-1] + num_classes = tf.cast(K.shape(y_true)[-1], dtype=K.floatx()) label_smoothing = K.constant(label_smoothing, dtype=K.floatx()) return y_true * (1.0 - label_smoothing) + label_smoothing / num_classes #---------------------------------------------------#