提交 763a30fd 编写于 作者: Q qiaolongfei 提交者: Yu Yang

add config_parser_utils

上级 843b63bb
...@@ -12,14 +12,17 @@ import paddle.trainer.PyDataProvider2 as dp ...@@ -12,14 +12,17 @@ import paddle.trainer.PyDataProvider2 as dp
import numpy as np import numpy as np
import random import random
from mnist_util import read_from_mnist from mnist_util import read_from_mnist
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
import paddle.trainer_config_helpers.config_parser as config_parser
from paddle.trainer_config_helpers import * from paddle.trainer_config_helpers import *
def optimizer_config(): def optimizer_config():
settings( settings(
learning_rate=1e-4, learning_method=AdamOptimizer(), batch_size=1000) learning_rate=1e-4,
learning_method=AdamOptimizer(),
batch_size=1000,
model_average=ModelAverage(average_window=0.5),
regularization=L2Regularization(rate=0.5))
def network_config(): def network_config():
...@@ -77,13 +80,14 @@ def main(): ...@@ -77,13 +80,14 @@ def main():
# enable_types = [value, gradient, momentum, etc] # enable_types = [value, gradient, momentum, etc]
# For each optimizer(SGD, Adam), GradientMachine should enable different # For each optimizer(SGD, Adam), GradientMachine should enable different
# buffers. # buffers.
opt_config_proto = config_parser.parse_optimizer_config(optimizer_config) opt_config_proto = config_parser_utils.parse_optimizer_config(
optimizer_config)
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto) opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config) _temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
enable_types = _temp_optimizer_.getParameterTypes() enable_types = _temp_optimizer_.getParameterTypes()
# Create Simple Gradient Machine. # Create Simple Gradient Machine.
model_config = config_parser.parse_network_config(network_config) model_config = config_parser_utils.parse_network_config(network_config)
m = api.GradientMachine.createFromConfigProto( m = api.GradientMachine.createFromConfigProto(
model_config, api.CREATE_MODE_NORMAL, enable_types) model_config, api.CREATE_MODE_NORMAL, enable_types)
......
from paddle.trainer_config_helpers import *
settings(
learning_rate=1e-4,
learning_method=AdamOptimizer(),
batch_size=1000,
model_average=ModelAverage(average_window=0.5),
regularization=L2Regularization(rate=0.5))
imgs = data_layer(name='pixel', size=784)
hidden1 = fc_layer(input=imgs, size=200)
hidden2 = fc_layer(input=hidden1, size=200)
inference = fc_layer(input=hidden2, size=10, act=SoftmaxActivation())
cost = classification_cost(
input=inference, label=data_layer(
name='label', size=10))
outputs(cost)
...@@ -20,7 +20,7 @@ from layers import * ...@@ -20,7 +20,7 @@ from layers import *
from networks import * from networks import *
from optimizers import * from optimizers import *
from attrs import * from attrs import *
from config_parser import * from config_parser_utils import *
# This will enable operator overload for LayerOutput # This will enable operator overload for LayerOutput
import math as layer_math import math as layer_math
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle.trainer.config_parser as config_parser
'''
This file is a wrapper of formal config_parser. The main idea of this file is to
separete different config logic into different function, such as network configuration
and optimizer configuration.
'''
__all__ = [
"parse_trainer_config", "parse_network_config", "parse_optimizer_config"
]
def parse_trainer_config(trainer_conf, config_arg_str):
return config_parser.parse_config(trainer_conf, config_arg_str)
def parse_network_config(network_conf, config_arg_str=''):
config = config_parser.parse_config(network_conf, config_arg_str)
return config.model_config
def parse_optimizer_config(optimizer_conf, config_arg_str=''):
config = config_parser.parse_config(optimizer_conf, config_arg_str)
return config.opt_config
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册