提交 9508c726 编写于 作者: T typhoonzero

wip: should fix variable recreate

上级 b4cd7f3d
...@@ -85,7 +85,7 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) { ...@@ -85,7 +85,7 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
} }
void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
bool create_local_scope) { bool create_local_scope, bool create_vars) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication) // - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op // - will change to use multiple blocks for RNN op and Cond Op
...@@ -94,33 +94,35 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id, ...@@ -94,33 +94,35 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
auto& device = device_contexts_[0]; auto& device = device_contexts_[0];
Scope* local_scope = scope; Scope* local_scope = scope;
if (create_local_scope) { if (create_vars) {
local_scope = &scope->NewScope(); if (create_local_scope) {
for (auto& var : block.AllVars()) { local_scope = &scope->NewScope();
if (var->Name() == framework::kEmptyVarName) { for (auto& var : block.AllVars()) {
continue; if (var->Name() == framework::kEmptyVarName) {
continue;
}
if (var->Persistable()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr;
}
} }
} else {
if (var->Persistable()) { for (auto& var : block.AllVars()) {
auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr;
} else {
auto* ptr = local_scope->Var(var->Name()); auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType()); CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var->Name() VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< " locally, which pointer is " << ptr; << ptr;
} }
} } // if (create_local_scope)
} else { } // if (create_vars)
for (auto& var : block.AllVars()) {
auto* ptr = local_scope->Var(var->Name());
CreateTensor(ptr, var->GetType());
VLOG(3) << "Create variable " << var->Name() << ", which pointer is "
<< ptr;
}
}
for (auto& op_desc : block.AllOps()) { for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
......
...@@ -35,7 +35,8 @@ class Executor { ...@@ -35,7 +35,8 @@ class Executor {
* ProgramDesc * ProgramDesc
* Scope * Scope
*/ */
void Run(const ProgramDescBind&, Scope*, int, bool create_local_scope = true); void Run(const ProgramDescBind&, Scope*, int, bool create_local_scope = true,
bool create_vars = true);
private: private:
std::vector<const platform::DeviceContext*> device_contexts_; std::vector<const platform::DeviceContext*> device_contexts_;
......
...@@ -20,7 +20,7 @@ namespace detail { ...@@ -20,7 +20,7 @@ namespace detail {
Status SendRecvServerImpl::SendVariable(ServerContext *context, Status SendRecvServerImpl::SendVariable(ServerContext *context,
const VariableMessage *in_var, const VariableMessage *in_var,
VariableMessage *out_var) { VoidMessage *out_var) {
// TODO(typhoonzero): support different variable types. // TODO(typhoonzero): support different variable types.
std::istringstream iss(in_var->serialized()); std::istringstream iss(in_var->serialized());
framework::LoDTensor t; framework::LoDTensor t;
...@@ -29,6 +29,12 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context, ...@@ -29,6 +29,12 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context,
std::make_pair(in_var->varname(), std::move(t)); std::make_pair(in_var->varname(), std::move(t));
var_recv_queue_.Push(std::move(tensor_with_name)); var_recv_queue_.Push(std::move(tensor_with_name));
return Status::OK;
}
Status SendRecvServerImpl::GetVariable(ServerContext *context,
const VoidMessage *in_var,
VariableMessage *out_var) {
// Block util the sub graph is done. // Block util the sub graph is done.
auto out_tensor_with_name = var_return_queue_.Pop(); auto out_tensor_with_name = var_return_queue_.Pop();
std::ostringstream oss; std::ostringstream oss;
...@@ -36,10 +42,9 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context, ...@@ -36,10 +42,9 @@ Status SendRecvServerImpl::SendVariable(ServerContext *context,
platform::CPUDeviceContext()); platform::CPUDeviceContext());
std::string *varname = out_var->mutable_varname(); std::string *varname = out_var->mutable_varname();
*varname = in_var->varname(); *varname = out_tensor_with_name.first;
std::string *serialized = out_var->mutable_serialized(); std::string *serialized = out_var->mutable_serialized();
*serialized = oss.str(); *serialized = oss.str();
return Status::OK; return Status::OK;
} }
......
...@@ -19,10 +19,10 @@ namespace operators { ...@@ -19,10 +19,10 @@ namespace operators {
namespace detail { namespace detail {
bool RPCClient::SendVariable(const framework::Scope& scope, bool RPCClient::SendVariable(const framework::Scope& scope,
const std::string& inname, const std::string& inname) {
const std::string& outname) {
ClientContext context; ClientContext context;
VariableMessage msg, out_msg; VariableMessage msg;
VoidMessage out_msg;
// FIXME(typhoonzero): pass device context to here. // FIXME(typhoonzero): pass device context to here.
auto ctx = platform::CPUDeviceContext(); auto ctx = platform::CPUDeviceContext();
auto* var = scope.FindVar(inname); auto* var = scope.FindVar(inname);
...@@ -40,7 +40,22 @@ bool RPCClient::SendVariable(const framework::Scope& scope, ...@@ -40,7 +40,22 @@ bool RPCClient::SendVariable(const framework::Scope& scope,
LOG(ERROR) << "gRPC error: " << status.error_message(); LOG(ERROR) << "gRPC error: " << status.error_message();
return false; return false;
} }
std::istringstream iss(out_msg.serialized()); return true;
}
bool RPCClient::GetVariable(const framework::Scope& scope) {
ClientContext context;
VariableMessage msg;
VoidMessage void_msg;
auto ctx = platform::CPUDeviceContext();
Status status = stub_->GetVariable(&context, void_msg, &msg);
if (!status.ok()) {
LOG(ERROR) << "gRPC error: " << status.error_message();
return false;
}
std::istringstream iss(msg.serialized());
auto outname = msg.varname();
framework::LoDTensor ret_tensor; framework::LoDTensor ret_tensor;
framework::DeserializeFromStream(iss, &ret_tensor); framework::DeserializeFromStream(iss, &ret_tensor);
auto* outvar = scope.FindVar(outname); auto* outvar = scope.FindVar(outname);
......
...@@ -20,7 +20,9 @@ service SendRecvService { ...@@ -20,7 +20,9 @@ service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors. // For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor // Send and recv only one tensor
// TODO(typhoonzero): add streaming API // TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VariableMessage) {} rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VoidMessage) returns (VariableMessage) {}
} }
// VariableMessage is serialized paddle variable message. // VariableMessage is serialized paddle variable message.
......
...@@ -55,7 +55,9 @@ class SendRecvServerImpl final : public SendRecvService::Service { ...@@ -55,7 +55,9 @@ class SendRecvServerImpl final : public SendRecvService::Service {
explicit SendRecvServerImpl() {} explicit SendRecvServerImpl() {}
Status SendVariable(ServerContext *context, const VariableMessage *in_var, Status SendVariable(ServerContext *context, const VariableMessage *in_var,
VariableMessage *out_var) override; VoidMessage *out_var) override;
Status GetVariable(ServerContext *context, const VoidMessage *in_var,
VariableMessage *out_var) override;
const TensorWithName Get() { return this->var_recv_queue_.Pop(); } const TensorWithName Get() { return this->var_recv_queue_.Pop(); }
...@@ -75,8 +77,8 @@ class RPCClient { ...@@ -75,8 +77,8 @@ class RPCClient {
RPCClient(std::shared_ptr<Channel> channel) RPCClient(std::shared_ptr<Channel> channel)
: stub_(SendRecvService::NewStub(channel)) {} : stub_(SendRecvService::NewStub(channel)) {}
bool SendVariable(const framework::Scope &scope, const std::string &inname, bool SendVariable(const framework::Scope &scope, const std::string &inname);
const std::string &outname); bool GetVariable(const framework::Scope &scope);
private: private:
std::unique_ptr<SendRecvService::Stub> stub_; std::unique_ptr<SendRecvService::Stub> stub_;
......
...@@ -66,37 +66,25 @@ class RecvOp : public framework::OperatorBase { ...@@ -66,37 +66,25 @@ class RecvOp : public framework::OperatorBase {
const platform::DeviceContext &dev_ctx) const override { const platform::DeviceContext &dev_ctx) const override {
// FIXME(typhoonzero): no new scopes for every run. // FIXME(typhoonzero): no new scopes for every run.
framework::Scope &recv_scope = scope.NewScope(); framework::Scope &recv_scope = scope.NewScope();
// blocking get one var from client.
const detail::TensorWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
auto param_list = Attr<std::vector<std::string>>("ParamList"); auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList"); auto grad_list = Attr<std::vector<std::string>>("GradList");
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); size_t param_count = param_list.size();
std::string param_var_name; for (size_t i = 0; i < param_count; ++i) {
if (it != grad_list.end()) { // blocking get one var from client.
param_var_name = param_list[it - grad_list.begin()]; const detail::TensorWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
param_var_name = param_list[it - grad_list.begin()];
}
VLOG(10) << "recved grad: " << grad_var_name
<< " updating param: " << param_var_name;
auto *var = recv_scope.Var(grad_var_name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
// FIXME(typhoonzero): do not copy
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
} }
// find input by "grad_var_name"
// auto inputs = Inputs("RX");
// FIXME(typhoonzero): Find the parameter name from input grad name
// rename X -> Param
// rename RX -> Grad
LOG(ERROR) << "recved grad: " << grad_var_name
<< " param: " << param_var_name;
auto *var = recv_scope.Var(grad_var_name);
auto *tensor = var->GetMutable<framework::LoDTensor>();
// Param is in parent scope, put it in current scope.
auto *param_var = recv_scope.FindVar(param_var_name);
auto param_scope = recv_scope.FindScope(param_var);
param_scope->Rename(param_var_name, "Param");
recv_scope.Rename(grad_var_name, "Grad");
// FIXME(typhoonzero): do not copy
framework::CopyFrom(v.second, dev_ctx.GetPlace(), dev_ctx, tensor);
std::string program_str = Attr<std::string>("OptimizeProgram"); std::string program_str = Attr<std::string>("OptimizeProgram");
framework::ProgramDesc program_desc; framework::ProgramDesc program_desc;
...@@ -104,17 +92,20 @@ class RecvOp : public framework::OperatorBase { ...@@ -104,17 +92,20 @@ class RecvOp : public framework::OperatorBase {
framework::ProgramDescBind program(program_desc); framework::ProgramDescBind program(program_desc);
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
// Run sub graph to get optimized tensor // Run sub graph to get optimized tensor
executor.Run(program, &recv_scope, 0, /*global_block*/ try {
false /*create_local_scope*/); executor.Run(program, &recv_scope, 0, /*global_block*/
false /*create_local_scope*/, false /*create_vars*/);
auto *out_var = recv_scope.FindVar("ParamOut"); } catch (std::exception &e) {
detail::TensorWithName out; LOG(ERROR) << "run sub program error " << e.what();
out.first = param_var_name; }
out.second = out_var->Get<framework::LoDTensor>();
rpc_service_->Push(out); for (size_t i = 0; i < param_count; ++i) {
// rename back the params auto *out_var = recv_scope.FindVar(param_list[i]);
param_scope.Rename("Param", param_var_name); detail::TensorWithName out;
recv_scope.Rename("Grad", grad_var_name); out.first = param_list[i];
out.second = out_var->Get<framework::LoDTensor>();
rpc_service_->Push(out);
}
} }
protected: protected:
......
...@@ -48,11 +48,18 @@ class SendOp : public framework::OperatorBase { ...@@ -48,11 +48,18 @@ class SendOp : public framework::OperatorBase {
// should block until server responds. // should block until server responds.
for (auto in : ins) { for (auto in : ins) {
LOG(ERROR) << "sending grad: " << in; LOG(ERROR) << "sending grad: " << in;
bool ret = client_->SendVariable(scope, in, in); bool ret = client_->SendVariable(scope, in);
if (!ret) { if (!ret) {
LOG(ERROR) << "send variable error"; LOG(ERROR) << "send variable error";
} }
} }
for (auto in : ins) {
LOG(ERROR) << "updating from server...";
bool ret = client_->GetVariable(scope);
if (!ret) {
LOG(ERROR) << "GetVariable error";
}
}
} }
protected: protected:
......
...@@ -138,7 +138,6 @@ class Executor(object): ...@@ -138,7 +138,6 @@ class Executor(object):
inputs=opt_op.inputs, inputs=opt_op.inputs,
outputs=opt_op.outputs, outputs=opt_op.outputs,
attrs=opt_op.attrs) attrs=opt_op.attrs)
print("optimize program: ", optimize_sub_program)
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
type="recv", type="recv",
...@@ -248,7 +247,7 @@ class Executor(object): ...@@ -248,7 +247,7 @@ class Executor(object):
outputs={'Out': [fetch_var]}, outputs={'Out': [fetch_var]},
attrs={'col': i}) attrs={'col': i})
self.executor.run(program.desc, scope, 0, True) self.executor.run(program.desc, scope, 0, True, True)
outs = [ outs = [
core.get_fetch_variable(scope, fetch_var_name, i) core.get_fetch_variable(scope, fetch_var_name, i)
for i in xrange(len(fetch_list)) for i in xrange(len(fetch_list))
......
...@@ -44,6 +44,7 @@ exe.optimize(optimize_ops, params_grads, pservers="127.0.0.1:6174", trainers=1) ...@@ -44,6 +44,7 @@ exe.optimize(optimize_ops, params_grads, pservers="127.0.0.1:6174", trainers=1)
pserver_endpoint = os.getenv("PSERVER") pserver_endpoint = os.getenv("PSERVER")
if pserver_endpoint: if pserver_endpoint:
pserver_prog = exe.get_pserver_program(pserver_endpoint, optimize_ops) pserver_prog = exe.get_pserver_program(pserver_endpoint, optimize_ops)
print("pserver startup: ", fluid.default_startup_program())
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
while True: while True:
exe.run(pserver_prog) exe.run(pserver_prog)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册