提交 71c2b296 编写于 作者: T typhoonzero

update

上级 23785584
...@@ -101,6 +101,10 @@ class Parameters(object): ...@@ -101,6 +101,10 @@ class Parameters(object):
self.__param_conf__[param_conf.name] = param_conf self.__param_conf__[param_conf.name] = param_conf
def update_param_conf(self, model_config):
for p in model_config.parameters:
self.__param_conf__[p.name] = p
def keys(self): def keys(self):
""" """
keys are the names of each parameter. keys are the names of each parameter.
......
...@@ -19,6 +19,7 @@ import paddle.trainer_config_helpers as conf_helps ...@@ -19,6 +19,7 @@ import paddle.trainer_config_helpers as conf_helps
import layer as v2_layer import layer as v2_layer
import config_base import config_base
import cPickle import cPickle
from paddle.trainer import config_parser as cp
__all__ = ['Topology'] __all__ = ['Topology']
...@@ -50,6 +51,32 @@ class Topology(object): ...@@ -50,6 +51,32 @@ class Topology(object):
assert isinstance(self.__model_config__, ModelConfig) assert isinstance(self.__model_config__, ModelConfig)
def update_from_default(self):
# HACK(typhoonzero): update ParameterConfig(proto) in case of optimizers
# are defined after layers, or between layers.
# Must be called from trainer.__init__()
for parameter in self.__model_config__.parameters:
print "####", parameter.decay_rate, cp.g_default_decay_rate
if parameter.momentum == 0.0 and cp.g_default_momentum:
parameter.momentum = cp.g_default_momentum
if parameter.decay_rate == 0.0 and cp.g_default_decay_rate:
parameter.decay_rate = cp.g_default_decay_rate
if parameter.initial_mean == 0.0:
parameter.initial_mean = cp.g_default_initial_mean
if parameter.initial_std == 0.01:
parameter.initial_std = cp.g_default_initial_std
if parameter.initial_strategy == 0:
parameter.initial_strategy = cp.g_default_initial_strategy
if parameter.initial_smart == False:
parameter.initial_smart = cp.g_default_initial_smart
if parameter.num_batches_regularization == 1 and cp.g_default_num_batches_regularization:
parameter.num_batches_regularization = cp.g_default_num_batches_regularization
if parameter.gradient_clipping_threshold == 0.0 and cp.g_default_gradient_clipping_threshold:
parameter.gradient_clipping_threshold = cp.g_default_gradient_clipping_threshold
if parameter.device == -1 and cp.g_default_device:
parameter.device = cp.g_default_device
# FIXME(typhoonzero): ignored: update_hooks, g_default_compact_func
def use_sparse_updater(self): def use_sparse_updater(self):
""" """
check if any parameter require to use sparse_update check if any parameter require to use sparse_update
......
...@@ -64,6 +64,11 @@ class SGD(object): ...@@ -64,6 +64,11 @@ class SGD(object):
"paddle.v2.optimizer.Optimizer") "paddle.v2.optimizer.Optimizer")
import py_paddle.swig_paddle as api import py_paddle.swig_paddle as api
topology = Topology(cost, extra_layers=extra_layers) topology = Topology(cost, extra_layers=extra_layers)
# HACK(typhoonzero): update ParameterConfig(proto) in case of optimizers
# are defined after layers, or between layers.
topology.update_from_default()
parameters.update_param_conf(topology.proto())
self.__optimizer__ = update_equation self.__optimizer__ = update_equation
self.__topology__ = topology self.__topology__ = topology
self.__parameters__ = parameters self.__parameters__ = parameters
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册