From 2c5a007a92113f2f0b1a472606e86c637053d2e3 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 18 Oct 2021 10:56:33 +0800 Subject: [PATCH] fix(mge/optimizer): allow lr to be 0 GitOrigin-RevId: dabd1fcc3390787988553abded8c572bfccf1957 --- imperative/python/megengine/optimizer/adadelta.py | 2 +- imperative/python/megengine/optimizer/adagrad.py | 2 +- imperative/python/megengine/optimizer/adam.py | 2 +- imperative/python/megengine/optimizer/adamw.py | 2 +- imperative/python/megengine/optimizer/sgd.py | 8 ++++---- imperative/python/test/integration/test_trace_dump.py | 6 ------ 6 files changed, 8 insertions(+), 14 deletions(-) diff --git a/imperative/python/megengine/optimizer/adadelta.py b/imperative/python/megengine/optimizer/adadelta.py index 5544cce99..73041f8ea 100644 --- a/imperative/python/megengine/optimizer/adadelta.py +++ b/imperative/python/megengine/optimizer/adadelta.py @@ -63,7 +63,7 @@ class Adadelta(Optimizer): eps = param_group["eps"] def make_scalar(val): - return tensor(val) + return tensor(val, dtype="float32") # since `conver_inputs` is disabled for param updates, # scalar should be explicitly tansforred to tensor diff --git a/imperative/python/megengine/optimizer/adagrad.py b/imperative/python/megengine/optimizer/adagrad.py index 43708ae7e..6b7479b86 100644 --- a/imperative/python/megengine/optimizer/adagrad.py +++ b/imperative/python/megengine/optimizer/adagrad.py @@ -62,7 +62,7 @@ class Adagrad(Optimizer): eps = param_group["eps"] def make_scalar(val): - return tensor(val) + return tensor(val, dtype="float32") # since `conver_inputs` is disabled for param updates, # scalar should be explicitly tansforred to tensor diff --git a/imperative/python/megengine/optimizer/adam.py b/imperative/python/megengine/optimizer/adam.py index 794bdd948..368550346 100644 --- a/imperative/python/megengine/optimizer/adam.py +++ b/imperative/python/megengine/optimizer/adam.py @@ -61,7 +61,7 @@ class Adam(Optimizer): beta0, beta1 = param_group["betas"] def make_scalar(val): - return tensor(val) + return tensor(val, dtype="float32") # since `conver_inputs` is disabled for param updates, # scalar should be explicitly tansforred to tensor diff --git a/imperative/python/megengine/optimizer/adamw.py b/imperative/python/megengine/optimizer/adamw.py index cdbe96636..ddf731350 100644 --- a/imperative/python/megengine/optimizer/adamw.py +++ b/imperative/python/megengine/optimizer/adamw.py @@ -61,7 +61,7 @@ class AdamW(Optimizer): beta0, beta1 = param_group["betas"] def make_scalar(val): - return tensor(val) + return tensor(val, dtype="float32") # since `conver_inputs` is disabled for param updates, # scalar should be explicitly tansforred to tensor diff --git a/imperative/python/megengine/optimizer/sgd.py b/imperative/python/megengine/optimizer/sgd.py index a4d122820..95a3dd1c9 100644 --- a/imperative/python/megengine/optimizer/sgd.py +++ b/imperative/python/megengine/optimizer/sgd.py @@ -62,13 +62,13 @@ class SGD(Optimizer): # since `conver_inputs` is disabled for param updates, # scalar should be explicitly tansforred to tensor - _lr = tensor(lr) - _weight_decay = tensor(weight_decay) - _momentum = tensor(momentum) + _lr = tensor(lr, dtype="float32") + _weight_decay = tensor(weight_decay, dtype="float32") + _momentum = tensor(momentum, dtype="float32") inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) if inplace_mode: - _neg_lr = tensor(-lr) + _neg_lr = tensor(-lr, dtype="float32") c1 = tensor([1.0]) for param in param_group["params"]: diff --git a/imperative/python/test/integration/test_trace_dump.py b/imperative/python/test/integration/test_trace_dump.py index 085cc396a..d7deeab8a 100644 --- a/imperative/python/test/integration/test_trace_dump.py +++ b/imperative/python/test/integration/test_trace_dump.py @@ -133,12 +133,6 @@ def test_xornet_trace_dump(): data = tensor(test_data.astype(np.float32)) out = pred_fun(data) - pred_output = out.numpy() - pred_label = np.argmax(pred_output, 1) - - with np.printoptions(precision=4, suppress=True): - print("Predicated probability:") - print(pred_output) with mkstemp() as out: pred_fun.dump(out, arg_names=["data"], output_names=["label"]) -- GitLab