提交 12e9c003 编写于 作者: Q qiaolongfei

add optimizer

上级 a3f0aed0
......@@ -13,15 +13,7 @@ import numpy as np
import random
from mnist_util import read_from_mnist
from paddle.trainer_config_helpers import *
def optimizer_config():
settings(
learning_rate=1e-4,
learning_method=AdamOptimizer(),
batch_size=1000,
model_average=ModelAverage(average_window=0.5),
regularization=L2Regularization(rate=0.5))
import paddle.v2
def network_config():
......@@ -75,19 +67,23 @@ def input_order_converter(generator):
def main():
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
# get enable_types for each optimizer.
# enable_types = [value, gradient, momentum, etc]
# For each optimizer(SGD, Adam), GradientMachine should enable different
# buffers.
opt_config_proto = parse_optimizer_config(optimizer_config)
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
enable_types = _temp_optimizer_.getParameterTypes()
optimizer = paddle.v2.optimizer.Adam(
learning_rate=1e-4,
batch_size=1000,
model_average=ModelAverage(average_window=0.5),
regularization=L2Regularization(rate=0.5))
# Create Local Updater. Local means not run in cluster.
# For a cluster training, here we can change to createRemoteUpdater
# in future.
updater = optimizer.create_local_updater()
assert isinstance(updater, api.ParameterUpdater)
# Create Simple Gradient Machine.
model_config = parse_network_config(network_config)
m = api.GradientMachine.createFromConfigProto(
model_config, api.CREATE_MODE_NORMAL, enable_types)
m = api.GradientMachine.createFromConfigProto(model_config,
api.CREATE_MODE_NORMAL,
optimizer.enable_types())
# This type check is not useful. Only enable type hint in IDE.
# Such as PyCharm
......@@ -96,12 +92,6 @@ def main():
# Initialize Parameter by numpy.
init_parameter(network=m)
# Create Local Updater. Local means not run in cluster.
# For a cluster training, here we can change to createRemoteUpdater
# in future.
updater = api.ParameterUpdater.createLocalUpdater(opt_config)
assert isinstance(updater, api.ParameterUpdater)
# Initialize ParameterUpdater.
updater.init(m)
......
......@@ -11,3 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import optimizer
__all__ = ['optimizer']
import py_paddle.swig_paddle as swig_api
import paddle.trainer_config_helpers.optimizers as v1_optimizers
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
import paddle.v2
__all__ = ['Adam', 'Adamax']
class Optimizer(object):
def __init__(self, **kwargs):
if 'batch_size' in kwargs:
del kwargs['batch_size'] # not important for python library.
def __impl__():
v1_optimizers.settings(batch_size=1, **kwargs)
self.__opt_conf_proto__ = config_parser_utils.parse_optimizer_config(
__impl__)
self.__opt_conf__ = swig_api.OptimizationConfig.createFromProto(
self.__opt_conf_proto__)
def enable_types(self):
"""
get enable_types for each optimizer.
enable_types = [value, gradient, momentum, etc]
For each optimizer(SGD, Adam), GradientMachine should enable different
buffers.
"""
tmp = swig_api.ParameterOptimizer.create(self.__opt_conf__)
assert isinstance(tmp, swig_api.ParameterOptimizer)
return tmp.getParameterTypes()
def create_local_updater(self):
return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__)
def create_remote_updater(self, pass_num):
return swig_api.ParameterUpdater.createRemoteUpdater(self.__opt_conf__,
pass_num)
class Adam(Optimizer):
def __init__(self, beta1=0.9, beta2=0.999, epsilon=1e-8, **kwargs):
learning_method = v1_optimizers.AdamOptimizer(
beta1=beta1, beta2=beta2, epsilon=epsilon)
super(Adam, self).__init__(learning_method=learning_method, **kwargs)
class Adamax(Optimizer):
def __init__(self, beta1=0.9, beta2=0.999, **kwargs):
learning_method = v1_optimizers.AdamaxOptimizer(
beta1=beta1, beta2=beta2)
super(Adamax, self).__init__(learning_method=learning_method, **kwargs)
if __name__ == '__main__':
swig_api.initPaddle('--use_gpu=false')
opt = paddle.v2.optimizer.Adam()
print opt.enable_types()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册