提交 9b41b08e 编写于 作者: Y Yu Yang

Remove unnecessary import in api_train.py

上级 763a30fd
...@@ -12,7 +12,6 @@ import paddle.trainer.PyDataProvider2 as dp ...@@ -12,7 +12,6 @@ import paddle.trainer.PyDataProvider2 as dp
import numpy as np import numpy as np
import random import random
from mnist_util import read_from_mnist from mnist_util import read_from_mnist
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
from paddle.trainer_config_helpers import * from paddle.trainer_config_helpers import *
...@@ -80,14 +79,13 @@ def main(): ...@@ -80,14 +79,13 @@ def main():
# enable_types = [value, gradient, momentum, etc] # enable_types = [value, gradient, momentum, etc]
# For each optimizer(SGD, Adam), GradientMachine should enable different # For each optimizer(SGD, Adam), GradientMachine should enable different
# buffers. # buffers.
opt_config_proto = config_parser_utils.parse_optimizer_config( opt_config_proto = parse_optimizer_config(optimizer_config)
optimizer_config)
opt_config = api.OptimizationConfig.createFromProto(opt_config_proto) opt_config = api.OptimizationConfig.createFromProto(opt_config_proto)
_temp_optimizer_ = api.ParameterOptimizer.create(opt_config) _temp_optimizer_ = api.ParameterOptimizer.create(opt_config)
enable_types = _temp_optimizer_.getParameterTypes() enable_types = _temp_optimizer_.getParameterTypes()
# Create Simple Gradient Machine. # Create Simple Gradient Machine.
model_config = config_parser_utils.parse_network_config(network_config) model_config = parse_network_config(network_config)
m = api.GradientMachine.createFromConfigProto( m = api.GradientMachine.createFromConfigProto(
model_config, api.CREATE_MODE_NORMAL, enable_types) model_config, api.CREATE_MODE_NORMAL, enable_types)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册