diff --git a/demo/mnist/api_train.py b/demo/mnist/api_train.py index f301da382ff8a5bc16d9c18b956f78566ed4894f..8573d8143a085b8d2e0bcf7df17b1abe177029df 100644 --- a/demo/mnist/api_train.py +++ b/demo/mnist/api_train.py @@ -13,15 +13,7 @@ import numpy as np import random from mnist_util import read_from_mnist from paddle.trainer_config_helpers import * - - -def optimizer_config(): - settings( - learning_rate=1e-4, - learning_method=AdamOptimizer(), - batch_size=1000, - model_average=ModelAverage(average_window=0.5), - regularization=L2Regularization(rate=0.5)) +import paddle.v2 def network_config(): @@ -75,19 +67,23 @@ def input_order_converter(generator): def main(): api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores - # get enable_types for each optimizer. - # enable_types = [value, gradient, momentum, etc] - # For each optimizer(SGD, Adam), GradientMachine should enable different - # buffers. - opt_config_proto = parse_optimizer_config(optimizer_config) - opt_config = api.OptimizationConfig.createFromProto(opt_config_proto) - _temp_optimizer_ = api.ParameterOptimizer.create(opt_config) - enable_types = _temp_optimizer_.getParameterTypes() + optimizer = paddle.v2.optimizer.Adam( + learning_rate=1e-4, + batch_size=1000, + model_average=ModelAverage(average_window=0.5), + regularization=L2Regularization(rate=0.5)) + + # Create Local Updater. Local means not run in cluster. + # For a cluster training, here we can change to createRemoteUpdater + # in future. + updater = optimizer.create_local_updater() + assert isinstance(updater, api.ParameterUpdater) # Create Simple Gradient Machine. model_config = parse_network_config(network_config) - m = api.GradientMachine.createFromConfigProto( - model_config, api.CREATE_MODE_NORMAL, enable_types) + m = api.GradientMachine.createFromConfigProto(model_config, + api.CREATE_MODE_NORMAL, + optimizer.enable_types()) # This type check is not useful. Only enable type hint in IDE. # Such as PyCharm @@ -96,12 +92,6 @@ def main(): # Initialize Parameter by numpy. init_parameter(network=m) - # Create Local Updater. Local means not run in cluster. - # For a cluster training, here we can change to createRemoteUpdater - # in future. - updater = api.ParameterUpdater.createLocalUpdater(opt_config) - assert isinstance(updater, api.ParameterUpdater) - # Initialize ParameterUpdater. updater.init(m) diff --git a/python/paddle/v2/__init__.py b/python/paddle/v2/__init__.py index f662d6826321eb840739382558f76327d27b5847..b2ea87b086101d71e89c33ce7c1f4eb21afade5a 100644 --- a/python/paddle/v2/__init__.py +++ b/python/paddle/v2/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import optimizer + +__all__ = ['optimizer'] diff --git a/python/paddle/v2/optimizer.py b/python/paddle/v2/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2942bc9faeb2a353459cd619886f56ea32f450 --- /dev/null +++ b/python/paddle/v2/optimizer.py @@ -0,0 +1,58 @@ +import py_paddle.swig_paddle as swig_api +import paddle.trainer_config_helpers.optimizers as v1_optimizers +import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils +import paddle.v2 + +__all__ = ['Adam', 'Adamax'] + + +class Optimizer(object): + def __init__(self, **kwargs): + if 'batch_size' in kwargs: + del kwargs['batch_size'] # not important for python library. + + def __impl__(): + v1_optimizers.settings(batch_size=1, **kwargs) + + self.__opt_conf_proto__ = config_parser_utils.parse_optimizer_config( + __impl__) + self.__opt_conf__ = swig_api.OptimizationConfig.createFromProto( + self.__opt_conf_proto__) + + def enable_types(self): + """ + get enable_types for each optimizer. + enable_types = [value, gradient, momentum, etc] + For each optimizer(SGD, Adam), GradientMachine should enable different + buffers. + """ + tmp = swig_api.ParameterOptimizer.create(self.__opt_conf__) + assert isinstance(tmp, swig_api.ParameterOptimizer) + return tmp.getParameterTypes() + + def create_local_updater(self): + return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__) + + def create_remote_updater(self, pass_num): + return swig_api.ParameterUpdater.createRemoteUpdater(self.__opt_conf__, + pass_num) + + +class Adam(Optimizer): + def __init__(self, beta1=0.9, beta2=0.999, epsilon=1e-8, **kwargs): + learning_method = v1_optimizers.AdamOptimizer( + beta1=beta1, beta2=beta2, epsilon=epsilon) + super(Adam, self).__init__(learning_method=learning_method, **kwargs) + + +class Adamax(Optimizer): + def __init__(self, beta1=0.9, beta2=0.999, **kwargs): + learning_method = v1_optimizers.AdamaxOptimizer( + beta1=beta1, beta2=beta2) + super(Adamax, self).__init__(learning_method=learning_method, **kwargs) + + +if __name__ == '__main__': + swig_api.initPaddle('--use_gpu=false') + opt = paddle.v2.optimizer.Adam() + print opt.enable_types()