From c3852b08adabcde76e102b9a5792954935bca053 Mon Sep 17 00:00:00 2001 From: Weilong Wu Date: Fri, 29 Apr 2022 11:32:01 +0800 Subject: [PATCH] [Eager] Support test_label_smooth_functional switch to eager mode (#42366) --- .../fluid/tests/unittests/test_label_smooth_functional.py | 2 -- python/paddle/nn/functional/common.py | 8 ++++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_label_smooth_functional.py b/python/paddle/fluid/tests/unittests/test_label_smooth_functional.py index 83c8ced79b..54f5e64fda 100644 --- a/python/paddle/fluid/tests/unittests/test_label_smooth_functional.py +++ b/python/paddle/fluid/tests/unittests/test_label_smooth_functional.py @@ -19,8 +19,6 @@ import paddle.fluid.dygraph as dg import paddle.nn.functional as F import paddle.fluid.initializer as I import unittest -from paddle.fluid.framework import _enable_legacy_dygraph -_enable_legacy_dygraph() class LabelSmoothTestCase(unittest.TestCase): diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 907fd4e625..fe37b8fb97 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -1633,14 +1633,14 @@ def label_smooth(label, prior_dist=None, epsilon=0.1, name=None): #[[[0.03333334 0.93333334 0.03333334] # [0.93333334 0.03333334 0.93333334]]] """ + if epsilon > 1. or epsilon < 0.: + raise ValueError("The value of epsilon must be between 0 and 1.") + if in_dygraph_mode(): return _C_ops.final_state_label_smooth(label, prior_dist, float(epsilon)) - if epsilon > 1. or epsilon < 0.: - raise ValueError("The value of epsilon must be between 0 and 1.") - - if paddle.in_dynamic_mode(): + elif paddle.in_dynamic_mode(): return _C_ops.label_smooth(label, prior_dist, 'epsilon', float(epsilon)) check_variable_and_dtype(label, 'label', ['float32', 'float64'], -- GitLab