未验证 提交 d2939cab 编写于 作者: N Nyakku Shigure 提交者: GitHub

[Dy2St] fix train step random failed on Windows (#52580)

上级 a62de41a
......@@ -58,6 +58,7 @@ class TestTrainStepTinyModel(unittest.TestCase):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 5
self.rtol = 1e-4
def get_train_step_losses(self, func, steps):
losses = []
......@@ -87,7 +88,9 @@ class TestTrainStepTinyModel(unittest.TestCase):
for dygraph_loss, static_loss in zip(dygraph_losses, static_losses):
dygraph_loss = dygraph_loss.numpy()
static_loss = static_loss.numpy()
np.testing.assert_allclose(dygraph_loss, static_loss, rtol=1e-4)
np.testing.assert_allclose(
dygraph_loss, static_loss, rtol=self.rtol
)
class TestTrainStepTinyModelAdadelta(TestTrainStepTinyModel):
......@@ -99,6 +102,7 @@ class TestTrainStepTinyModelAdadelta(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelAdagrad(TestTrainStepTinyModel):
......@@ -110,6 +114,7 @@ class TestTrainStepTinyModelAdagrad(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelAdam(TestTrainStepTinyModel):
......@@ -121,6 +126,7 @@ class TestTrainStepTinyModelAdam(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelAdamax(TestTrainStepTinyModel):
......@@ -132,6 +138,7 @@ class TestTrainStepTinyModelAdamax(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelAdamW(TestTrainStepTinyModel):
......@@ -143,6 +150,7 @@ class TestTrainStepTinyModelAdamW(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLamb(TestTrainStepTinyModel):
......@@ -156,6 +164,7 @@ class TestTrainStepTinyModelLamb(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelMomentum(TestTrainStepTinyModel):
......@@ -167,6 +176,7 @@ class TestTrainStepTinyModelMomentum(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelRMSProp(TestTrainStepTinyModel):
......@@ -178,6 +188,7 @@ class TestTrainStepTinyModelRMSProp(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRNoamDecay(TestTrainStepTinyModel):
......@@ -191,6 +202,7 @@ class TestTrainStepTinyModelLRNoamDecay(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRPiecewiseDecay(TestTrainStepTinyModel):
......@@ -206,6 +218,7 @@ class TestTrainStepTinyModelLRPiecewiseDecay(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRNaturalExpDecay(TestTrainStepTinyModel):
......@@ -221,6 +234,7 @@ class TestTrainStepTinyModelLRNaturalExpDecay(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRInverseTimeDecay(TestTrainStepTinyModel):
......@@ -234,6 +248,7 @@ class TestTrainStepTinyModelLRInverseTimeDecay(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRPolynomialDecay(TestTrainStepTinyModel):
......@@ -250,6 +265,7 @@ class TestTrainStepTinyModelLRPolynomialDecay(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRLinearWarmup(TestTrainStepTinyModel):
......@@ -267,6 +283,7 @@ class TestTrainStepTinyModelLRLinearWarmup(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRExponentialDecay(TestTrainStepTinyModel):
......@@ -280,6 +297,7 @@ class TestTrainStepTinyModelLRExponentialDecay(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRMultiStepDecay(TestTrainStepTinyModel):
......@@ -297,6 +315,7 @@ class TestTrainStepTinyModelLRMultiStepDecay(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRStepDecay(TestTrainStepTinyModel):
......@@ -313,6 +332,7 @@ class TestTrainStepTinyModelLRStepDecay(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRLambdaDecay(TestTrainStepTinyModel):
......@@ -328,6 +348,7 @@ class TestTrainStepTinyModelLRLambdaDecay(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRReduceOnPlateau(TestTrainStepTinyModel):
......@@ -344,6 +365,7 @@ class TestTrainStepTinyModelLRReduceOnPlateau(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRCosineAnnealingDecay(TestTrainStepTinyModel):
......@@ -359,6 +381,7 @@ class TestTrainStepTinyModelLRCosineAnnealingDecay(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRMultiplicativeDecay(TestTrainStepTinyModel):
......@@ -374,6 +397,7 @@ class TestTrainStepTinyModelLRMultiplicativeDecay(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLROneCycleLR(TestTrainStepTinyModel):
......@@ -387,6 +411,7 @@ class TestTrainStepTinyModelLROneCycleLR(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
class TestTrainStepTinyModelLRCyclicLR(TestTrainStepTinyModel):
......@@ -404,6 +429,7 @@ class TestTrainStepTinyModelLRCyclicLR(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
if __name__ == "__main__":
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
from test_train_step import (
......@@ -33,6 +34,9 @@ class TestTrainStepResNet18Adam(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
if platform.system() == 'Windows':
self.rtol = 1e-3
if __name__ == "__main__":
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
from test_train_step import (
......@@ -33,6 +34,9 @@ class TestTrainStepResNet18Sgd(TestTrainStepTinyModel):
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
self.rtol = 1e-4
if platform.system() == 'Windows':
self.rtol = 1e-3
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册