提交 f6128777 编写于 作者: D dongdaxiang

add incubate for unified API

上级 317eb0aa
......@@ -1358,6 +1358,7 @@ All parameter, weight, gradient are variables in Paddle.
BindRecordIOWriter(&m);
BindAsyncExecutor(&m);
BindFleetWrapper(&m);
BindGraph(&m);
BindNode(&m);
BindInferenceApi(&m);
......
......@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .helper import MPIHelper
class RoleMakerBase(object):
......@@ -46,6 +45,7 @@ class MPIRoleMaker(RoleMakerBase):
from mpi4py import MPI
self.comm_ = MPI.COMM_WORLD
self.MPI = MPI
self.ips_ = None
def get_rank(self):
self.rank_ = self.comm_.Get_rank()
......
......@@ -14,19 +14,10 @@
import sys
import os
from ..base.role_maker import MPISymetricRoleMaker
from paddle.fluid.optimizer import Optimizer
# this is a temporary solution
# TODO(guru4elephant)
# will make this more flexible for more Parameter Server Archs
fleet_instance = Fleet()
init = fleet_instance.init
stop = fleet_instance.stop
init_pserver = fleet_instance.init_pserver
init_worker = fleet_instance.init_worker
init_pserver_model = fleet_instance.init_pserver_model
save_pserver_model = fleet_instance.save_pserver_model
from .optimizer_factory import *
from google.protobuf import text_format
import paddle.fluid.optimizer as local_optimizer
import paddle.fluid as fluid
class Fleet(object):
......@@ -35,7 +26,7 @@ class Fleet(object):
"""
def __init__(self):
self.opt_info = None # for fleet only
self._opt_info = None # for fleet only
self.role_maker_ = None
def init(self):
......@@ -44,7 +35,7 @@ class Fleet(object):
# we will support more configurable RoleMaker for users in the future
self.role_maker_ = MPISymetricRoleMaker()
self.role_maker_.generate_role()
self._fleet_ptr = core.FleetWrapper()
self._fleet_ptr = fluid.core.Fleet()
def stop(self):
self.role_maker_.barrier_worker()
......@@ -91,6 +82,12 @@ class Fleet(object):
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
def is_worker(self):
return self.role_maker_.is_worker()
def is_server(self):
return self.role_maker_.is_server()
def init_pserver_model(self):
if self.role_maker_.is_first_worker():
self._fleet_ptr.init_model()
......@@ -103,7 +100,7 @@ class Fleet(object):
self._opt_info = opt_info
class DistributedOptimizer(paddle.fluid.Optimizer):
class DistributedOptimizer(object):
def __init__(self, optimizer, dist_config={}):
super(DistributedOptimizer, self).__init__()
self._optimizer = optimizer
......@@ -115,7 +112,7 @@ class DistributedOptimizer(paddle.fluid.Optimizer):
sys.stderr)
self._optimizer_name = "DistributedAdam"
self._distributed_optimizer = globals()[self._optimizer_name]()
self._distributed_optimizer = globals()[self._optimizer_name](optimizer)
def backward(self,
loss,
......@@ -135,7 +132,6 @@ class DistributedOptimizer(paddle.fluid.Optimizer):
no_grad_set=None):
optimize_ops, param_grads, opt_info = \
self._distributed_optimizer.minimize(
self._optimizer,
loss,
startup_program,
parameter_list,
......@@ -143,3 +139,18 @@ class DistributedOptimizer(paddle.fluid.Optimizer):
fleet_instance._set_opt_info(opt_info)
return [optimize_ops, param_grads]
# this is a temporary solution
# TODO(guru4elephant)
# will make this more flexible for more Parameter Server Archs
fleet_instance = Fleet()
init = fleet_instance.init
stop = fleet_instance.stop
init_pserver = fleet_instance.init_pserver
init_worker = fleet_instance.init_worker
is_worker = fleet_instance.is_worker
is_server = fleet_instance.is_server
init_pserver_model = fleet_instance.init_pserver_model
save_pserver_model = fleet_instance.save_pserver_model
......@@ -120,7 +120,12 @@ packages=['paddle',
'paddle.fluid.contrib.slim.distillation',
'paddle.fluid.contrib.utils',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details']
'paddle.fluid.transpiler.details',
'paddle.fluid.incubate',
'paddle.fluid.incubate.fleet',
'paddle.fluid.incubate.fleet.base',
'paddle.fluid.incubate.fleet.parameter_server',
'paddle.fluid.incubate.fleet.p2p']
with open('@PADDLE_SOURCE_DIR@/python/requirements.txt') as f:
setup_requires = f.read().splitlines()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册