提交 2c5a007a 编写于 作者: M Megvii Engine Team

fix(mge/optimizer): allow lr to be 0

GitOrigin-RevId: dabd1fcc3390787988553abded8c572bfccf1957
上级 c50858ee
......@@ -63,7 +63,7 @@ class Adadelta(Optimizer):
eps = param_group["eps"]
def make_scalar(val):
return tensor(val)
return tensor(val, dtype="float32")
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
......
......@@ -62,7 +62,7 @@ class Adagrad(Optimizer):
eps = param_group["eps"]
def make_scalar(val):
return tensor(val)
return tensor(val, dtype="float32")
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
......
......@@ -61,7 +61,7 @@ class Adam(Optimizer):
beta0, beta1 = param_group["betas"]
def make_scalar(val):
return tensor(val)
return tensor(val, dtype="float32")
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
......
......@@ -61,7 +61,7 @@ class AdamW(Optimizer):
beta0, beta1 = param_group["betas"]
def make_scalar(val):
return tensor(val)
return tensor(val, dtype="float32")
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
......
......@@ -62,13 +62,13 @@ class SGD(Optimizer):
# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor
_lr = tensor(lr)
_weight_decay = tensor(weight_decay)
_momentum = tensor(momentum)
_lr = tensor(lr, dtype="float32")
_weight_decay = tensor(weight_decay, dtype="float32")
_momentum = tensor(momentum, dtype="float32")
inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0"))
if inplace_mode:
_neg_lr = tensor(-lr)
_neg_lr = tensor(-lr, dtype="float32")
c1 = tensor([1.0])
for param in param_group["params"]:
......
......@@ -133,12 +133,6 @@ def test_xornet_trace_dump():
data = tensor(test_data.astype(np.float32))
out = pred_fun(data)
pred_output = out.numpy()
pred_label = np.argmax(pred_output, 1)
with np.printoptions(precision=4, suppress=True):
print("Predicated probability:")
print(pred_output)
with mkstemp() as out:
pred_fun.dump(out, arg_names=["data"], output_names=["label"])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册