提交 843b63bb 编写于 作者: Q qiaolongfei 提交者: Yu Yang

add config_parser in trainer_config_helpers to seperate trainer config

上级 3a802729
......@@ -9,11 +9,29 @@ The user api could be simpler and carefully designed.
import py_paddle.swig_paddle as api
from py_paddle import DataProviderConverter
import paddle.trainer.PyDataProvider2 as dp
import paddle.trainer.config_parser
import numpy as np
import random
from mnist_util import read_from_mnist
import paddle.trainer_config_helpers.config_parser as config_parser
from paddle.trainer_config_helpers import *
def optimizer_config():
settings(
learning_rate=1e-4, learning_method=AdamOptimizer(), batch_size=1000)
def network_config():
imgs = data_layer(name='pixel', size=784)
hidden1 = fc_layer(input=imgs, size=200)
hidden2 = fc_layer(input=hidden1, size=200)
inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
cost = classification_cost(
input=inference, label=data_layer(
name='label', size=10))
outputs(cost)
def init_parameter(network):
assert isinstance(network, api.GradientMachine)
......@@ -54,20 +72,20 @@ def input_order_converter(generator):
def main():
api.initPaddle("-use_gpu=false", "-trainer_count=4") # use 4 cpu cores
config = paddle.trainer.config_parser.parse_config(
'simple_mnist_network.py', '')
# get enable_types for each optimizer.
# enable_types = [value, gradient, momentum, etc]
# For each optimizer(SGD, Adam), GradientMachine should enable different
# buffers.
opt_config = api.OptimizationConfig.createFromProto(config.opt_config)
opt_config_proto = config_parser.parse_optimizer_config(optimizer_config)
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
enable_types = _temp_optimizer_.getParameterTypes()
# Create Simple Gradient Machine.
model_config = config_parser.parse_network_config(network_config)
m = api.GradientMachine.createFromConfigProto(
config.model_config, api.CREATE_MODE_NORMAL, enable_types)
model_config, api.CREATE_MODE_NORMAL, enable_types)
# This type check is not useful. Only enable type hint in IDE.
# Such as PyCharm
......
......@@ -3416,8 +3416,35 @@ def register_parse_config_hook(f):
_parse_config_hooks.add(f)
def parse_config(config_file, config_arg_str):
def update_g_config():
'''
Update g_config after execute config_file or config_functions.
'''
for k, v in settings.iteritems():
if v is None:
continue
g_config.opt_config.__setattr__(k, v)
for k, v in trainer_settings.iteritems():
if v is None:
continue
g_config.__setattr__(k, v)
for name in g_config.model_config.input_layer_names:
assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name
assert (g_layer_map[name].type == "data" or g_layer_map[name].type == "data_trim"), \
'The type of input layer "%s" is not "data"' % name
for name in g_config.model_config.output_layer_names:
assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name
return g_config
def parse_config(trainer_config, config_arg_str):
'''
@param trainer_config: can be a string of config file name or a function name
with config logic
@param config_arg_str: a string of the form var1=val1,var2=val2. It will be
passed to config script as a dictionary CONFIG_ARGS
'''
......@@ -3451,45 +3478,20 @@ def parse_config(config_file, config_arg_str):
g_root_submodel.is_recurrent_layer_group = False
g_current_submodel = g_root_submodel
# for paddle on spark, need support non-file config.
# you can use parse_config like below:
#
# from paddle.trainer.config_parser import parse_config
# def configs():
# #your paddle config code, which is same as config file.
#
# config = parse_config(configs, "is_predict=1")
# # then you get config proto object.
if hasattr(config_file, '__call__'):
config_file.func_globals.update(
if hasattr(trainer_config, '__call__'):
trainer_config.func_globals.update(
make_config_environment("", config_args))
config_file()
trainer_config()
else:
execfile(config_file, make_config_environment(config_file, config_args))
for k, v in settings.iteritems():
if v is None:
continue
g_config.opt_config.__setattr__(k, v)
for k, v in trainer_settings.iteritems():
if v is None:
continue
g_config.__setattr__(k, v)
execfile(trainer_config,
make_config_environment(trainer_config, config_args))
for name in g_config.model_config.input_layer_names:
assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name
assert (g_layer_map[name].type == "data" or g_layer_map[name].type == "data_trim"), \
'The type of input layer "%s" is not "data"' % name
for name in g_config.model_config.output_layer_names:
assert name in g_layer_map, \
'input name "%s" does not correspond to a layer name' % name
return g_config
return update_g_config()
def parse_config_and_serialize(config_file, config_arg_str):
def parse_config_and_serialize(trainer_config, config_arg_str):
try:
config = parse_config(config_file, config_arg_str)
config = parse_config(trainer_config, config_arg_str)
#logger.info(config)
return config.SerializeToString()
except:
......
......@@ -20,6 +20,7 @@ from layers import *
from networks import *
from optimizers import *
from attrs import *
from config_parser import *
# This will enable operator overload for LayerOutput
import math as layer_math
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.trainer.config_parser as config_parser
'''
This file is a wrapper of formal config_parser. The main idea of this file is to
separete different config logic into different function, such as network configuration
and optimizer configuration.
'''
__all__ = [
"parse_trainer_config", "parse_network_config", "parse_optimizer_config"
]
def parse_trainer_config(trainer_conf, config_arg_str):
return config_parser.parse_config(trainer_conf, config_arg_str)
def parse_network_config(network_conf):
config = config_parser.parse_config(network_conf, '')
return config.model_config
def parse_optimizer_config(optimizer_conf):
config = config_parser.parse_config(optimizer_conf, '')
return config.opt_config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册