提交 4b921e2f 编写于 作者: littletomatodonkey's avatar littletomatodonkey

add opt doc

上级 2cdafa10
...@@ -29,9 +29,17 @@ def cosine_decay_with_warmup(learning_rate, ...@@ -29,9 +29,17 @@ def cosine_decay_with_warmup(learning_rate,
step_each_epoch, step_each_epoch,
epochs=500, epochs=500,
warmup_minibatch=1000): warmup_minibatch=1000):
"""Applies cosine decay to the learning rate. """
Applies cosine decay to the learning rate.
lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1) lr = 0.05 * (math.cos(epoch * (math.pi / 120)) + 1)
decrease lr for every mini-batch and start with warmup. decrease lr for every mini-batch and start with warmup.
args:
learning_rate(float): initial learning rate
step_each_epoch (int): number of step for each epoch in training process
epochs(int): number of training epochs
warmup_minibatch(int): number of minibatch for warmup
return:
lr(tensor): learning rate tensor
""" """
global_step = _decay_step_counter() global_step = _decay_step_counter()
lr = fluid.layers.tensor.create_global_var( lr = fluid.layers.tensor.create_global_var(
...@@ -65,6 +73,7 @@ def AdamDecay(params, parameter_list=None): ...@@ -65,6 +73,7 @@ def AdamDecay(params, parameter_list=None):
params(dict): the super parameters params(dict): the super parameters
parameter_list (list): list of Variable names to update to minimize loss parameter_list (list): list of Variable names to update to minimize loss
return: return:
optimizer: a Adam optimizer instance
""" """
base_lr = params['base_lr'] base_lr = params['base_lr']
beta1 = params['beta1'] beta1 = params['beta1']
...@@ -121,6 +130,7 @@ def RMSProp(params, parameter_list=None): ...@@ -121,6 +130,7 @@ def RMSProp(params, parameter_list=None):
params(dict): the super parameters params(dict): the super parameters
parameter_list (list): list of Variable names to update to minimize loss parameter_list (list): list of Variable names to update to minimize loss
return: return:
optimizer: a RMSProp optimizer instance
""" """
base_lr = params.get("base_lr", 0.001) base_lr = params.get("base_lr", 0.001)
l2_decay = params.get("l2_decay", 0.00005) l2_decay = params.get("l2_decay", 0.00005)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册