提交 c78a5e5d 编写于 作者: Y Yu Yang

Fix merge error before

上级 c5bc1267
import py_paddle.swig_paddle as swig_api
import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils import paddle.trainer_config_helpers.config_parser_utils as config_parser_utils
import paddle.trainer_config_helpers.optimizers as v1_optimizers import paddle.trainer_config_helpers.optimizers as v1_optimizers
""" """
...@@ -17,6 +16,7 @@ __all__ = [ ...@@ -17,6 +16,7 @@ __all__ = [
class Optimizer(object): class Optimizer(object):
def __init__(self, **kwargs): def __init__(self, **kwargs):
import py_paddle.swig_paddle as swig_api
if 'batch_size' in kwargs: if 'batch_size' in kwargs:
del kwargs['batch_size'] # not important for python library. del kwargs['batch_size'] # not important for python library.
...@@ -25,8 +25,6 @@ class Optimizer(object): ...@@ -25,8 +25,6 @@ class Optimizer(object):
self.__opt_conf_proto__ = config_parser_utils.parse_optimizer_config( self.__opt_conf_proto__ = config_parser_utils.parse_optimizer_config(
__impl__) __impl__)
if swig_api is None:
raise RuntimeError("paddle.v2 currently need swig_paddle")
self.__opt_conf__ = swig_api.OptimizationConfig.createFromProto( self.__opt_conf__ = swig_api.OptimizationConfig.createFromProto(
self.__opt_conf_proto__) self.__opt_conf_proto__)
...@@ -37,18 +35,22 @@ class Optimizer(object): ...@@ -37,18 +35,22 @@ class Optimizer(object):
For each optimizer(SGD, Adam), GradientMachine should enable different For each optimizer(SGD, Adam), GradientMachine should enable different
buffers. buffers.
""" """
import py_paddle.swig_paddle as swig_api
tmp = swig_api.ParameterOptimizer.create(self.__opt_conf__) tmp = swig_api.ParameterOptimizer.create(self.__opt_conf__)
assert isinstance(tmp, swig_api.ParameterOptimizer) assert isinstance(tmp, swig_api.ParameterOptimizer)
return tmp.getParameterTypes() return tmp.getParameterTypes()
def __create_local_updater__(self): def __create_local_updater__(self):
import py_paddle.swig_paddle as swig_api
return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__) return swig_api.ParameterUpdater.createLocalUpdater(self.__opt_conf__)
def __create_remote_updater__(self, pass_num, use_sparse_updater): def __create_remote_updater__(self, pass_num, use_sparse_updater):
import py_paddle.swig_paddle as swig_api
return swig_api.ParameterUpdater.createRemoteUpdater( return swig_api.ParameterUpdater.createRemoteUpdater(
self.__opt_conf__, pass_num, use_sparse_updater) self.__opt_conf__, pass_num, use_sparse_updater)
def __create_new_remote_updater__(self, pserver_spec, use_etcd): def __create_new_remote_updater__(self, pserver_spec, use_etcd):
import py_paddle.swig_paddle as swig_api
return swig_api.ParameterUpdater.createNewRemoteUpdater( return swig_api.ParameterUpdater.createNewRemoteUpdater(
self.__opt_conf__, pserver_spec, use_etcd) self.__opt_conf__, pserver_spec, use_etcd)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册