提交 9508c726 编写于 作者: T typhoonzero

wip: should fix variable recreate

上级 b4cd7f3d
......@@ -85,7 +85,7 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
}
void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
bool create_local_scope) {
bool create_local_scope, bool create_vars) {
// TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op
......@@ -94,6 +94,7 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
auto& device = device_contexts_[0];
Scope* local_scope = scope;
if (create_vars) {
if (create_local_scope) {
local_scope = &scope->NewScope();
for (auto& var : block.AllVars()) {
......@@ -120,7 +121,8 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}
} // if (create_local_scope)
} // if (create_vars)
for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
......
......@@ -35,7 +35,8 @@ class Executor {
* ProgramDesc
* Scope
*/
void Run(const ProgramDescBind&, Scope*, int, bool create_local_scope = true);
void Run(const ProgramDescBind&, Scope*, int, bool create_local_scope = true,
bool create_vars = true);
private:
std::vector<const platform::DeviceContext*> device_contexts_;
......
......@@ -20,7 +20,7 @@ namespace detail {
Status SendRecvServerImpl::SendVariable(ServerContext *context,
const VariableMessage *in_var,
VariableMessage *out_var) {
VoidMessage *out_var) {
// TODO(typhoonzero): support different variable types.
std::istringstream iss(in_var->serialized());
framework::LoDTensor t;
......@@ -29,6 +29,12 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context,
std::make_pair(in_var->varname(), std::move(t));
var_recv_queue_.Push(std::move(tensor_with_name));
return Status::OK;
}
Status SendRecvServerImpl::GetVariable(ServerContext *context,
const VoidMessage *in_var,
VariableMessage *out_var) {
// Block util the sub graph is done.
auto out_tensor_with_name = var_return_queue_.Pop();
std::ostringstream oss;
......@@ -36,10 +42,9 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context,
platform::CPUDeviceContext());
std::string *varname = out_var->mutable_varname();
*varname = in_var->varname();
*varname = out_tensor_with_name.first;
std::string *serialized = out_var->mutable_serialized();
*serialized = oss.str();
return Status::OK;
}
......
......@@ -19,10 +19,10 @@ namespace operators {
namespace detail {
bool RPCClient::SendVariable(const framework::Scope& scope,
const std::string& inname,
const std::string& outname) {
const std::string& inname) {
ClientContext context;
VariableMessage msg, out_msg;
VariableMessage msg;
VoidMessage out_msg;
// FIXME(typhoonzero): pass device context to here.
auto ctx = platform::CPUDeviceContext();
auto* var = scope.FindVar(inname);
......@@ -40,7 +40,22 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
LOG(ERROR) << "gRPC error: " << status.error_message();
return false;
}
std::istringstream iss(out_msg.serialized());
return true;
}
bool RPCClient::GetVariable(const framework::Scope& scope) {
ClientContext context;
VariableMessage msg;
VoidMessage void_msg;
auto ctx = platform::CPUDeviceContext();
Status status = stub_->GetVariable(&context, void_msg, &msg);
if (!status.ok()) {
LOG(ERROR) << "gRPC error: " << status.error_message();
return false;
}
std::istringstream iss(msg.serialized());
auto outname = msg.varname();
framework::LoDTensor ret_tensor;
framework::DeserializeFromStream(iss, &ret_tensor);
auto* outvar = scope.FindVar(outname);
......
......@@ -20,7 +20,9 @@ service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VariableMessage) {}
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VoidMessage) returns (VariableMessage) {}
}
// VariableMessage is serialized paddle variable message.
......
......@@ -55,6 +55,8 @@ class SendRecvServerImpl final : public SendRecvService::Service {
explicit SendRecvServerImpl() {}
Status SendVariable(ServerContext *context, const VariableMessage *in_var,
VoidMessage *out_var) override;
Status GetVariable(ServerContext *context, const VoidMessage *in_var,
VariableMessage *out_var) override;
const TensorWithName Get() { return this->var_recv_queue_.Pop(); }
......@@ -75,8 +77,8 @@ class RPCClient {
RPCClient(std::shared_ptr<Channel> channel)
: stub_(SendRecvService::NewStub(channel)) {}
bool SendVariable(const framework::Scope &scope, const std::string &inname,
const std::string &outname);
bool SendVariable(const framework::Scope &scope, const std::string &inname);
bool GetVariable(const framework::Scope &scope);
private:
std::unique_ptr<SendRecvService::Stub> stub_;
......
......@@ -66,37 +66,25 @@ class RecvOp : public framework::OperatorBase {
const platform::DeviceContext &dev_ctx) const override {
// FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope();
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
size_t param_count = param_list.size();
for (size_t i = 0; i < param_count; ++i) {
// blocking get one var from client.
const detail::TensorWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
}
// find input by "grad_var_name"
// auto inputs = Inputs("RX");
// FIXME(typhoonzero): Find the parameter name from input grad name
// rename X -> Param
// rename RX -> Grad
LOG(ERROR) << "recved grad: " << grad_var_name
<< " param: " << param_var_name;
VLOG(10) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;
auto *var = recv_scope.Var(grad_var_name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
// Param is in parent scope, put it in current scope.
auto *param_var = recv_scope.FindVar(param_var_name);
auto param_scope = recv_scope.FindScope(param_var);
param_scope->Rename(param_var_name, "Param");
recv_scope.Rename(grad_var_name, "Grad");
// FIXME(typhoonzero): do not copy
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
}
std::string program_str = Attr<std::string>("OptimizeProgram");
framework::ProgramDesc program_desc;
......@@ -104,17 +92,20 @@ class RecvOp : public framework::OperatorBase {
framework::ProgramDescBind program(program_desc);
framework::Executor executor(dev_ctx);
// Run sub graph to get optimized tensor
try {
executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/);
false /*create_local_scope*/, false /*create_vars*/);
} catch (std::exception &e) {
LOG(ERROR) << "run sub program error " << e.what();
}
auto *out_var = recv_scope.FindVar("ParamOut");
for (size_t i = 0; i < param_count; ++i) {
auto *out_var = recv_scope.FindVar(param_list[i]);
detail::TensorWithName out;
out.first = param_var_name;
out.first = param_list[i];
out.second = out_var->Get<framework::LoDTensor>();
rpc_service_->Push(out);
// rename back the params
param_scope.Rename("Param", param_var_name);
recv_scope.Rename("Grad", grad_var_name);
}
}
protected:
......
......@@ -48,11 +48,18 @@ class SendOp : public framework::OperatorBase {
// should block until server responds.
for (auto in : ins) {
LOG(ERROR) << "sending grad: " << in;
bool ret = client_->SendVariable(scope, in, in);
bool ret = client_->SendVariable(scope, in);
if (!ret) {
LOG(ERROR) << "send variable error";
}
}
for (auto in : ins) {
LOG(ERROR) << "updating from server...";
bool ret = client_->GetVariable(scope);
if (!ret) {
LOG(ERROR) << "GetVariable error";
}
}
}
protected:
......
......@@ -138,7 +138,6 @@ class Executor(object):
inputs=opt_op.inputs,
outputs=opt_op.outputs,
attrs=opt_op.attrs)
print("optimize program: ", optimize_sub_program)
pserver_program.global_block().append_op(
type="recv",
......@@ -248,7 +247,7 @@ class Executor(object):
outputs={'Out': [fetch_var]},
attrs={'col': i})
self.executor.run(program.desc, scope, 0, True)
self.executor.run(program.desc, scope, 0, True, True)
outs = [
core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list))
......
......@@ -44,6 +44,7 @@ exe.optimize(optimize_ops, params_grads, pservers="127.0.0.1:6174", trainers=1)
pserver_endpoint = os.getenv("PSERVER")
if pserver_endpoint:
pserver_prog = exe.get_pserver_program(pserver_endpoint, optimize_ops)
print("pserver startup: ", fluid.default_startup_program())
exe.run(fluid.default_startup_program())
while True:
exe.run(pserver_prog)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册