提交 489b9695 编写于 作者: T typhoonzero

wip for testing

上级 308491a9
......@@ -21,16 +21,20 @@ namespace detail {
Status SendRecvServerImpl::SendVariable(ServerContext *context,
const VariableMessage *in_var,
VariableMessage *out_var) {
framework::LoDTensor t;
// TODO(typhoonzero): desirealize in_tensor and run pserver network.
// TODO(typhoonzero): support different variable types.
std::istringstream iss(in_var->serialized());
framework::LoDTensor t;
framework::DeserializeFromStream(iss, &t);
lodtensor_queue_.Push(std::move(t));
TensorWithName tensor_with_name =
std::make_pair(in_var->varname(), std::move(t));
var_recv_queue_.Push(std::move(tensor_with_name));
// Block util the sub graph is done.
t = lodtensor_return_queue_.Pop();
auto out_tensor_with_name = var_return_queue_.Pop();
std::ostringstream oss;
// FIXME(typhoonzero): get context from op.
framework::SerializeToStream(oss, t, platform::CPUDeviceContext());
framework::SerializeToStream(oss, out_tensor_with_name.second,
platform::CPUDeviceContext());
std::string *varname = out_var->mutable_varname();
*varname = in_var->varname();
std::string *serialized = out_var->mutable_serialized();
......
......@@ -19,6 +19,7 @@ package sendrecv;
service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VariableMessage) {}
}
......
......@@ -48,6 +48,8 @@ namespace paddle {
namespace operators {
namespace detail {
typedef std::pair<std::string, framework::LoDTensor> TensorWithName;
class SendRecvServerImpl final : public SendRecvService::Service {
public:
explicit SendRecvServerImpl() {}
......@@ -55,17 +57,15 @@ class SendRecvServerImpl final : public SendRecvService::Service {
Status SendVariable(ServerContext *context, const VariableMessage *in_var,
VariableMessage *out_var) override;
const framework::LoDTensor Get() { return this->lodtensor_queue_.Pop(); }
const TensorWithName Get() { return this->var_recv_queue_.Pop(); }
void Push(const framework::LoDTensor &tensor) {
this->lodtensor_return_queue_.Push(tensor);
}
void Push(const TensorWithName &var) { this->var_return_queue_.Push(var); }
private:
SimpleBlockQueue<framework::LoDTensor> lodtensor_queue_;
SimpleBlockQueue<framework::LoDTensor> lodtensor_return_queue_;
SimpleBlockQueue<framework::SelectedRows> selected_rows_queue_;
SimpleBlockQueue<framework::SelectedRows> selected_rows_return_queue_;
// received variable from RPC, operators fetch variable from this queue.
SimpleBlockQueue<TensorWithName> var_recv_queue_;
// calculated variable should push to this queue.
SimpleBlockQueue<TensorWithName> var_return_queue_;
};
// RPCClient is a class to send tensors to pserver sub-network
......
......@@ -14,6 +14,7 @@
#include <stdint.h>
#include <sys/stat.h>
#include <iostream>
#include <ostream>
#include <thread>
......@@ -63,14 +64,32 @@ class RecvOp : public framework::OperatorBase {
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
// blocking get one var from client.
const framework::LoDTensor &t = rpc_service_->Get();
framework::Scope &recv_scope = scope.NewScope();
// blocking get one var from client.
const detail::TensorWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
// framework::Scope &recv_scope = scope.NewScope();
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
}
// set graph input var
auto *var = recv_scope.Var(Input("RX"));
auto input_grad = Input("RX");
// FIXME(typhoonzero): Find the parameter name from input grad name
// rename X -> Param
// rename RX -> Grad
auto *var = recv_scope.FindVar(input_grad);
auto *tensor = var->GetMutable<framework::LoDTensor>();
recv_scope.Rename(param_var_name, "Param");
recv_scope.Rename("RX", "Grad");
// FIXME(typhoonzero): do not copy
framework::CopyFrom(t, dev_ctx.GetPlace(), dev_ctx, tensor);
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
std::string program_str = Attr<std::string>("OptimizeProgram");
framework::ProgramDesc program_desc;
......@@ -81,9 +100,14 @@ class RecvOp : public framework::OperatorBase {
executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/);
auto *out_var = recv_scope.FindVar("Out");
// push back
rpc_service_->Push(out_var->Get<framework::LoDTensor>());
auto *out_var = recv_scope.FindVar("Param");
detail::TensorWithName out;
out.first = param_var_name;
out.second = out_var->Get<framework::LoDTensor>();
rpc_service_->Push(out);
// rename back the params
recv_scope.Rename("Param", param_var_name);
recv_scope.Rename("Grad", "RX");
}
protected:
......@@ -93,13 +117,14 @@ class RecvOp : public framework::OperatorBase {
// grpc send/recv service implement to register.
std::shared_ptr<detail::SendRecvServerImpl> rpc_service_;
std::shared_ptr<std::thread> server_thread_;
framework::Scope const *recv_scope_{nullptr};
};
class RecvOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RecvOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("RX", "(Tensor) Input tensor to be saved");
AddInput("RX", "(Tensor) Input tensor to be optimized").AsDuplicable();
AddComment(R"DOC(
Recv operator
......@@ -112,6 +137,12 @@ This operator will recv tensor from send_op
.AddCustomChecker([](const std::string &ip) { return !ip.empty(); });
AddAttr<std::string>("OptimizeProgram", "type string",
"Serialized ProgramDesc string for recv to run.");
AddAttr<std::vector<std::string>>(
"ParamList", "type list of string",
"grad->param name mapping to find which param to optimize.");
AddAttr<std::vector<std::string>>(
"GradList", "type list of string",
"grad->param name mapping to find which param to optimize.");
}
};
......
import numpy as np
from . import core
from framework import Program, default_main_program
from framework import Program, default_main_program, Parameter, Variable
import distribute_planner
__all__ = ['Executor', 'g_scope']
......@@ -91,7 +91,7 @@ class Executor(object):
# FIXME(typhoonzero): send to different servers can run in parrallel.
send_op = program.global_block().append_op(
type="send",
inputs={"X": self.param_grad_map[ep]["params"]
inputs={"X": self.param_grad_map[ep]["grads"]
}, # inputs is a list of tensors to be send
outputs={},
attrs={"endpoint": ep})
......@@ -102,9 +102,20 @@ class Executor(object):
def get_pserver_program(self, endpoint):
pserver_program = Program()
for param in self.param_grad_map[endpoint]["params"]:
pserver_program.global_block().create_parameter(**param.__dict__)
for v in self.param_grad_map[endpoint]["params"]:
assert isinstance(v, Parameter)
new_p = Parameter(
block=pserver_program.global_block(),
shape=v.shape,
dtype=v.dtype,
type=v.type,
lod_level=v.lod_level,
stop_gradient=v.stop_gradient,
trainable=v.trainable,
optimize_attr=v.optimize_attr,
regularizer=v.regularizer,
name=v.name)
pserver_program.global_block().vars[new_p.name] = new_p
pserver_program.global_block().append_op(
type="recv",
......@@ -112,12 +123,12 @@ class Executor(object):
self.param_grad_map[endpoint]["grads"]}, # grads to recv
outputs={},
attrs={
"OptimizeProgram": self.optimize_sub_program.to_string(),
"endpoint": endpoint
"OptimizeProgram": self.optimize_sub_program.to_string(True),
"endpoint": endpoint,
"ParamList": self.param_grad_map[endpoint]["params"],
"GradList": self.param_grad_map[endpoint]["grads"]
})
def get_trainer_program(self):
return default_main_program()
return pserver_program
def aslodtensor(self, data):
def accumulate(data):
......
......@@ -45,7 +45,8 @@ pserver_endpoint = os.getenv("PSERVER")
if pserver_endpoint:
pserver_prog = exe.get_pserver_program(pserver_endpoint)
exe.run(fluid.default_startup_program())
exe.run(pserver_prog)
while True:
exe.run(pserver_prog)
else:
feeder = fluid.DataFeeder(feed_list=[images, label], place=place)
exe.run(fluid.default_startup_program())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册