未验证 提交 7728efb4 编写于 作者: N Nyakku Shigure 提交者: GitHub

[Dy2St] support train step in to_static (#51693)

Co-authored-by: Nxiongkun <xiongkun03@baidu.com>
上级 15aa73df
......@@ -548,7 +548,9 @@ class IpuDynamicPatcher:
self._caches[item_id] = (
concrete_program,
partial_program_from(concrete_program),
partial_program_from(
concrete_program, item.class_instance is not None
),
)
# Note: raise warnings if number of traced program is more than `max_tracing_count`
current_tracing_count = len(self._caches)
......
......@@ -44,7 +44,7 @@ __all__ = [
'to_variable',
]
# Flag that indicates whether running code under `@to_static`
NON_PERSISTABLE_VAR_NAME_SUFFIX = "__non_persistable"
def in_declarative_mode():
......@@ -143,17 +143,16 @@ def _convert_into_variable(tensor):
# and necessary for inferring. It will be pruned if it's not necessary for inferring.
# But if its shape is empty while created from `create_variable()`, we consider this buffer
# non-persistable. See case of `drop_state` in lstm api.
is_persistable = len(tensor.shape) > 0
# non-persistable. See case of `dropout_state` in lstm api.
is_persistable = True
if tensor.name.endswith(NON_PERSISTABLE_VAR_NAME_SUFFIX):
is_persistable = False
new_var = tensor._to_static_var(
to_parameter=False, persistable=is_persistable
)
# add param into parameter recorder to collect all the params used in this program.
if new_var.persistable is True:
# TODO(@xiongkun): 0d-tensor may be affected at present,
# but there is no particularly good method to identify whether 0d-tensor
# is used as buffer or "drop_out_state" in LSTM buffer variable.
from paddle.jit.dy2static.program_translator import (
ProgramTranslator,
)
......
......@@ -1620,7 +1620,7 @@ class Variable(metaclass=VariableMetaClass):
"""
pass
@fake_interface_only
@non_static_only
def backward(self, retain_graph=False):
"""
**Notes**:
......@@ -1657,7 +1657,17 @@ class Variable(metaclass=VariableMetaClass):
loss.backward()
"""
pass
from .backward import append_backward
if retain_graph is True:
raise AssertionError(
"`retain_graph` == True is not supported in @to_static function."
"please set retain_graph = False."
)
param_grad_list = append_backward(self)
for param, param_grad in param_grad_list:
# set grad to simulate dygraph loss.backward() in static mode.
setattr(param, "grad", param_grad)
@fake_interface_only
def gradient(self):
......@@ -7396,6 +7406,19 @@ def _get_var(name, program=None):
return program.global_block().var(name)
@signature_safe_contextmanager
def dygraph_guard_if_declarative():
from .dygraph.base import in_declarative_mode
from .dygraph import Tracer
if in_declarative_mode():
# Under @paddle.jit.to_static decorator, we switch back dygraph mode temporarily.
with _dygraph_guard(tracer=Tracer()):
yield
else:
yield
@signature_safe_contextmanager
def _dygraph_guard(tracer):
tmp_tracer = global_var._dygraph_tracer_
......
......@@ -21,9 +21,10 @@ from paddle import _legacy_C_ops
from paddle.amp.auto_cast import _in_amp_guard, _in_pure_fp16_guard
from paddle.fluid import backward, core, framework, program_guard
from paddle.fluid.compiler import BuildStrategy
from paddle.fluid.data_feeder import convert_dtype
from paddle.fluid.dygraph.base import switch_to_static_graph
from paddle.fluid.framework import _apply_pass
from paddle.nn.layer import layers
from paddle.optimizer.lr import LRScheduler
from . import logging_utils
from .utils import RETURN_NO_VALUE_MAGIC_NUM, _out_grad_names, _param_grad_names
......@@ -205,6 +206,8 @@ class PartialProgramLayer:
self._cast_fp16_if_pure_fp16(in_vars)
attrs = self._prepare_attributes()
self._sync_lr_value_with_scheduler()
_legacy_C_ops.run_program(
self._valid_vars(in_vars),
self._valid_vars(self._params),
......@@ -219,6 +222,21 @@ class PartialProgramLayer:
restored_nest_out = self._restore_out(out_vars)
return self._remove_no_value(restored_nest_out)
def _sync_lr_value_with_scheduler(self):
"""Update lr_var value with calculated by lr_scheduler."""
main_program = self._origin_main_program
if hasattr(main_program, 'lr_scheduler') and hasattr(
main_program, 'lr_var'
):
lr_scheduler = main_program.lr_scheduler
lr_var = main_program.lr_var
assert isinstance(lr_scheduler, LRScheduler), "must be LRScheduler"
lr_scheduler = self._origin_main_program.lr_scheduler
lr_value = lr_scheduler()
data = np.array(lr_value).astype(convert_dtype(lr_var.dtype))
lr_var.set_value(data)
def set_hooker(self, hooker):
self._hooker = hooker
......@@ -240,7 +258,8 @@ class PartialProgramLayer:
@LazyInitialized
def _double_grads(self):
return self._get_double_grads(self._origin_main_program)
# TODO: check the affects.
return None
# whole
@switch_to_static_graph
......@@ -658,23 +677,6 @@ class PartialProgramLayer:
self._params = required_params
def _get_double_grads(self, program):
double_grads = []
for block in program.blocks:
for name in block.vars:
if "@GRAD" in name:
var_desc = block.vars[name].desc
var_base = None
var_base = core.eager.Tensor(
var_desc.dtype(),
var_desc.shape(),
var_desc.name(),
var_desc.type(),
False,
)
double_grads.append(var_base)
return self._valid_vars(double_grads)
def _cast_fp16_if_pure_fp16(self, in_vars):
if _in_pure_fp16_guard():
for i, var in enumerate(in_vars):
......@@ -1053,9 +1055,11 @@ class PartialProgramLayer:
return vars if vars else None
def partial_program_from(concrete_program):
def partial_program_from(concrete_program, from_method=False):
inputs = concrete_program.inputs
if inputs and isinstance(inputs[0], layers.Layer):
# NOTE(SigureMo): Remove the first arg `self` from method args.
if inputs and from_method:
inputs = inputs[1:]
return PartialProgramLayer(
......
......@@ -1225,7 +1225,9 @@ class ProgramCache:
)
)
partial_program = partial_program_from(concrete_program)
partial_program = partial_program_from(
concrete_program, cache_key.class_instance is not None
)
if core._is_fwd_prim_enabled() and not _in_amp_guard():
partial_program.set_hooker(
PrimHooker(concrete_program.main_program)
......
......@@ -22,6 +22,7 @@ import paddle
from paddle import _C_ops, _legacy_C_ops, framework, in_dynamic_mode
from paddle.common_ops_import import Variable
from paddle.fluid.data_feeder import check_type, check_variable_and_dtype
from paddle.fluid.dygraph.base import NON_PERSISTABLE_VAR_NAME_SUFFIX
from paddle.fluid.framework import (
_non_static_mode,
default_startup_program,
......@@ -1428,7 +1429,8 @@ class RNNBase(LayerList):
# dropout state may also can be hided and avoid saving
# should dropout state be persistable for static-graph
self._dropout_state = self.create_variable(
dtype=core.VarDesc.VarType.UINT8
dtype=core.VarDesc.VarType.UINT8,
name=f"dropout_state{NON_PERSISTABLE_VAR_NAME_SUFFIX}",
)
if in_dynamic_mode():
with paddle.no_grad():
......
......@@ -389,7 +389,7 @@ class Adam(Optimizer):
return adam_op
@imperative_base.no_grad
@framework.dygraph_only
@framework.non_static_only
def step(self):
"""
Execute the optimizer and update parameters once.
......@@ -412,6 +412,10 @@ class Adam(Optimizer):
adam.step()
adam.clear_grad()
"""
if paddle.fluid.dygraph.base.in_declarative_mode():
self._declarative_step()
return
if not isinstance(self._parameter_list[0], dict):
params_grads = []
for param in self._parameter_list:
......
......@@ -530,7 +530,7 @@ class AdamW(Optimizer):
return " ".join(["Weight Decay, params:", ",".join(self._params_name)])
@imperative_base.no_grad
@framework.dygraph_only
@framework.non_static_only
def step(self):
"""
Execute the optimizer and update parameters once.
......@@ -553,6 +553,10 @@ class AdamW(Optimizer):
opt.step()
opt.clear_grad()
"""
if paddle.fluid.dygraph.base.in_declarative_mode():
self._declarative_step()
return
if not isinstance(self._parameter_list[0], dict):
params_grads = []
for param in self._parameter_list:
......
......@@ -412,65 +412,71 @@ class Optimizer:
return self._opti_name_list
def _create_global_learning_rate(self):
# lr var can't be float16 or bfloat16, for pure fp16 or bf16 training, should extra handle the dtype for lr
_lr_dtype = (
paddle.get_default_dtype() if self._dtype is None else self._dtype
)
_lr_dtype = (
paddle.float32
if (
(
paddle.get_default_dtype() != "float16"
and _lr_dtype == paddle.float16
)
or (
paddle.get_default_dtype() != "bfloat16"
and _lr_dtype == paddle.bfloat16
)
def do_create():
# lr var can't be float16 or bfloat16, for pure fp16 or bf16 training, should extra handle the dtype for lr
_lr_dtype = (
paddle.get_default_dtype()
if self._dtype is None
else self._dtype
)
else _lr_dtype
)
if isinstance(self._learning_rate, LRScheduler):
lr_var = self._global_learning_rate()
# only create global lr_var once
if not isinstance(lr_var, framework.Variable):
lr_name = unique_name.generate('learning_rate')
self._learning_rate._var_name = lr_name
lr_var = self.helper.create_global_variable(
name=lr_name,
shape=[],
persistable=True,
stop_gradient=True,
dtype=_lr_dtype,
_lr_dtype = (
paddle.float32
if (
(
paddle.get_default_dtype() != "float16"
and _lr_dtype == paddle.float16
)
or (
paddle.get_default_dtype() != "bfloat16"
and _lr_dtype == paddle.bfloat16
)
)
main_prog = framework.default_main_program()
main_prog.lr_scheduler = self._learning_rate
main_prog.lr_var = lr_var
self._learning_rate_map[
framework.default_main_program()
] = lr_var
lr_value = float(self._learning_rate())
self.helper.set_variable_initializer(
lr_var,
initializer=paddle.nn.initializer.Constant(value=lr_value),
else _lr_dtype
)
elif isinstance(self._learning_rate, float):
# only create global lr_var once
lr = self._global_learning_rate()
if isinstance(lr, framework.Variable):
return
else:
self._learning_rate_map[
framework.default_main_program()
] = paddle.static.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[],
value=float(self._learning_rate),
dtype=_lr_dtype,
persistable=True,
if isinstance(self._learning_rate, LRScheduler):
lr_var = self._global_learning_rate()
# only create global lr_var once
if not isinstance(lr_var, framework.Variable):
lr_name = unique_name.generate('learning_rate')
self._learning_rate._var_name = lr_name
lr_var = self.helper.create_global_variable(
name=lr_name,
shape=[],
persistable=True,
stop_gradient=True,
dtype=_lr_dtype,
)
main_prog = framework.default_main_program()
main_prog.lr_scheduler = self._learning_rate
main_prog.lr_var = lr_var
self._learning_rate_map[
framework.default_main_program()
] = lr_var
lr_value = float(self._learning_rate())
self.helper.set_variable_initializer(
lr_var,
initializer=paddle.nn.initializer.Constant(value=lr_value),
)
elif isinstance(self._learning_rate, float):
# only create global lr_var once
lr = self._global_learning_rate()
if isinstance(lr, framework.Variable):
return
else:
self._learning_rate_map[
framework.default_main_program()
] = paddle.static.create_global_var(
name=unique_name.generate("learning_rate"),
shape=[],
value=float(self._learning_rate),
dtype=_lr_dtype,
persistable=True,
)
with paddle.fluid.framework.dygraph_guard_if_declarative():
do_create()
@framework.dygraph_only
def set_lr(self, value):
......@@ -962,14 +968,15 @@ class Optimizer:
)
if isinstance(parameters_and_grads, list):
self._create_accumulators(
target_block,
[
p[0]
for p in parameters_and_grads
if not p[0].stop_gradient
],
)
with paddle.fluid.framework.dygraph_guard_if_declarative():
self._create_accumulators(
target_block,
[
p[0]
for p in parameters_and_grads
if not p[0].stop_gradient
],
)
else:
params_acc_dict = parameters_and_grads.copy()
params_acc_dict['params'] = [
......@@ -977,7 +984,8 @@ class Optimizer:
for p in params_acc_dict['params']
if not p[0].stop_gradient
]
self._create_accumulators(target_block, params_acc_dict)
with paddle.fluid.framework.dygraph_guard_if_declarative():
self._create_accumulators(target_block, params_acc_dict)
if framework._non_static_mode():
found_inf = self._get_auxiliary_var('found_inf')
......@@ -1329,7 +1337,7 @@ class Optimizer:
return no_grad_set
@framework.dygraph_only
@framework.non_static_only
def clear_grad(self, set_to_zero=True):
"""
Clear the gradients of all optimized parameters for model.
......@@ -1442,8 +1450,30 @@ class Optimizer:
return optimize_ops, params_grads
def _declarative_step(self):
"""
In declarative mode, we forward `call step` to `call apply_gradients`
"""
params = (
paddle.static.default_main_program().global_block().all_parameters()
)
assert not isinstance(
self._parameter_list[0], dict
), "Only list of parameters is supported while using optimizer in @paddle.jit.static."
selected_params = {param.name for param in self._parameter_list}
parameters = [param for param in params if param.trainable]
parameters = list(
filter(
lambda x: x.name in selected_params and hasattr(x, "grad"),
parameters,
)
)
params_grads = [(param, param.grad) for param in parameters]
optimize_ops = self.apply_gradients(params_grads)
return
@imperative_base.no_grad()
@framework.dygraph_only
@framework.non_static_only
def step(self):
"""
Execute the optimizer and update parameters once.
......@@ -1466,6 +1496,9 @@ class Optimizer:
adam.step()
adam.clear_grad()
"""
if paddle.fluid.dygraph.base.in_declarative_mode():
self._declarative_step()
return
if not isinstance(self._param_groups[0], dict):
params_grads = []
......
......@@ -44,6 +44,13 @@ if(WIN32 AND NOT WITH_GPU)
)# disable on Windows CPU CI for timeout
endif()
if(NOT WITH_GPU)
# TODO(SigureMo): Temporarily disable train step on Windows CPU CI.
# We should remove this after fix the performance issue.
list(REMOVE_ITEM TEST_OPS test_train_step_resnet18_adam)
list(REMOVE_ITEM TEST_OPS test_train_step_resnet18_sgd)
endif()
foreach(TEST_OP ${TEST_OPS})
list(FIND TEST_EAGER_OPS ${TEST_OP} WAS_FOUND)
if(NOT WAS_FOUND EQUAL -1)
......@@ -79,3 +86,8 @@ if(APPLE)
set_tests_properties(test_mobile_net PROPERTIES TIMEOUT 300)
set_tests_properties(test_resnet_v2 PROPERTIES TIMEOUT 300)
endif()
if(WITH_GPU)
set_tests_properties(test_train_step_resnet18_sgd PROPERTIES TIMEOUT 240)
set_tests_properties(test_train_step_resnet18_adam PROPERTIES TIMEOUT 240)
endif()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import unittest
from functools import partial
import numpy as np
import paddle
def reset_seed():
paddle.seed(1010)
np.random.seed(1010)
random.seed(1010)
def loss_fn_tiny_model(x):
return x.mean()
def train_step_tiny_model(net, x, loss_fn, opt):
out = net(x)
loss = loss_fn(out)
loss.backward()
opt.step()
opt.clear_grad()
return loss
class TinyModel(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.layer1 = paddle.nn.Linear(10, 10)
def forward(self, data):
return self.layer1(data)
class TestTrainStepTinyModel(unittest.TestCase):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = lambda: 0.001
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 5
def get_train_step_losses(self, func, steps):
losses = []
net = self.net_creator()
lr = self.lr_creator()
optimizer = self.optimizer_creator(
learning_rate=lr, parameters=net.parameters()
)
for _ in range(steps):
loss = func(net, self.input, self.loss_fn, optimizer)
if isinstance(lr, paddle.optimizer.lr.ReduceOnPlateau):
lr.step(loss)
elif isinstance(lr, paddle.optimizer.lr.LRScheduler):
lr.step()
losses.append(loss)
return losses
def test_train_step(self):
reset_seed()
dygraph_losses = self.get_train_step_losses(
self.train_step_func, self.steps
)
reset_seed()
static_func = paddle.jit.to_static(self.train_step_func)
static_losses = self.get_train_step_losses(static_func, self.steps)
self.assertEqual(len(dygraph_losses), len(static_losses))
for dygraph_loss, static_loss in zip(dygraph_losses, static_losses):
dygraph_loss = dygraph_loss.numpy()
static_loss = static_loss.numpy()
np.testing.assert_allclose(dygraph_loss, static_loss, rtol=1e-4)
class TestTrainStepTinyModelAdadelta(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = lambda: 0.001
self.optimizer_creator = paddle.optimizer.Adadelta
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelAdagrad(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = lambda: 0.001
self.optimizer_creator = paddle.optimizer.Adagrad
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelAdam(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = lambda: 0.001
self.optimizer_creator = paddle.optimizer.Adam
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelAdamax(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = lambda: 0.001
self.optimizer_creator = paddle.optimizer.Adamax
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelAdamW(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = lambda: 0.001
self.optimizer_creator = paddle.optimizer.AdamW
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLamb(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = lambda: 0.001
self.optimizer_creator = partial(
paddle.optimizer.Lamb, lamb_weight_decay=0.01
)
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelMomentum(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = lambda: 0.001
self.optimizer_creator = paddle.optimizer.Momentum
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelRMSProp(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = lambda: 0.001
self.optimizer_creator = paddle.optimizer.RMSProp
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRNoamDecay(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.NoamDecay, d_model=0.01, warmup_steps=100
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRPiecewiseDecay(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.PiecewiseDecay,
boundaries=[3, 6, 9],
values=[0.1, 0.2, 0.3, 0.4],
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRNaturalExpDecay(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.NaturalExpDecay,
learning_rate=0.5,
gamma=0.1,
)
self.optimizer_creator = partial(paddle.optimizer.SGD)
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRInverseTimeDecay(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.InverseTimeDecay, learning_rate=0.5, gamma=0.1
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRPolynomialDecay(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.PolynomialDecay,
learning_rate=0.5,
decay_steps=20,
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRLinearWarmup(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.LinearWarmup,
learning_rate=0.5,
warmup_steps=2,
start_lr=0,
end_lr=0.5,
)
self.optimizer_creator = partial(paddle.optimizer.SGD)
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRExponentialDecay(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.ExponentialDecay, learning_rate=0.5, gamma=0.9
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRMultiStepDecay(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.MultiStepDecay,
learning_rate=0.5,
milestones=[2, 4, 6],
gamma=0.8,
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRStepDecay(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.StepDecay,
learning_rate=0.5,
step_size=5,
gamma=0.8,
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRLambdaDecay(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.LambdaDecay,
learning_rate=0.5,
lr_lambda=lambda x: 0.95**x,
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRReduceOnPlateau(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.ReduceOnPlateau,
learning_rate=1.0,
factor=0.5,
patience=5,
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRCosineAnnealingDecay(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.CosineAnnealingDecay,
learning_rate=0.5,
T_max=10,
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRMultiplicativeDecay(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.MultiplicativeDecay,
learning_rate=0.5,
lr_lambda=lambda x: 0.95,
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLROneCycleLR(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.OneCycleLR, max_learning_rate=1.0, total_steps=3
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
class TestTrainStepTinyModelLRCyclicLR(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([10000, 10])
self.net_creator = TinyModel
self.lr_creator = partial(
paddle.optimizer.lr.CyclicLR,
base_learning_rate=0.5,
max_learning_rate=1.0,
step_size_up=15,
step_size_down=5,
)
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test_train_step import (
TestTrainStepTinyModel,
loss_fn_tiny_model,
train_step_tiny_model,
)
import paddle
from paddle.vision.models import resnet18
class TestTrainStepResNet18Adam(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([64, 3, 224, 224])
self.net_creator = resnet18
self.lr_creator = lambda: 0.001
self.optimizer_creator = paddle.optimizer.Adam
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from test_train_step import (
TestTrainStepTinyModel,
loss_fn_tiny_model,
train_step_tiny_model,
)
import paddle
from paddle.vision.models import resnet18
class TestTrainStepResNet18Sgd(TestTrainStepTinyModel):
def setUp(self):
self.input = paddle.randn([64, 3, 224, 224])
self.net_creator = resnet18
self.lr_creator = lambda: 0.001
self.optimizer_creator = paddle.optimizer.SGD
self.loss_fn = loss_fn_tiny_model
self.train_step_func = train_step_tiny_model
self.steps = 3
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册