未验证 提交 8b50ad80 编写于 作者: T tangwei12 提交者: GitHub

checkpoint at distributed training (#14854)

checkpoint for distributed training.
上级 07dc5a15
......@@ -74,7 +74,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
SendProcessor* s = new SendProcessor(ch);
const std::string method = "SendRPC";
const std::string method = kSendRPC;
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
......@@ -107,7 +107,7 @@ VarHandlePtr GRPCClient::AsyncSendVar(const std::string& ep,
void ProcGetResponse(const VarHandle& var_h,
const ::grpc::ByteBuffer& ret_msg) {
VLOG(100) << "ProcGetResponse";
VLOG(4) << "ProcGetResponse";
framework::Variable* outvar = nullptr;
// get response's trainer_id is not used
int trainer_id;
......@@ -127,39 +127,54 @@ VarHandlePtr GRPCClient::AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name,
return _AsyncGetVar(ep, ctx, scope, kGetRPC, var_name, out_varname,
"/sendrecv.SendRecvService/GetVariable", time_out);
}
VarHandlePtr GRPCClient::AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname, int64_t time_out) {
std::string var_name_no_barrier =
string::Sprintf("%s%s", var_name, WITHOUT_BARRIER_MESSAGE);
return _AsyncGetVar(
ep, ctx, scope, kGetNoBarrierRPC, var_name_no_barrier, out_varname,
"/sendrecv.SendRecvService/GetVariableNoBarrier", time_out);
}
VarHandlePtr GRPCClient::AsyncGetMonomerVariable(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
int64_t time_out) {
return _AsyncGetVar(ep, ctx, scope, var_name,
return _AsyncGetVar(ep, ctx, scope, kGetMonomerRPC, var_name, var_name,
"/sendrecv.SendRecvService/GetMonomerVariable", time_out);
}
VarHandlePtr GRPCClient::_AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& rpc_path,
int64_t time_out) {
VarHandlePtr GRPCClient::_AsyncGetVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, int64_t time_out) {
const platform::DeviceContext* p_ctx = &ctx;
const std::string ep_val = ep;
const std::string var_name_val = var_name;
const std::string out_varname_val = out_varname;
const framework::Scope* p_scope = &scope;
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
const std::string method = "GetRPC";
VarHandlePtr h(new VarHandle(ep, method, var_name_val, p_ctx, p_scope));
VarHandlePtr h(new VarHandle(ep, method, out_varname_val, p_ctx, p_scope));
s->Prepare(h, time_out);
framework::AsyncIO([var_name_val, s, method, p_ctx, h, rpc_path, this] {
framework::AsyncIO(
[var_name_val, out_varname_val, s, method, p_ctx, h, rpc_path, this] {
// prepare input
sendrecv::VariableMessage req;
req.set_varname(var_name_val);
req.set_out_varname(out_varname_val);
req.set_trainer_id(trainer_id_);
::grpc::ByteBuffer buf;
RequestToByteBuffer<sendrecv::VariableMessage>(req, &buf);
......@@ -202,7 +217,7 @@ VarHandlePtr GRPCClient::AsyncPrefetchVar(const std::string& ep,
const auto ch = GetChannel(ep_val);
GetProcessor* s = new GetProcessor(ch);
const std::string method = "PrefetchRPC";
const std::string method = kPrefetchRPC;
VarHandlePtr h(new VarHandle(ep, method, out_var_name_val, p_ctx, p_scope));
s->Prepare(h, time_out);
......@@ -242,7 +257,7 @@ VarHandlePtr GRPCClient::AsyncSendBatchBarrier(const std::string& ep,
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "BatchBarrierRPC";
const std::string method = kBatchBarrierRPC;
VarHandlePtr h(
new VarHandle(ep, method, BATCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
......@@ -267,7 +282,7 @@ VarHandlePtr GRPCClient::AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
FetchBarrierProcessor* s = new FetchBarrierProcessor(ch);
const std::string method = "FetchBarrierRPC";
const std::string method = kFetchBarrierRPC;
VarHandlePtr h(
new VarHandle(ep, method, FETCH_BARRIER_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
......@@ -293,7 +308,7 @@ VarHandlePtr GRPCClient::AsyncGetMonomerBarrier(const std::string& ep,
int64_t time_out) {
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "SendMonomerFetchBarrierRPC";
const std::string method = kSendMonomerFetchBarrierRPC;
VarHandlePtr h(new VarHandle(ep, method, var_name, nullptr, nullptr));
s->Prepare(h, time_out);
......@@ -320,7 +335,7 @@ VarHandlePtr GRPCClient::AsyncSendComplete(const std::string& ep,
const auto ch = GetChannel(ep);
BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
const std::string method = "SendCompleteRPC";
const std::string method = kSendCompleteRPC;
VarHandlePtr h(new VarHandle(ep, method, COMPLETE_MESSAGE, nullptr, nullptr));
s->Prepare(h, time_out);
......@@ -347,7 +362,7 @@ VarHandlePtr GRPCClient::AsyncCheckpointNotify(const std::string& ep,
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
const std::string method = "CheckPointNotifyRPC";
const std::string method = kCheckPointNotifyRPC;
VarHandlePtr h(
new VarHandle(ep, method, CHECKPOINT_SAVE_MESSAGE, nullptr, nullptr));
......
......@@ -186,6 +186,13 @@ class GRPCClient : public RPCClient {
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) override;
VarHandlePtr AsyncGetMonomerVariable(
......@@ -228,11 +235,11 @@ class GRPCClient : public RPCClient {
void Proceed();
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
VarHandlePtr _AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name, const std::string& rpc,
int64_t time_out);
VarHandlePtr _AsyncGetVar(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& method,
const std::string& var_name, const std::string& out_varname,
const std::string& rpc_path, int64_t time_out = FLAGS_rpc_deadline);
private:
grpc::CompletionQueue cq_;
......
......@@ -136,17 +136,65 @@ class RequestGet final : public RequestBase {
void Process() override {
// proc request.
std::string varname = request_.varname();
std::string out_varname = request_.out_varname();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGet " << varname;
VLOG(4) << "RequestGet " << out_varname << " from " << varname;
auto scope = request_handler_->scope();
auto invar = scope->FindVar(varname);
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id);
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) {
SerializeToByteBuffer(varname, outvar, *request_handler_->dev_ctx(),
SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
&reply_);
}
Finish(reply_, &responder_);
}
protected:
sendrecv::VariableMessage request_;
::grpc::ByteBuffer reply_;
ServerAsyncResponseWriter<::grpc::ByteBuffer> responder_;
};
class RequestGetNoBarrier final : public RequestBase {
public:
explicit RequestGetNoBarrier(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
auto method_id =
static_cast<int>(distributed::GrpcMethod::kGetVariableNoBarrier);
service_->RequestAsyncUnary(
method_id, &ctx_, &request_, &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestGetNoBarrier() {}
std::string GetReqName() override { return request_.varname(); }
void Process() override {
// proc request.
std::string varname = request_.varname();
std::string out_varname = request_.out_varname();
int trainer_id = request_.trainer_id();
VLOG(4) << "RequestGetNoBarrier " << out_varname << " from " << varname;
auto scope = request_handler_->scope();
framework::Variable* invar = nullptr;
framework::Variable* outvar = nullptr;
request_handler_->Handle(varname, scope, invar, &outvar, trainer_id,
out_varname);
if (outvar) {
SerializeToByteBuffer(out_varname, outvar, *request_handler_->dev_ctx(),
&reply_);
}
Finish(reply_, &responder_);
......@@ -460,6 +508,9 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestSend(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGet) {
b = new RequestGet(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetNoBarrier) {
b = new RequestGetNoBarrier(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestGetMonomerVariable) {
b = new RequestGetMonomerVariable(&service_, cq.get(), handler, req_id,
this);
......
......@@ -81,6 +81,7 @@ enum class GrpcMethod {
kGetVariable,
kPrefetchVariable,
kCheckpointNotify,
kGetVariableNoBarrier,
kGetMonomerVariable,
kGetMonomerBarrier,
};
......@@ -94,6 +95,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/SendVariable";
case GrpcMethod::kGetVariable:
return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kGetVariableNoBarrier:
return "/sendrecv.SendRecvService/GetVariableNoBarrier";
case GrpcMethod::kGetMonomerVariable:
return "/sendrecv.SendRecvService/GetMonomerVariable";
case GrpcMethod::kGetMonomerBarrier:
......
......@@ -42,11 +42,24 @@ constexpr char kRequestGetMonomerBarrier[] = "RequestGetMonomerBarrier";
constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
constexpr char kRequestPassBarrier[] = "RequestPassBarrier";
constexpr char kRequestGetNoBarrier[] = "GetVariableNoBarrier";
constexpr char kSendRPC[] = "SendRPC";
constexpr char kGetRPC[] = "GetRPC";
constexpr char kGetNoBarrierRPC[] = "GetNoBarrierRPC";
constexpr char kGetMonomerRPC[] = "GetMonomerRPC";
constexpr char kPrefetchRPC[] = "PrefetchRPC";
constexpr char kBatchBarrierRPC[] = "BatchBarrierRPC";
constexpr char kFetchBarrierRPC[] = "FetchBarrierRPC";
constexpr char kSendMonomerFetchBarrierRPC[] = "SendMonomerFetchBarrierRPC";
constexpr char kSendCompleteRPC[] = "SendCompleteRPC";
constexpr char kCheckPointNotifyRPC[] = "CheckPointNotifyRPC";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV"
#define WITHOUT_BARRIER_MESSAGE "@WITHOUT_BARRIER@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
......
......@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/piece.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
......@@ -81,7 +82,8 @@ bool RequestGetHandler::Handle(const std::string& varname,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestGetHandler:" << varname;
VLOG(4) << "RequestGetHandler:" << varname
<< " out_var_name: " << out_var_name;
if (sync_mode_) {
if (varname == FETCH_BARRIER_MESSAGE) {
......@@ -112,6 +114,32 @@ bool RequestGetHandler::Handle(const std::string& varname,
return true;
}
bool RequestGetNoBarrierHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const int trainer_id,
const std::string& out_var_name,
const std::string& table_name) {
VLOG(4) << "RequestGetNoBarrierHandler:" << varname
<< " out_var_name: " << out_var_name;
// get var from pserver immediately without barriers
string::Piece without_barrier_piece(WITHOUT_BARRIER_MESSAGE);
string::Piece var_name_piece = string::Piece(varname);
if (string::Contains(var_name_piece, without_barrier_piece)) {
var_name_piece = string::TrimSuffix(var_name_piece, without_barrier_piece);
VLOG(4) << "Get var " << var_name_piece << " with "
<< WITHOUT_BARRIER_MESSAGE;
*outvar = scope_->FindVar(var_name_piece.ToString());
return true;
} else {
PADDLE_THROW("GetNoBarrier must contain %s", WITHOUT_BARRIER_MESSAGE);
}
return true;
}
bool RequestPrefetchHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
......
......@@ -67,6 +67,16 @@ class RequestGetHandler final : public RequestHandler {
bool enable_dc_asgd_;
};
class RequestGetNoBarrierHandler final : public RequestHandler {
public:
RequestGetNoBarrierHandler() : RequestHandler(false) {}
virtual ~RequestGetNoBarrierHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const int trainer_id, const std::string& out_var_name = "",
const std::string& table_name = "") override;
};
static inline void BuildVar(const std::string& param_name,
std::initializer_list<const char*> arguments,
paddle::framework::proto::OpDesc::Var* var) {
......
......@@ -43,6 +43,13 @@ class RPCClient {
const platform::DeviceContext& ctx,
const framework::Scope& scope,
const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetVarNoBarrier(
const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name,
const std::string& out_varname,
int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual VarHandlePtr AsyncGetMonomerVariable(
......
......@@ -17,8 +17,14 @@ package sendrecv;
option cc_generic_services = @cc_generic_services@;
service SendRecvService {
// For parameter server round-robin like hashing, do not split tensors.
// Send and recv only one tensor
// TODO(typhoonzero): add streaming API
rpc SendVariable(VariableMessage) returns (VoidMessage) {}
// Argument VariableMessage for GetVariable should only contain varname.
rpc GetVariable(VariableMessage) returns (VariableMessage) {}
rpc GetVariableNoBarrier(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
......@@ -27,12 +33,17 @@ service SendRecvService {
rpc GetMonomerBarrier(VariableMessage) returns (VoidMessage) {}
}
// It can be: LoDTensorSelectedRows or NCCL_ID
enum VarType {
LOD_TENSOR = 0;
SELECTED_ROWS = 1;
NCCL_ID = 2;
}
// VariableMessage is serialized paddle variable message.
// NOTICE(gongwb):don't modify this proto if you are not
// not familar with how we serialize in sendrecvop_utils.h
// and deserilize it in variable_response.h.
message VariableMessage {
enum Type {
// Pod Types
......@@ -49,14 +60,21 @@ message VariableMessage {
string varname = 1;
// TODO(Yancey1989): reference framework::proto::VarDesc::VarType
VarType type = 2;
// bool persistable is not needed for sending.
// tensor info:
Type data_type = 3;
repeated int64 dims = 4;
// lod details:
int64 lod_level = 5;
repeated LodData lod = 6;
// selected_rows height, aka. original dim0
int64 slr_height = 7;
// tensor data
bytes serialized = 8;
// selected_rows data
bytes rows = 9;
// Look up table block execution output variable name.
string out_varname = 10;
// If 1, the ps server will start profiling, the ps
// server stops profiling and generates a profile to /tmp/profile_ps_*
......
......@@ -347,6 +347,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
new distributed::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
sync_mode, checkpoint_block_id));
request_get_no_barrier_handler_.reset(
new distributed::RequestGetNoBarrierHandler());
rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get(),
......@@ -359,6 +361,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
FLAGS_rpc_prefetch_thread_num);
rpc_service_->RegisterRPC(distributed::kRequestCheckpoint,
request_checkpoint_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestGetNoBarrier,
request_get_no_barrier_handler_.get());
auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
......@@ -413,6 +417,7 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
f(request_get_handler_.get());
f(request_prefetch_handler_.get());
f(request_checkpoint_handler_.get());
f(request_get_no_barrier_handler_.get());
// start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_));
......
......@@ -55,7 +55,6 @@ class ListenAndServOp : public framework::OperatorBase {
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs);
virtual ~ListenAndServOp();
void RunSyncLoop(framework::Executor* executor,
......@@ -89,6 +88,8 @@ class ListenAndServOp : public framework::OperatorBase {
mutable std::shared_ptr<distributed::RPCServer> rpc_service_;
mutable std::shared_ptr<distributed::RequestHandler> request_send_handler_;
mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_get_no_barrier_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_prefetch_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
......
......@@ -27,34 +27,54 @@ namespace operators {
class RecvOp : public framework::OperatorBase {
public:
RecvOp(const std::string& type, const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
RecvOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
auto outs = Outputs("Out");
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames");
int sync_mode = Attr<int>("sync_mode");
auto outs = Outputs("Out");
bool with_barrier = Attr<bool>("with_barrier");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &ctx = *pool.Get(place);
distributed::RPCClient* rpc_client =
distributed::RPCClient *rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>(
Attr<int>("trainer_id"));
if (with_barrier) {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rets.push_back(rpc_client->AsyncGetVar(epmap[i], ctx, scope, outs[i]));
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVar";
rets.push_back(
rpc_client->AsyncGetVar(epmap[i], ctx, scope, varname, outs[i]));
}
if (sync_mode) {
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
} else {
std::vector<distributed::VarHandlePtr> rets;
for (size_t i = 0; i < outs.size(); i++) {
std::string varname = varnames.size() == 0 ? outs[i] : varnames[i];
VLOG(4) << "recv " << outs[i] << " from " << epmap[i] << " with "
<< varname << " and with AsyncGetVarNoBarrier";
rets.push_back(rpc_client->AsyncGetVarNoBarrier(epmap[i], ctx, scope,
varname, outs[i]));
}
for (size_t i = 0; i < rets.size(); i++) {
PADDLE_ENFORCE(rets[i]->Wait(), "internal error in RPCClient");
}
}
}
};
......@@ -79,12 +99,23 @@ This operator can get variables from server side.
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
AddAttr<bool>("with_barrier",
"(bool, default True) if with_barrier=False, will use "
"AsyncGetVarNoBarrier get variable from pserver immediately")
.SetDefault(true);
AddAttr<std::vector<std::string>>(
"varnames",
"(string vector, default {}) "
"sometimes we need to put received var in another name "
"for example: we need var named 'moment_1@127.0.0.1:1001', "
"and it real name on parameter server is 'moment_1'. ")
.SetDefault({});
}
};
class RecvOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
void operator()(framework::InferShapeContext *ctx) const override {}
};
} // namespace operators
......
......@@ -1696,12 +1696,20 @@ class Program(object):
self._current_role = core.op_proto_and_checker_maker.OpRole.Forward
self._op_role_var = []
# for distribute
# for distribute training
# _is_distributed = True if under distributed training
self._is_distributed = False
# _is_chief = True if the trainer is the first one, usually No.0
self._is_chief = False
self._slice_vars_and_attrs = []
# _parameters_on_pservers records all the parameters distributed on parameter servers.
self._parameters_on_pservers = None
# _endpoints is a list about parameter servers ip:port, such as ["ip:port","ip:port"]
self._endpoints = []
# if current role is parameter server, the _ps_endpoint is its "ip:port"
self._ps_endpoint = None
# trainers_endpoints, it is used for distribution.
self._trainers_endpoints = []
# the distributed lookup table names
self._distributed_lookup_table = None
@property
......@@ -2232,8 +2240,9 @@ class Program(object):
"Program")
self._is_distributed = other._is_distributed
self._is_chief = other._is_chief
self._slice_vars_and_attrs = other._slice_vars_and_attrs
self._parameters_on_pservers = other._parameters_on_pservers
self._endpoints = other._endpoints
self._ps_endpoint = other._ps_endpoint
self._distributed_lookup_table = other._distributed_lookup_table
def _copy_data_info_from(self, other):
......
......@@ -19,6 +19,7 @@ import errno
import time
import shutil
import six
from functools import reduce
from paddle.fluid.executor import Executor
from paddle.fluid.evaluator import Evaluator
......@@ -183,8 +184,6 @@ def save_vars(executor,
# NOTE: don't save the variable which type is RAW
if each_var.type == core.VarDesc.VarType.RAW:
continue
if each_var.name == main_program._distributed_lookup_table:
continue
new_var = _clone_var_in_block_(save_block, each_var)
if filename is None:
save_block.append_op(
......@@ -206,16 +205,6 @@ def save_vars(executor,
outputs={},
attrs={'file_path': os.path.join(dirname, filename)})
# if there is lookup table, the trainer 0 will notify all pserver to save.
if main_program._is_distributed and main_program._is_chief and main_program._distributed_lookup_table:
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
attrs = {}
attrs['epmap'] = main_program._endpoints
attrs['dir'] = lookup_table_filename
attrs['lookup_table'] = main_program._distributed_lookup_table
save_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(save_program)
......@@ -267,6 +256,186 @@ def save_params(executor, dirname, main_program=None, filename=None):
filename=filename)
def _save_distributed_persistables(executor, dirname, main_program):
"""
save_persistables for distributed training.
the method will do things listed below:
1.save part of persistable variables on trainer.
2.receive "remote prefetch variables" from parameter servers and merge them.
3.save "distributed lookup table" on parameter servers.
4.receive "optimizer variables" from parameter servers and merge them.
Args:
executor(Executor): The executor to run for saving parameters.
dirname(str): The saving directory path.
main_program(Program): The program whose parameters will be
saved. the main_program must be the trainer_program
get after transpiler.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
t = distribute_transpiler.DistributeTranspiler()
t.transpile(...)
train_program = t.get_trainer_program()
_save_distributed_persistables(executor=exe, dirname=param_path, main_program=train_program)
"""
def __save_remote_params(executor, dirname, remote_params_map):
"""
recive params on pserver through rpc.
if the params are be sliced, will concat them to one, then save it.
"""
if not remote_params_map:
return
prog = Program()
block = prog.global_block()
# recv optimize vars from pserver
for name, remote_params in remote_params_map.items():
origin_var = None
is_slice = False
slice_vars = [0] * len(remote_params)
slice_var_names = [""] * len(remote_params)
endpoints = [""] * len(remote_params)
for idx, optimizer in enumerate(remote_params):
origin = optimizer.origin
slice = optimizer.slice
is_slice = optimizer.is_slice
block_id = optimizer.block_id
endpoint = optimizer.endpoint
if idx == 0:
origin_var = block.create_var(
name=origin.name,
type=origin.type,
shape=origin.shape,
dtype=origin.dtype,
persistable=True)
slice_var = block.create_var(
name="{}.slice.{}".format(slice.name, idx),
type=slice.type,
shape=slice.shape,
dtype=slice.dtype,
persistable=True)
index = block_id if is_slice else idx
slice_vars[index] = slice_var
slice_var_names[index] = slice.name
endpoints[index] = endpoint
if is_slice:
block.append_op(
type='recv',
inputs={"X": []},
outputs={"Out": slice_vars},
attrs={
"epmap": endpoints,
"with_barrier": False,
"varnames": slice_var_names,
"sync_mode": True
})
block.append_op(
type='concat',
inputs={'X': slice_vars},
outputs={'Out': origin_var},
attrs={})
else:
block.append_op(
type='recv',
inputs={"X": []},
outputs={"Out": [origin_var]},
attrs={
"epmap": endpoints[:1],
"with_barrier": False,
"varnames": slice_var_names,
"sync_mode": True
})
block.append_op(
type='save',
inputs={'X': [origin_var]},
outputs={},
attrs={'file_path': os.path.join(dirname, origin_var.name)})
block.append_op(type='delete_var', inputs={'X': slice_vars})
executor.run(prog)
def __save_distributed_lookup_tables(executor, dirname,
distributed_lookup_table, endpoints):
"""
because the distributed lookup table may too huge to merge and save at one place,
it will be saved at parameter server independent respectively.
the save directory is dirname/"__lookup_table__".
"""
prog = Program()
block = prog.global_block()
# if there is lookup table, the trainer 0 will notify all pserver to save.
lookup_table_filename = os.path.join(dirname, "__lookup_table__")
attrs = {}
attrs['epmap'] = endpoints
attrs['dir'] = lookup_table_filename
attrs['lookup_table'] = distributed_lookup_table
block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(prog)
def __exclude_vars(exclude_var_names=[]):
def is_valid(var):
if var.name in exclude_var_names:
return False
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
var.desc.type() == core.VarDesc.VarType.READER:
return False
return var.persistable
return is_valid
if not isinstance(main_program, Program):
raise ValueError("'main_program' should be an instance of Program.")
if not main_program._is_distributed:
raise ValueError(
"'_save_distributed_persistables' just be designed for distributed training."
)
remote_params_map = main_program._parameters_on_pservers.get_distributed_vars_by_vtypes(
["Optimizer", "RemotePrefetch"], groupby=True)
exclude_var_names = []
if remote_params_map:
exclude_var_names.extend(remote_params_map.keys())
if main_program._distributed_lookup_table:
if isinstance(main_program._distributed_lookup_table, list):
exclude_var_names.extend(main_program._distributed_lookup_table)
else:
exclude_var_names.append(main_program._distributed_lookup_table)
local_vars = list(
filter(__exclude_vars(exclude_var_names), main_program.list_vars()))
save_vars(
executor, main_program=main_program, dirname=dirname, vars=local_vars)
if main_program._is_chief:
if remote_params_map:
__save_remote_params(executor, dirname, remote_params_map)
if main_program._distributed_lookup_table:
__save_distributed_lookup_tables(
executor, dirname, main_program._distributed_lookup_table,
main_program._endpoints)
def save_persistables(executor, dirname, main_program=None, filename=None):
"""
This function filters out all variables with `persistable==True` from the
......@@ -301,6 +470,12 @@ def save_persistables(executor, dirname, main_program=None, filename=None):
fluid.io.save_persistables(executor=exe, dirname=param_path,
main_program=None)
"""
if main_program and main_program._is_distributed:
_save_distributed_persistables(
executor, dirname=dirname, main_program=main_program)
else:
save_vars(
executor,
dirname=dirname,
......@@ -402,17 +577,11 @@ def load_vars(executor,
if not isinstance(main_program, Program):
raise TypeError("program should be as Program type or None")
load_slice_vars = []
for each_var in main_program._slice_vars_and_attrs:
load_slice_vars.append(each_var[2].name)
load_var_map = {}
for each_var in vars:
assert isinstance(each_var, Variable)
if each_var.type == core.VarDesc.VarType.RAW:
continue
if each_var.name in load_slice_vars:
continue
new_var = _clone_var_in_block_(load_block, each_var)
if filename is None:
load_block.append_op(
......@@ -435,10 +604,6 @@ def load_vars(executor,
attrs={'file_path': os.path.join(dirname, filename)})
executor.run(load_prog)
# load slice vars on pserver, if have it.
_load_slice_up_vars(executor, dirname,
main_program._slice_vars_and_attrs)
def load_params(executor, dirname, main_program=None, filename=None):
"""
......@@ -521,6 +686,11 @@ def load_persistables(executor, dirname, main_program=None, filename=None):
fluid.io.load_persistables(executor=exe, dirname=param_path,
main_program=None)
"""
if main_program and main_program._is_distributed:
_load_distributed_persistables(
executor, dirname=dirname, main_program=main_program)
else:
load_vars(
executor,
dirname=dirname,
......@@ -529,6 +699,123 @@ def load_persistables(executor, dirname, main_program=None, filename=None):
filename=filename)
def _load_distributed_persistables(executor, dirname, main_program=None):
"""
customized load_persistables for distributed training.
it should be used on parameter server,
Args:
executor(Executor): The executor to run for saving parameters.
dirname(str): The load directory path.
main_program(Program): The program whose parameters will be
loaded. the main_program must be the pserver_program
get after transpiler.
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
t = distribute_transpiler.DistributeTranspiler()
t.transpile(...)
pserver_prog = t.get_pserver_program(...)
_load_distributed_persistables(executor=exe, dirname=param_path, main_program=pserver_prog)
"""
def __is_distributed_part_var(varname):
trainer_idx = varname.find(".trainer_")
block_idx = varname.find(".block")
return trainer_idx or block_idx
def __load_persistable_vars(executor, dirname, need_load_vars):
load_prog = Program()
load_block = load_prog.global_block()
need_delete_vars = []
for param in need_load_vars:
origin_var = param.origin
slice_var = param.slice
is_slice = param.is_slice
offset = param.offset
if is_slice:
origin = load_block.create_var(
name="{}.load".format(origin_var.name),
type=origin_var.type,
shape=origin_var.shape,
dtype=origin_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [origin]},
attrs={
'file_path': os.path.join(dirname, origin_var.name)
})
slice = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
dim1_flatten = reduce(lambda x, y: x * y, slice.shape[1:])
start = int(offset / dim1_flatten)
end = int(offset / dim1_flatten + slice.shape[0])
load_block.append_op(
type="slice",
inputs={'Input': origin},
outputs={'Out': slice},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
need_delete_vars.append(origin)
else:
origin = load_block.create_var(
name="{}".format(origin_var.name),
type=origin_var.type,
shape=origin_var.shape,
dtype=origin_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [origin]},
attrs={
'file_path': os.path.join(dirname, origin_var.name)
})
load_block.append_op(
type='delete_var',
inputs={'X': need_delete_vars}, )
executor.run(load_prog)
if not isinstance(main_program, Program):
raise ValueError("'main_program' should be an instance of Program.")
if not main_program._is_distributed:
raise ValueError(
"'_load_distributed_persistables' just be designed for distributed training."
)
if not main_program._ps_endpoint:
raise ValueError(
"'_load_distributed_persistables' need current_endpoint set in DistributeTranspiler.transpile"
)
need_load_vars = main_program._parameters_on_pservers.get_distributed_vars_by_ep(
main_program._ps_endpoint)
__load_persistable_vars(executor, dirname, need_load_vars)
def prepend_feed_ops(inference_program,
feed_target_names,
feed_holder_name='feed'):
......@@ -795,52 +1082,6 @@ def load_inference_model(dirname,
return [program, feed_target_names, fetch_targets]
def _save_lookup_tables_by_notify(executor, dirname, lookup_table,
pserver_endpoints):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
_save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
pserver_endpoints=ps_endpoints)
"""
pserver_notify_program = Program()
pserver_notify_block = pserver_notify_program.global_block()
attrs = {}
attrs['epmap'] = pserver_endpoints
attrs['dir'] = dirname
attrs['lookup_table'] = lookup_table
pserver_notify_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(pserver_notify_program)
def _endpoints_replacement(program, endpoints):
ENDPOINT_MAP = "epmap"
for op in program.global_block().ops:
......@@ -911,54 +1152,3 @@ def get_parameter_value_by_name(name, executor, program=None):
program = default_main_program()
var = program.global_block().var(name)
return get_parameter_value(var, executor)
def _load_slice_up_vars(executor, dirname, slice_vars_and_attrs):
if not slice_vars_and_attrs:
return
load_prog = Program()
load_block = load_prog.global_block()
need_delete_vars = []
for var_tuple in slice_vars_and_attrs:
orig_var = var_tuple[0]
start = var_tuple[1]
slice_var = var_tuple[2]
end = start + slice_var.shape[0]
orig_var_name = orig_var.name
orig_var.name = "{}.origin".format(orig_var_name)
clone_orig_var = load_block.create_var(
name=orig_var.name,
type=orig_var.type,
shape=orig_var.shape,
dtype=orig_var.dtype,
persistable=True)
clone_slice_var = load_block.create_var(
name=slice_var.name,
type=slice_var.type,
shape=slice_var.shape,
dtype=slice_var.dtype,
persistable=True)
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [clone_orig_var]},
attrs={'file_path': os.path.join(dirname, orig_var_name)})
load_block.append_op(
type="slice",
inputs={'Input': clone_orig_var},
outputs={'Out': clone_slice_var},
attrs={'axes': [0],
'starts': [start],
'ends': [end]})
need_delete_vars.append(clone_orig_var)
load_block.append_op(
type='delete_var',
inputs={'X': need_delete_vars}, )
executor.run(load_prog)
......@@ -80,7 +80,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
# NOTE: pserver should not call memory optimize
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), args.endpoints,
args.trainers, args.sync_mode)
args.trainers, args.sync_mode, False,
args.current_endpoint)
pserver_prog = t.get_pserver_program(args.current_endpoint)
startup_prog = t.get_startup_program(args.current_endpoint,
pserver_prog)
......@@ -93,7 +94,8 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
exe.run(startup_prog)
if need_load and model_dir:
self._load_persistable_vars(exe, model_dir, startup_prog)
fluid.io.load_persistables(exe, model_dir, pserver_prog)
exe.run(pserver_prog)
def run_trainer(self, args):
......@@ -158,7 +160,9 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
need_save = bool(int(os.getenv("SAVE", "0")))
model_dir = os.getenv("MODEL_DIR", "")
save_mode = os.getenv("SAVE_MODE", "")
if save_mode == "LOCAL":
if need_save:
for _ in six.moves.xrange(RUN_STEP):
loss, = exe.run(fetch_list=[avg_cost.name],
......@@ -166,12 +170,37 @@ class TestDistSaveLoad2x2(TestDistSimnetBow2x2):
if need_save and model_dir:
io.save_persistables(startup_exe, model_dir, trainer_prog)
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor())
var = np.array(fluid.global_scope().find_var('__fc_b__').get_tensor(
))
if six.PY2:
print(pickle.dumps(np.ravel(var).tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(np.ravel(var).tolist()))
elif save_mode == "DIST":
skip_steps = int(os.getenv("SKIP_STEPS"))
loss = None
if need_save:
for idx in six.moves.xrange(8):
loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(get_data()))
if need_save and model_dir and idx == skip_steps and args.trainer_id == 0:
io.save_persistables(startup_exe, model_dir,
trainer_prog)
else:
for idx in six.moves.xrange(8):
data = get_data()
if idx <= skip_steps:
continue
loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(data))
if six.PY2:
print(pickle.dumps(loss.tolist()))
else:
sys.stdout.buffer.write(pickle.dumps(loss.tolist()))
else:
raise Exception("save_mode must be LOCAL or DIST")
if __name__ == "__main__":
paddle.dataset.common.download(DATA_URL, 'simnet', DATA_MD5, "train")
......
......@@ -75,8 +75,12 @@ def get_loss(cos_q_pt, cos_q_nt):
return avg_cost
def get_optimizer():
# SGD optimizer
def get_optimizer(op="sgd"):
if op.upper() == "sgd".upper():
optimizer = fluid.optimizer.SGD(learning_rate=base_lr)
elif op.upper() == "adam".upper():
optimizer = fluid.optimizer.Adam(learning_rate=base_lr)
else:
optimizer = fluid.optimizer.SGD(learning_rate=base_lr)
return optimizer
......@@ -237,7 +241,8 @@ class TestDistSimnetBow2x2(TestDistRunnerBase):
inference_program = fluid.default_main_program().clone()
# Optimization
opt = get_optimizer()
opt = os.getenv('OPTIMIZER', 'sgd')
opt = get_optimizer(opt)
opt.minimize(avg_cost)
# Reader
......
......@@ -43,7 +43,8 @@ class TestDistRunnerBase(object):
pserver_endpoints,
trainers,
sync_mode,
dc_asgd=False):
dc_asgd=False,
current_endpoint=None):
# NOTE: import fluid until runtime, or else forking processes will cause error.
config = fluid.DistributeTranspilerConfig()
config.enable_dc_asgd = dc_asgd
......@@ -53,7 +54,8 @@ class TestDistRunnerBase(object):
program=main_program,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=sync_mode)
sync_mode=sync_mode,
current_endpoint=current_endpoint)
return t
def run_pserver(self, args):
......
......@@ -33,7 +33,6 @@ class TestDistSaveLoadDense2x2(TestDistBase):
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
......@@ -77,7 +76,77 @@ class TestDistSaveLoadDense2x2(TestDistBase):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '0',
'IS_SELF_CONTAINED_LR': '1'
'IS_SELF_CONTAINED_LR': '1',
'SAVE_MODE': 'LOCAL',
}
self.check_with_place(
"dist_save_load.py",
delta=0,
check_error_log=False,
need_envs=need_envs)
class TestDistSaveLoadWithPServerStateDense2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"http_proxy": ""
}
required_envs.update(need_envs)
if check_error_log:
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
model_dir = tempfile.mkdtemp()
save_env = {}
save_env["SAVE_MODE"] = "DIST"
save_env["SAVE"] = "1"
save_env["MODEL_DIR"] = model_dir
save_env.update(required_envs)
tr0_var_1, tr1_var_1 = self._run_cluster(model_file, save_env,
check_error_log)
load_env = {}
load_env["LOAD"] = "1"
load_env["MODEL_DIR"] = model_dir
load_env.update(required_envs)
tr0_var_2, tr1_var_2 = self._run_cluster(model_file, load_env,
check_error_log)
shutil.rmtree(model_dir)
train0_1_np = np.array(tr0_var_1)
train1_1_np = np.array(tr1_var_1)
train0_2_np = np.array(tr0_var_2)
train1_2_np = np.array(tr1_var_2)
self.assertAlmostEqual(
train0_1_np.all(), train0_2_np.all(), delta=delta)
self.assertAlmostEqual(
train1_1_np.all(), train1_2_np.all(), delta=delta)
def test_dist(self):
need_envs = {
"IS_DISTRIBUTED": '0',
"IS_SPARSE": '0',
'IS_SELF_CONTAINED_LR': '1',
'SAVE_MODE': 'DIST',
'OPTIMIZER': 'ADAM',
'SKIP_STEPS': str(np.random.randint(2, 6))
}
self.check_with_place(
"dist_save_load.py",
......
......@@ -741,21 +741,40 @@ class TestLoadSliceVar(TranspilerTest):
pserver, _ = self.get_pserver(self.pserver1_ep)
pserver2, _ = self.get_pserver(self.pserver2_ep)
self.assertTrue(pserver._slice_vars_and_attrs)
self.assertTrue(pserver2._slice_vars_and_attrs)
for idx in six.moves.xrange(len(pserver._slice_vars_and_attrs)):
self.assertEqual(pserver._slice_vars_and_attrs[idx][0],
pserver2._slice_vars_and_attrs[idx][0])
total_numel = six.moves.reduce(
lambda x, y: x * y, pserver._slice_vars_and_attrs[idx][0].shape)
self.assertEqual(
total_numel,
six.moves.reduce(lambda x, y: x * y,
pserver._slice_vars_and_attrs[idx][2].shape) +
six.moves.reduce(lambda x, y: x * y,
pserver2._slice_vars_and_attrs[idx][2].shape))
vars_ps1 = pserver._parameters_on_pservers.get_distributed_vars_by_ep(
self.pserver1_ep)
vars_ps2 = pserver._parameters_on_pservers.get_distributed_vars_by_ep(
self.pserver2_ep)
self.assertTrue(vars_ps1)
self.assertTrue(vars_ps2)
for idx in six.moves.xrange(len(vars_ps1)):
total_numel = 0
ps1_numel, ps2_numel = 0, 0
ps1_var = vars_ps1[idx]
if not ps1_var.is_slice:
total_numel = six.moves.reduce(lambda x, y: x * y,
vars_ps1[idx].origin.shape)
ps1_numel = six.moves.reduce(lambda x, y: x * y,
vars_ps1[idx].slice.shape)
else:
ps2_var = None
for var in vars_ps2:
if var.origin.name == ps1_var.origin.name:
ps2_var = var
break
total_numel = six.moves.reduce(lambda x, y: x * y,
ps1_var.origin.shape)
ps1_numel = six.moves.reduce(lambda x, y: x * y,
ps1_var.slice.shape)
ps2_numel = six.moves.reduce(lambda x, y: x * y,
ps2_var.slice.shape)
self.assertEqual(total_numel, ps1_numel + ps2_numel)
class TestNCCL2Transpile(TranspilerTest):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册