提交 308491a9 编写于 作者: T typhoonzero

update for simple dist train

上级 1c1fae60
......@@ -43,13 +43,14 @@ class SendOp : public framework::OperatorBase {
}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto iname = Input("X");
auto oname = Output("Out");
auto ins = Inputs("X");
// TODO(typhoonzero): currently it's non-blocking,
// should block until server responds.
bool ret = client_->SendVariable(scope, iname, oname);
if (!ret) {
LOG(ERROR) << "send variable error";
for (auto in : ins) {
bool ret = client_->SendVariable(scope, in, in);
if (!ret) {
LOG(ERROR) << "send variable error";
}
}
}
......@@ -61,8 +62,7 @@ class SendOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SendOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "(Tensor) Input tensor to be saved");
AddOutput("Out", "(Tensor) Output fetched from server");
AddInput("X", "(Tensor) Input tensor to be send").AsDuplicable();
AddComment(R"DOC(
Recv operator
......
......@@ -30,7 +30,7 @@ def hash_name_to_server(params_grads, pserver_endpoints):
def round_robin(parameters, pserver_endpoints):
assert (len(parameters) < len(pserver_endpoints))
assert (len(parameters) > len(pserver_endpoints))
param_grad_map = dict()
pserver_idx = 0
......@@ -44,6 +44,6 @@ def round_robin(parameters, pserver_endpoints):
param_grad_map[server_for_param]["grads"].append(param)
pserver_idx += 1
if pserver_idx > len(pserver_endpoints):
if pserver_idx >= len(pserver_endpoints):
pserver_idx = 0
return param_grad_map
......@@ -50,7 +50,7 @@ class Executor(object):
self.executor = core.Executor(act_places)
self.places = places
def optimize(self, optimize_ops, program=None, **kwargs):
def optimize(self, optimize_ops, params_grads, program=None, **kwargs):
"""
optimize the program for different runtime environment
......@@ -67,7 +67,8 @@ class Executor(object):
program = default_main_program()
if kwargs.has_key("pservers"):
return self._optimize_distributed(optimize_ops, program, **kwargs)
return self._optimize_distributed(optimize_ops, program,
params_grads, **kwargs)
def _optimize_distributed(self, optimize_ops, program, params_and_grads,
**kwargs):
......@@ -92,7 +93,7 @@ class Executor(object):
type="send",
inputs={"X": self.param_grad_map[ep]["params"]
}, # inputs is a list of tensors to be send
outputs={"Out": self.param_grad_map[ep]["params"]},
outputs={},
attrs={"endpoint": ep})
# -------------- generate optimize sub program --------------
self.optimize_sub_program = Program()
......
......@@ -304,7 +304,8 @@ class Operator(object):
self.desc.check_attrs()
no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent',
'rnn_memory_helper_grad', 'conditional_block', 'while'
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv'
}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
......
......@@ -202,7 +202,7 @@ class Optimizer(object):
params_grads = append_regularization_ops(params_grads)
optimize_ops = self.create_optimization_pass(params_grads, loss,
startup_program)
return optimize_ops
return optimize_ops, params_grads
class SGDOptimizer(Optimizer):
......
......@@ -2,6 +2,7 @@ from __future__ import print_function
import numpy as np
import paddle.v2 as paddle
import paddle.v2.fluid as fluid
import os
images = fluid.layers.data(name='pixel', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
......@@ -24,7 +25,7 @@ predict = fluid.layers.fc(input=conv_pool_2, size=10, act="softmax")
cost = fluid.layers.cross_entropy(input=predict, label=label)
avg_cost = fluid.layers.mean(x=cost)
optimizer = fluid.optimizer.Adam(learning_rate=0.01)
optimizer.minimize(avg_cost)
optimize_ops, params_grads = optimizer.minimize(avg_cost)
accuracy = fluid.evaluator.Accuracy(input=predict, label=label)
......@@ -38,10 +39,10 @@ train_reader = paddle.batch(
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.optimize(pservers="127.0.0.1:6174", trainers=1)
exe.optimize(optimize_ops, params_grads, pservers="127.0.0.1:6174", trainers=1)
pserver_endpoint = os.getenv("PSERVER")
if is_pserver:
if pserver_endpoint:
pserver_prog = exe.get_pserver_program(pserver_endpoint)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册