提交 88cb47bd 编写于 作者: Y yi.wu

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_paraexe_bcast

...@@ -114,7 +114,12 @@ INCLUDE_DIRECTORIES(${CBLAS_INC_DIR}) ...@@ -114,7 +114,12 @@ INCLUDE_DIRECTORIES(${CBLAS_INC_DIR})
SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c) SET(dummyfile ${CMAKE_CURRENT_BINARY_DIR}/cblas_dummy.c)
FILE(WRITE ${dummyfile} "const char *dummy_cblas = \"${dummyfile}\";") FILE(WRITE ${dummyfile} "const char *dummy_cblas = \"${dummyfile}\";")
ADD_LIBRARY(cblas STATIC ${dummyfile}) ADD_LIBRARY(cblas STATIC ${dummyfile})
TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES})
IF("${CBLAS_PROVIDER}" STREQUAL "MKLML")
TARGET_LINK_LIBRARIES(cblas dynload_mklml)
ELSE()
TARGET_LINK_LIBRARIES(cblas ${CBLAS_LIBRARIES})
ENDIF("${CBLAS_PROVIDER}" STREQUAL "MKLML")
IF(NOT ${CBLAS_FOUND}) IF(NOT ${CBLAS_FOUND})
ADD_DEPENDENCIES(cblas extern_openblas) ADD_DEPENDENCIES(cblas extern_openblas)
......
...@@ -195,6 +195,15 @@ function(cc_library TARGET_NAME) ...@@ -195,6 +195,15 @@ function(cc_library TARGET_NAME)
list(REMOVE_ITEM cc_library_DEPS warpctc) list(REMOVE_ITEM cc_library_DEPS warpctc)
add_dependencies(${TARGET_NAME} warpctc) add_dependencies(${TARGET_NAME} warpctc)
endif() endif()
# Only deps libmklml.so, not link
if("${cc_library_DEPS};" MATCHES "mklml;")
list(REMOVE_ITEM cc_library_DEPS mklml)
if(NOT "${TARGET_NAME}" MATCHES "dynload_mklml")
list(APPEND cc_library_DEPS dynload_mklml)
endif()
add_dependencies(${TARGET_NAME} mklml)
target_link_libraries(${TARGET_NAME} "-L${MKLML_LIB_DIR} -liomp5 -Wl,--as-needed")
endif()
target_link_libraries(${TARGET_NAME} ${cc_library_DEPS}) target_link_libraries(${TARGET_NAME} ${cc_library_DEPS})
add_dependencies(${TARGET_NAME} ${cc_library_DEPS}) add_dependencies(${TARGET_NAME} ${cc_library_DEPS})
endif() endif()
......
# Inference High-level APIs # Inference High-level APIs
This document describes the high-level inference APIs one can use to easily deploy a Paddle model for an application. This document describes the high-level inference APIs, one can use them to deploy a Paddle model for an application quickly.
The APIs are described in `paddle_inference_api.h`, just one header file, and two libaries `libpaddle_fluid.so` and `libpaddle_fluid_api.so` are needed. The APIs are described in `paddle_inference_api.h`, just one header file, and two libaries `libpaddle_fluid.so` and `libpaddle_fluid_api.so` are needed for a deployment.
## PaddleTensor ## PaddleTensor
We provide the `PaddleTensor` data structure is to give a general tensor interface. We provide the `PaddleTensor` data structure to give a general tensor interface.
The definition is The definition is
...@@ -17,18 +17,19 @@ struct PaddleTensor { ...@@ -17,18 +17,19 @@ struct PaddleTensor {
}; };
``` ```
The data is stored in a continuous memory `PaddleBuf`, and tensor's data type is specified by a `PaddleDType`. The data is stored in a continuous memory `PaddleBuf,` and a `PaddleDType` specifies tensor's data type.
The `name` field is used to specify the name of input variable, The `name` field is used to specify the name of an input variable,
that is important when there are multiple inputs and need to distiuish which variable to set. that is important when there are multiple inputs and need to distinguish which variable to set.
## engine ## engine
The inference APIs has two different underlying implementation, currently there are two valid engines: The inference APIs has two different underlying engines
- the native engine, which is consists of the native operators and framework, - the native engine, which is consists of the native operators and framework,
- the Anakin engine, which is a Anakin library embeded. - the Anakin engine, which has an Anakin library embedded.
The native engine takes a native Paddle model as input, and supports any model that trained by Paddle, The native engine takes a native Paddle model as input, and supports any model that trained by Paddle,
but the Anakin engine can only take the Anakin model as input(user need to manully transform the format first) and currently not all Paddle models are supported. the Anakin engine is faster for some model,
but it can only take the Anakin model as input(user need to transform the format first manually) and currently not all Paddle models are supported.
```c++ ```c++
enum class PaddleEngineKind { enum class PaddleEngineKind {
...@@ -38,10 +39,10 @@ enum class PaddleEngineKind { ...@@ -38,10 +39,10 @@ enum class PaddleEngineKind {
``` ```
## PaddlePredictor and how to create one ## PaddlePredictor and how to create one
The main interface is `PaddlePredictor`, there are following methods The main interface is `PaddlePredictor,` there are following methods
- `bool Run(const std::vector<PaddleTensor>& inputs, std::vector<PaddleTensor>* output_data)` - `bool Run(const std::vector<PaddleTensor>& inputs, std::vector<PaddleTensor>* output_data)`
- take inputs and output `output_data` - take inputs and output `output_data.`
- `Clone` to clone a predictor from an existing one, with model parameter shared. - `Clone` to clone a predictor from an existing one, with model parameter shared.
There is a factory method to help create a predictor, and the user takes the ownership of this object. There is a factory method to help create a predictor, and the user takes the ownership of this object.
...@@ -51,9 +52,9 @@ template <typename ConfigT, PaddleEngineKind engine = PaddleEngineKind::kNative> ...@@ -51,9 +52,9 @@ template <typename ConfigT, PaddleEngineKind engine = PaddleEngineKind::kNative>
std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT& config); std::unique_ptr<PaddlePredictor> CreatePaddlePredictor(const ConfigT& config);
``` ```
By specifying the engine kind and config, one can get an specific implementation. By specifying the engine kind and config, one can get a specific implementation.
## Reference ## Reference
- [paddle_inference_api.h](./paddle_inference_api.h) - [paddle_inference_api.h](./paddle_inference_api.h)
- [demos](./demo) - [some demos](./demo)
...@@ -207,14 +207,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -207,14 +207,16 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
is_forwarding = false; is_forwarding = false;
} else { } else {
int op_dev_id = GetOpDeviceID(*op); int op_dev_id = GetOpDeviceID(*op);
if (op_dev_id == -1) { // var on all device if (op_dev_id != -1) { // This op only runs on one specific device.
CreateComputationalOps(&result, *op, places_.size());
} else {
CreateComputationalOp(&result, *op, op_dev_id); CreateComputationalOp(&result, *op, op_dev_id);
for (auto &var_name : op->OutputArgumentNames()) { for (auto &var_name : op->OutputArgumentNames()) {
var_name_on_devices_.emplace(var_name, op_dev_id); var_name_on_devices_.emplace(var_name, op_dev_id);
} }
} } else {
// This op runs on all devices, and its output may have parameter's
// gradients.
CreateComputationalOps(&result, *op, places_.size());
if (!is_forwarding && places_.size() > 1) { if (!is_forwarding && places_.size() > 1) {
// Currently, we assume that once gradient is generated, it can be // Currently, we assume that once gradient is generated, it can be
// broadcast, and each gradient is only broadcast once. // broadcast, and each gradient is only broadcast once.
...@@ -259,6 +261,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build( ...@@ -259,6 +261,7 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
} }
} }
} }
}
// Insert BCast Ops // Insert BCast Ops
for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) {
......
...@@ -19,8 +19,8 @@ limitations under the License. */ ...@@ -19,8 +19,8 @@ limitations under the License. */
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/inference/tests/test_helper.h" #include "paddle/fluid/inference/tests/test_helper.h"
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include <mkl_service.h>
#include <omp.h> #include <omp.h>
#endif #endif
...@@ -164,7 +164,7 @@ TEST(inference, nlp) { ...@@ -164,7 +164,7 @@ TEST(inference, nlp) {
// only use 1 thread number per std::thread // only use 1 thread number per std::thread
omp_set_dynamic(0); omp_set_dynamic(0);
omp_set_num_threads(1); omp_set_num_threads(1);
mkl_set_num_threads(1); paddle::operators::math::SetNumThreads(1);
#endif #endif
double start_ms = 0, stop_ms = 0; double start_ms = 0, stop_ms = 0;
......
...@@ -195,7 +195,7 @@ if(WITH_DISTRIBUTE) ...@@ -195,7 +195,7 @@ if(WITH_DISTRIBUTE)
endif() endif()
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor") set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
foreach(dist_op "prefetch_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op") foreach(dist_op "prefetch_op" "checkpoint_notify_op" "listen_and_serv_op" "send_op" "recv_op" "send_barrier_op" "fetch_barrier_op")
op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS}) op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS})
set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
endforeach() endforeach()
...@@ -216,7 +216,7 @@ if(WITH_DISTRIBUTE) ...@@ -216,7 +216,7 @@ if(WITH_DISTRIBUTE)
set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op) set(DEPS_OPS ${DEPS_OPS} gen_nccl_id_op)
endif() endif()
else() else()
set(DEPS_OPS ${DEPS_OPS} prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op) set(DEPS_OPS ${DEPS_OPS} checkpoint_notify_op prefetch_op recv_op listen_and_serv_op send_op send_barrier_op fetch_barrier_op gen_nccl_id_op)
endif() endif()
op_library(cross_entropy_op DEPS cross_entropy) op_library(cross_entropy_op DEPS cross_entropy)
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <future> // NOLINT
#include <ostream>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/macros.h"
#include "paddle/fluid/operators/send_recv_util.h"
#include "paddle/fluid/string/printf.h"
namespace paddle {
namespace operators {
class CheckpointNotifyOp : public framework::OperatorBase {
public:
CheckpointNotifyOp(const std::string& type,
const framework::VariableNameMap& inputs,
const framework::VariableNameMap& outputs,
const framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope& scope,
const platform::Place& place) const override {
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::string dir = Attr<std::string>("dir");
std::string lookup_table_name = Attr<std::string>("lookup_table");
distributed::RPCClient* rpc_client =
distributed::RPCClient::GetInstance<RPCCLIENT_T>();
for (size_t i = 0; i < epmap.size(); i++) {
auto lookup_table_save_dir =
string::Sprintf("%s/%s_%d", dir, lookup_table_name, i);
rpc_client->AsyncCheckpointNotify(epmap[i], lookup_table_save_dir);
VLOG(3) << "checkpoint notify sending lookup table: " << lookup_table_name
<< " and dir:" << dir << " to " << epmap[i];
}
rpc_client->Wait();
}
};
class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() {
AddAttr<std::vector<std::string>>("epmap",
"(string vector, default 127.0.0.1:6164)"
"Parameter Server endpoints in the order")
.SetDefault({"127.0.0.1:6164"});
AddAttr<std::string>(
"dir", "(string, default '') indicate the folder checkpoint will use");
AddAttr<std::string>("lookup_table",
"(string, default '') the lookup table name");
AddComment(R"DOC(
CheckpointNotify operator
This operator will send lookup table and it's checkpoint direcoty to listen_and_serve op at
the parameter server.
)DOC");
}
};
class CheckpointNotifyOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext* ctx) const override {}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(checkpoint_notify, ops::CheckpointNotifyOp,
paddle::framework::EmptyGradOpMaker,
ops::CheckpointNotifyOpMaker,
ops::CheckpointNotifyOpShapeInference);
...@@ -55,26 +55,24 @@ class BRPCClient : public RPCClient { ...@@ -55,26 +55,24 @@ class BRPCClient : public RPCClient {
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name, const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name, const framework::Scope& scope, const std::string& var_name,
int64_t time_out = RPCClient::rpc_time_out) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep, bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
int64_t time_out = RPCClient::rpc_time_out) override; int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBatchBarrier( void AsyncSendBatchBarrier(const std::string& ep,
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
int64_t time_out = RPCClient::rpc_time_out) override;
void AsyncSendFetchBarrier( void AsyncSendFetchBarrier(const std::string& ep,
const std::string& ep, int64_t time_out = FLAGS_rpc_deadline) override;
int64_t time_out = RPCClient::rpc_time_out) override;
void Wait() override; void Wait() override;
......
...@@ -239,6 +239,23 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) { ...@@ -239,6 +239,23 @@ void GRPCClient::AsyncSendComplete(const std::string& ep, int64_t time_out) {
req_count_++; req_count_++;
} }
void GRPCClient::AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out) {
const auto ch = GetChannel(ep);
CheckpointNotifyProcessor* s = new CheckpointNotifyProcessor(ch);
s->Prepare(time_out);
sendrecv::VariableMessage req;
req.set_varname(CHECKPOINT_SAVE_MESSAGE);
req.set_out_varname(dir);
auto rpc = s->stub_->AsyncCheckpointNotify(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, reinterpret_cast<void*>(s));
req_count_++;
}
void GRPCClient::Wait() { void GRPCClient::Wait() {
std::unique_lock<std::mutex> lk(sync_mutex_); std::unique_lock<std::mutex> lk(sync_mutex_);
sync_cond_.wait(lk, [this] { return req_count_ == 0; }); sync_cond_.wait(lk, [this] { return req_count_ == 0; });
......
...@@ -171,6 +171,20 @@ class FetchBarrierProcessor : public BaseProcessor { ...@@ -171,6 +171,20 @@ class FetchBarrierProcessor : public BaseProcessor {
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_; std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
}; };
class CheckpointNotifyProcessor : public BaseProcessor {
public:
explicit CheckpointNotifyProcessor(std::shared_ptr<grpc::Channel> ch)
: BaseProcessor(ch) {
stub_ = sendrecv::SendRecvService::NewStub(ch);
}
virtual ~CheckpointNotifyProcessor() {}
virtual void Process() {}
sendrecv::VoidMessage reply_;
std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
};
class GRPCClient : public RPCClient { class GRPCClient : public RPCClient {
public: public:
GRPCClient() {} GRPCClient() {}
...@@ -178,24 +192,27 @@ class GRPCClient : public RPCClient { ...@@ -178,24 +192,27 @@ class GRPCClient : public RPCClient {
bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx, bool AsyncSendVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name, const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_grpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx, bool AsyncGetVar(const std::string& ep, const platform::DeviceContext& ctx,
const framework::Scope& scope, const std::string& var_name, const framework::Scope& scope, const std::string& var_name,
int64_t time_out = FLAGS_grpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
bool AsyncPrefetchVar(const std::string& ep, bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
int64_t time_out = FLAGS_grpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendBatchBarrier(const std::string& ep, void AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = FLAGS_grpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncSendFetchBarrier(const std::string& ep, void AsyncSendFetchBarrier(const std::string& ep,
int64_t time_out = FLAGS_grpc_deadline) override; int64_t time_out = FLAGS_rpc_deadline) override;
void AsyncCheckpointNotify(const std::string& ep, const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) override;
void Wait() override; void Wait() override;
...@@ -211,7 +228,7 @@ class GRPCClient : public RPCClient { ...@@ -211,7 +228,7 @@ class GRPCClient : public RPCClient {
void Proceed(); void Proceed();
void AsyncSendComplete(const std::string& ep, void AsyncSendComplete(const std::string& ep,
int64_t time_out = FLAGS_grpc_deadline); int64_t time_out = FLAGS_rpc_deadline);
std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep); std::shared_ptr<grpc::Channel> GetChannel(const std::string& ep);
......
...@@ -200,6 +200,45 @@ class RequestPrefetch final : public RequestBase { ...@@ -200,6 +200,45 @@ class RequestPrefetch final : public RequestBase {
framework::Scope* local_scope_; framework::Scope* local_scope_;
}; };
class RequestCheckpointNotify final : public RequestBase {
public:
explicit RequestCheckpointNotify(GrpcService::AsyncService* service,
::grpc::ServerCompletionQueue* cq,
RequestHandler* request_handler, int req_id)
: RequestBase(service, cq, request_handler, req_id), responder_(&ctx_) {
request_.reset(new VariableResponse(request_handler->scope(),
request_handler->dev_ctx()));
int method_id =
static_cast<int>(distributed::GrpcMethod::kCheckpointNotify);
service_->RequestAsyncUnary(
method_id, &ctx_, request_.get(), &responder_, cq_, cq_,
reinterpret_cast<void*>(static_cast<intptr_t>(req_id)));
}
virtual ~RequestCheckpointNotify() {}
std::string GetReqName() override { return request_->Varname(); }
void Process() override {
auto scope = request_->GetMutableLocalScope();
std::string checkpoint_notify = request_->Varname();
std::string checkpoint_dir = request_->OutVarname();
VLOG(4) << "RequestCheckpointNotify notify: " << checkpoint_notify
<< ", dir: " << checkpoint_dir;
request_handler_->Handle(checkpoint_notify, scope, nullptr, nullptr,
checkpoint_dir);
Finish(reply_, &responder_);
}
protected:
std::shared_ptr<VariableResponse> request_;
sendrecv::VoidMessage reply_;
ServerAsyncResponseWriter<sendrecv::VoidMessage> responder_;
};
void AsyncGRPCServer::WaitServerReady() { void AsyncGRPCServer::WaitServerReady() {
VLOG(4) << "AsyncGRPCServer is wait server ready"; VLOG(4) << "AsyncGRPCServer is wait server ready";
std::unique_lock<std::mutex> lock(this->mutex_ready_); std::unique_lock<std::mutex> lock(this->mutex_ready_);
...@@ -237,6 +276,7 @@ void AsyncGRPCServer::StartServer() { ...@@ -237,6 +276,7 @@ void AsyncGRPCServer::StartServer() {
reqs.reserve(kRequestBufSize); reqs.reserve(kRequestBufSize);
for (int i = 0; i < kRequestBufSize; i++) { for (int i = 0; i < kRequestBufSize; i++) {
VLOG(6) << "TryToRegisterNewOne on RPC NAME: " << rpc_name << " I: " << i;
TryToRegisterNewOne(rpc_name, i); TryToRegisterNewOne(rpc_name, i);
} }
...@@ -289,8 +329,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -289,8 +329,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
return; return;
} }
VLOG(4) << "register send rpc_name:" << rpc_name VLOG(4) << "TryToRegisterNewOne on RPC NAME: " << rpc_name
<< ", handler:" << rpc_call_map_[kRequestSend]; << " REQ ID: " << req_id;
auto& reqs = rpc_reqs_[rpc_name]; auto& reqs = rpc_reqs_[rpc_name];
auto& handler = rpc_call_map_[rpc_name]; auto& handler = rpc_call_map_[rpc_name];
...@@ -303,6 +343,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name, ...@@ -303,6 +343,8 @@ void AsyncGRPCServer::TryToRegisterNewOne(const std::string& rpc_name,
b = new RequestGet(&service_, cq.get(), handler, req_id); b = new RequestGet(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestPrefetch) { } else if (rpc_name == kRequestPrefetch) {
b = new RequestPrefetch(&service_, cq.get(), handler, req_id); b = new RequestPrefetch(&service_, cq.get(), handler, req_id);
} else if (rpc_name == kRequestCheckpoint) {
b = new RequestCheckpointNotify(&service_, cq.get(), handler, req_id);
} else { } else {
PADDLE_ENFORCE(false, "not supported rpc"); PADDLE_ENFORCE(false, "not supported rpc");
} }
...@@ -321,7 +363,7 @@ void AsyncGRPCServer::HandleRequest( ...@@ -321,7 +363,7 @@ void AsyncGRPCServer::HandleRequest(
while (true) { while (true) {
VLOG(4) << "HandleRequest " << rpc_name << " wait next"; VLOG(4) << "HandleRequest " << rpc_name << " wait next";
if (!cq->Next(&tag, &ok)) { if (!cq->Next(&tag, &ok)) {
LOG(INFO) << "CompletionQueue " << rpc_name << " shutdown!"; VLOG(3) << "CompletionQueue " << rpc_name << " shutdown!";
break; break;
} }
......
...@@ -80,10 +80,11 @@ enum class GrpcMethod { ...@@ -80,10 +80,11 @@ enum class GrpcMethod {
kSendVariable, kSendVariable,
kGetVariable, kGetVariable,
kPrefetchVariable, kPrefetchVariable,
kCheckpointNotify,
}; };
static const int kGrpcNumMethods = static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kPrefetchVariable) + 1; static_cast<int>(GrpcMethod::kCheckpointNotify) + 1;
inline const char* GrpcMethodName(GrpcMethod id) { inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) { switch (id) {
...@@ -93,6 +94,8 @@ inline const char* GrpcMethodName(GrpcMethod id) { ...@@ -93,6 +94,8 @@ inline const char* GrpcMethodName(GrpcMethod id) {
return "/sendrecv.SendRecvService/GetVariable"; return "/sendrecv.SendRecvService/GetVariable";
case GrpcMethod::kPrefetchVariable: case GrpcMethod::kPrefetchVariable:
return "/sendrecv.SendRecvService/PrefetchVariable"; return "/sendrecv.SendRecvService/PrefetchVariable";
case GrpcMethod::kCheckpointNotify:
return "/sendrecv.SendRecvService/CheckpointNotify";
} }
// Shouldn't be reached. // Shouldn't be reached.
......
...@@ -36,12 +36,16 @@ namespace distributed { ...@@ -36,12 +36,16 @@ namespace distributed {
constexpr char kRequestSend[] = "RequestSend"; constexpr char kRequestSend[] = "RequestSend";
constexpr char kRequestGet[] = "RequestGet"; constexpr char kRequestGet[] = "RequestGet";
constexpr char kRequestPrefetch[] = "RequestPrefetch"; constexpr char kRequestPrefetch[] = "RequestPrefetch";
constexpr char kRequestCheckpoint[] = "RequestCheckpoint";
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" #define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" #define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"
#define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV" #define FETCH_BARRIER_MESSAGE "FETCH_BARRIER@RECV"
#define COMPLETE_MESSAGE "COMPLETE@RECV" #define COMPLETE_MESSAGE "COMPLETE@RECV"
#define CHECKPOINT_SAVE_MESSAGE "SAVE@CHECKPOINTNOTIFY"
#define CHECKPOINT_LOAD_MESSAGE "LOAD@CHECKPOINTNOTIFY"
class RPCServer; class RPCServer;
class RequestHandler { class RequestHandler {
...@@ -69,6 +73,11 @@ class RequestHandler { ...@@ -69,6 +73,11 @@ class RequestHandler {
prefetch_var_name_to_prepared_ctx_ = g; prefetch_var_name_to_prepared_ctx_ = g;
} }
void SetCheckpointNotifyPreparedCtx(
std::shared_ptr<framework::ExecutorPrepareContext> g) {
checkpoint_prepared_ctx_ = g;
}
// Used for async. // Used for async.
void SetGradToPreparedCtx( void SetGradToPreparedCtx(
std::unordered_map< std::unordered_map<
...@@ -115,6 +124,8 @@ class RequestHandler { ...@@ -115,6 +124,8 @@ class RequestHandler {
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>>* std::shared_ptr<framework::ExecutorPrepareContext>>*
prefetch_var_name_to_prepared_ctx_; prefetch_var_name_to_prepared_ctx_;
// used for checkpoint notify
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_prepared_ctx_;
// Used for async. // Used for async.
std::unordered_map<std::string, std::unordered_map<std::string,
......
...@@ -22,11 +22,16 @@ ...@@ -22,11 +22,16 @@
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/request_handler_impl.h" #include "paddle/fluid/operators/distributed/request_handler_impl.h"
#include "paddle/fluid/operators/distributed/rpc_server.h" #include "paddle/fluid/operators/distributed/rpc_server.h"
#include "paddle/fluid/string/printf.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace distributed { namespace distributed {
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
bool RequestSendHandler::Handle(const std::string& varname, bool RequestSendHandler::Handle(const std::string& varname,
framework::Scope* scope, framework::Scope* scope,
framework::Variable* invar, framework::Variable* invar,
...@@ -119,6 +124,24 @@ bool RequestPrefetchHandler::Handle(const std::string& varname, ...@@ -119,6 +124,24 @@ bool RequestPrefetchHandler::Handle(const std::string& varname,
return true; return true;
} }
bool RequestCheckpointHandler::Handle(const std::string& varname,
framework::Scope* scope,
framework::Variable* invar,
framework::Variable** outvar,
const std::string& out_var_name) {
PADDLE_ENFORCE(
checkpoint_notify_id != -1,
"when checkpoint_notify_id = -1, there should be no RPC invoke.");
auto* lt_var = scope->FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
lt_var->clear();
lt_var->append(out_var_name);
VLOG(4) << "RequestCheckpointHandler update var kLookupTablePath to: "
<< out_var_name;
executor_->RunPreparedContext(checkpoint_prepared_ctx_.get(), scope);
return true;
}
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -66,6 +66,21 @@ class RequestPrefetchHandler final : public RequestHandler { ...@@ -66,6 +66,21 @@ class RequestPrefetchHandler final : public RequestHandler {
const std::string& out_var_name = "") override; const std::string& out_var_name = "") override;
}; };
class RequestCheckpointHandler final : public RequestHandler {
public:
explicit RequestCheckpointHandler(bool sync_mode, int checkpoint_notify_id)
: RequestHandler(sync_mode) {
this->checkpoint_notify_id = checkpoint_notify_id;
}
virtual ~RequestCheckpointHandler() {}
bool Handle(const std::string& varname, framework::Scope* scope,
framework::Variable* var, framework::Variable** outvar,
const std::string& out_var_name = "") override;
private:
int checkpoint_notify_id;
};
} // namespace distributed } // namespace distributed
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include "gflags/gflags.h" #include "gflags/gflags.h"
// default to 3min to avoid temprary network failures. // default to 3min to avoid temprary network failures.
DEFINE_int32(grpc_deadline, 180000, "deadline timeouts for grpc"); DEFINE_int32(rpc_deadline, 180000, "deadline timeouts for rpc");
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
DECLARE_int32(grpc_deadline); DECLARE_int32(rpc_deadline);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -35,26 +35,30 @@ class RPCClient { ...@@ -35,26 +35,30 @@ class RPCClient {
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
int64_t time_out = FLAGS_grpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual bool AsyncGetVar(const std::string& ep, virtual bool AsyncGetVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
int64_t time_out = FLAGS_grpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual bool AsyncPrefetchVar(const std::string& ep, virtual bool AsyncPrefetchVar(const std::string& ep,
const platform::DeviceContext& ctx, const platform::DeviceContext& ctx,
const framework::Scope& scope, const framework::Scope& scope,
const std::string& in_var_name, const std::string& in_var_name,
const std::string& out_var_name, const std::string& out_var_name,
int64_t time_out = FLAGS_grpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendBatchBarrier( virtual void AsyncSendBatchBarrier(const std::string& ep,
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncSendFetchBarrier( virtual void AsyncSendFetchBarrier(const std::string& ep,
const std::string& ep, int64_t time_out = FLAGS_grpc_deadline) = 0; int64_t time_out = FLAGS_rpc_deadline) = 0;
virtual void AsyncCheckpointNotify(const std::string& ep,
const std::string& dir,
int64_t time_out = FLAGS_rpc_deadline) = 0;
// SendComplete tells all the server that current trainer have no more data // SendComplete tells all the server that current trainer have no more data
// to train, so that the pserver can reduce it's barrier count, and continue // to train, so that the pserver can reduce it's barrier count, and continue
......
...@@ -25,6 +25,8 @@ service SendRecvService { ...@@ -25,6 +25,8 @@ service SendRecvService {
rpc GetVariable(VariableMessage) returns (VariableMessage) {} rpc GetVariable(VariableMessage) returns (VariableMessage) {}
// pre-fetch variable by given variable name and Ids // pre-fetch variable by given variable name and Ids
rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {} rpc PrefetchVariable(VariableMessage) returns (VariableMessage) {}
rpc CheckpointNotify(VariableMessage) returns (VoidMessage) {}
} }
// VariableMessage is serialized paddle variable message. // VariableMessage is serialized paddle variable message.
......
...@@ -99,7 +99,8 @@ static int64_t GetTimestamp() { ...@@ -99,7 +99,8 @@ static int64_t GetTimestamp() {
void ListenAndServOp::RunSyncLoop( void ListenAndServOp::RunSyncLoop(
framework::Executor *executor, framework::ProgramDesc *program, framework::Executor *executor, framework::ProgramDesc *program,
framework::Scope *recv_scope, framework::Scope *recv_scope,
const std::vector<int> &prefetch_block_id_list) const { const std::vector<int> &prefetch_block_id_list,
const int checkpoint_point_block_id) const {
size_t num_blocks = program->Size(); size_t num_blocks = program->Size();
auto optimize_blocks = auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
...@@ -163,7 +164,8 @@ void ListenAndServOp::RunSyncLoop( ...@@ -163,7 +164,8 @@ void ListenAndServOp::RunSyncLoop(
} }
void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
framework::ProgramDesc *program) const { framework::ProgramDesc *program,
framework::Scope *recv_scope) const {
// grad name to block id // grad name to block id
std::unordered_map<std::string, int32_t> grad_to_block_id; std::unordered_map<std::string, int32_t> grad_to_block_id;
std::unordered_map<int32_t, std::string> id_to_grad; std::unordered_map<int32_t, std::string> id_to_grad;
...@@ -190,6 +192,10 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -190,6 +192,10 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
block_list.push_back(blkid); block_list.push_back(blkid);
} }
auto optimize_prepared = executor->Prepare(*program, block_list); auto optimize_prepared = executor->Prepare(*program, block_list);
// execute global block if needed
if (block_list[0] == 1 && id_to_grad.count(1) == 0) {
executor->RunPreparedContext(optimize_prepared[0].get(), recv_scope);
}
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
grad_to_prepared_ctx; grad_to_prepared_ctx;
...@@ -203,7 +209,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor, ...@@ -203,7 +209,7 @@ void ListenAndServOp::RunAsyncLoop(framework::Executor *executor,
while (true) { while (true) {
if (rpc_service_->IsExit()) { if (rpc_service_->IsExit()) {
LOG(INFO) << "get exit!rpc_processor break!"; VLOG(4) << "get exit!rpc_processor break!";
break; break;
} }
...@@ -218,6 +224,7 @@ static void FillRequestCtx( ...@@ -218,6 +224,7 @@ static void FillRequestCtx(
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<framework::ExecutorPrepareContext>> std::shared_ptr<framework::ExecutorPrepareContext>>
*prefetch_ctx, *prefetch_ctx,
std::shared_ptr<framework::ExecutorPrepareContext> checkpoint_ctx,
distributed::RPCServer *rpc_server) { distributed::RPCServer *rpc_server) {
h->SetScope(scope); h->SetScope(scope);
h->SetDevCtx(dev_ctx); h->SetDevCtx(dev_ctx);
...@@ -225,6 +232,7 @@ static void FillRequestCtx( ...@@ -225,6 +232,7 @@ static void FillRequestCtx(
h->SetProgram(program); h->SetProgram(program);
h->SetPrefetchPreparedCtx(prefetch_ctx); h->SetPrefetchPreparedCtx(prefetch_ctx);
h->SetRPCServer(rpc_server); h->SetRPCServer(rpc_server);
h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx);
} }
void ListenAndServOp::RunImpl(const framework::Scope &scope, void ListenAndServOp::RunImpl(const framework::Scope &scope,
...@@ -240,9 +248,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -240,9 +248,11 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
PADDLE_ENFORCE(!rpc_service_); PADDLE_ENFORCE(!rpc_service_);
std::string endpoint = Attr<std::string>("endpoint"); std::string endpoint = Attr<std::string>("endpoint");
int checkpoint_block_id = Attr<int>(kCheckpointBlockId);
LOG(INFO) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in VLOG(4) << "sync_mode:" << sync_mode << ", fan_in:" << fan_in
<< ", end_point:" << endpoint; << ", end_point:" << endpoint
<< ", checkpoint_block_id: " << checkpoint_block_id;
rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in)); rpc_service_.reset(new RPCSERVER_T(endpoint, fan_in));
...@@ -250,6 +260,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -250,6 +260,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_get_handler_.reset(new distributed::RequestGetHandler(sync_mode)); request_get_handler_.reset(new distributed::RequestGetHandler(sync_mode));
request_prefetch_handler_.reset( request_prefetch_handler_.reset(
new distributed::RequestPrefetchHandler(sync_mode)); new distributed::RequestPrefetchHandler(sync_mode));
request_checkpoint_handler_.reset(new distributed::RequestCheckpointHandler(
sync_mode, checkpoint_block_id));
rpc_service_->RegisterRPC(distributed::kRequestSend, rpc_service_->RegisterRPC(distributed::kRequestSend,
request_send_handler_.get()); request_send_handler_.get());
...@@ -257,6 +269,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -257,6 +269,8 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
request_get_handler_.get()); request_get_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestPrefetch, rpc_service_->RegisterRPC(distributed::kRequestPrefetch,
request_prefetch_handler_.get()); request_prefetch_handler_.get());
rpc_service_->RegisterRPC(distributed::kRequestCheckpoint,
request_checkpoint_handler_.get());
auto optimize_blocks = auto optimize_blocks =
Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks); Attr<std::vector<framework::BlockDesc *>>(kOptimizeBlocks);
...@@ -265,6 +279,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -265,6 +279,13 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
auto *program = optimize_blocks[0]->Program(); auto *program = optimize_blocks[0]->Program();
framework::Executor executor(dev_place); framework::Executor executor(dev_place);
std::shared_ptr<framework::ExecutorPrepareContext> ckpt_pre_context = nullptr;
if (checkpoint_block_id != -1) {
auto ctx = executor.Prepare(*program, checkpoint_block_id);
// see: https://stackoverflow.com/a/14856553
ckpt_pre_context = std::move(ctx);
}
// prepare for prefetch // prepare for prefetch
std::vector<int> prefetch_block_id_list; std::vector<int> prefetch_block_id_list;
std::unordered_map<int, std::string> block_id_to_prefetch_var_name; std::unordered_map<int, std::string> block_id_to_prefetch_var_name;
...@@ -295,13 +316,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -295,13 +316,15 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i]; prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i];
} }
auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, auto f =
&dev_ctx, &executor, program, std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx,
&prefetch_var_name_to_prepared_ctx, rpc_service_.get()); &executor, program, &prefetch_var_name_to_prepared_ctx,
ckpt_pre_context, rpc_service_.get());
f(request_send_handler_.get()); f(request_send_handler_.get());
f(request_get_handler_.get()); f(request_get_handler_.get());
f(request_prefetch_handler_.get()); f(request_prefetch_handler_.get());
f(request_checkpoint_handler_.get());
// start the server listening after all member initialized. // start the server listening after all member initialized.
server_thread_.reset(new std::thread(RunServer, rpc_service_)); server_thread_.reset(new std::thread(RunServer, rpc_service_));
...@@ -315,9 +338,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, ...@@ -315,9 +338,10 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope,
// Write to a file of server selected port for python use. // Write to a file of server selected port for python use.
SavePort(); SavePort();
if (sync_mode) { if (sync_mode) {
RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list); RunSyncLoop(&executor, program, &recv_scope, prefetch_block_id_list,
checkpoint_block_id);
} else { } else {
RunAsyncLoop(&executor, program); RunAsyncLoop(&executor, program, &recv_scope);
} }
} }
...@@ -347,6 +371,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -347,6 +371,9 @@ class ListenAndServOpMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault({}); .SetDefault({});
AddAttr<int>("Fanin", "How many clients send to this server.") AddAttr<int>("Fanin", "How many clients send to this server.")
.SetDefault(1); .SetDefault(1);
AddAttr<int>(kCheckpointBlockId,
"BolckID to run save checkpoint on pserer.")
.SetDefault(-1);
} }
}; };
......
...@@ -32,6 +32,7 @@ namespace operators { ...@@ -32,6 +32,7 @@ namespace operators {
constexpr char kOptimizeBlocks[] = "optimize_blocks"; constexpr char kOptimizeBlocks[] = "optimize_blocks";
constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id"; constexpr char kPrefetchVarNameToBlockId[] = "prefetch_var_name_to_block_id";
constexpr char kCheckpointBlockId[] = "checkpint_block_id";
void RunServer(std::shared_ptr<distributed::RPCServer> service); void RunServer(std::shared_ptr<distributed::RPCServer> service);
...@@ -47,10 +48,12 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -47,10 +48,12 @@ class ListenAndServOp : public framework::OperatorBase {
void RunSyncLoop(framework::Executor* executor, void RunSyncLoop(framework::Executor* executor,
framework::ProgramDesc* program, framework::ProgramDesc* program,
framework::Scope* recv_scope, framework::Scope* recv_scope,
const std::vector<int>& prefetch_block_id_list) const; const std::vector<int>& prefetch_block_id_list,
const int checkpoint_point_block_id) const;
void RunAsyncLoop(framework::Executor* executor, void RunAsyncLoop(framework::Executor* executor,
framework::ProgramDesc* program) const; framework::ProgramDesc* program,
framework::Scope* recv_scope) const;
void SavePort() const; void SavePort() const;
...@@ -67,6 +70,8 @@ class ListenAndServOp : public framework::OperatorBase { ...@@ -67,6 +70,8 @@ class ListenAndServOp : public framework::OperatorBase {
mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_; mutable std::shared_ptr<distributed::RequestHandler> request_get_handler_;
mutable std::shared_ptr<distributed::RequestHandler> mutable std::shared_ptr<distributed::RequestHandler>
request_prefetch_handler_; request_prefetch_handler_;
mutable std::shared_ptr<distributed::RequestHandler>
request_checkpoint_handler_;
mutable std::shared_ptr<std::thread> server_thread_; mutable std::shared_ptr<std::thread> server_thread_;
}; };
......
...@@ -34,6 +34,8 @@ class LoadOp : public framework::OperatorBase { ...@@ -34,6 +34,8 @@ class LoadOp : public framework::OperatorBase {
auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place); auto *dev_ctx = platform::DeviceContextPool::Instance().Get(place);
platform::RecordEvent record_event(Type(), dev_ctx); platform::RecordEvent record_event(Type(), dev_ctx);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
std::ifstream fin(filename); std::ifstream fin(filename);
PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op", PADDLE_ENFORCE(static_cast<bool>(fin), "Cannot open file %s for load op",
...@@ -44,9 +46,25 @@ class LoadOp : public framework::OperatorBase { ...@@ -44,9 +46,25 @@ class LoadOp : public framework::OperatorBase {
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
out_var_name); out_var_name);
auto *tensor = out_var->GetMutable<framework::LoDTensor>(); if (out_var->IsType<framework::LoDTensor>()) {
LoadLodTensor(fin, place, out_var);
} else if (out_var->IsType<framework::SelectedRows>()) {
LoadSelectedRows(fin, place, out_var);
} else {
PADDLE_ENFORCE(
false,
"Load only support LoDTensor and SelectedRows, %s has wrong type",
out_var_name);
}
}
DeserializeFromStream(fin, tensor, *dev_ctx); void LoadLodTensor(std::istream &fin, const platform::Place &place,
framework::Variable *var) const {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto *tensor = var->GetMutable<framework::LoDTensor>();
DeserializeFromStream(fin, tensor, dev_ctx);
auto load_as_fp16 = Attr<bool>("load_as_fp16"); auto load_as_fp16 = Attr<bool>("load_as_fp16");
auto in_dtype = framework::ToDataType(tensor->type()); auto in_dtype = framework::ToDataType(tensor->type());
...@@ -63,18 +81,27 @@ class LoadOp : public framework::OperatorBase { ...@@ -63,18 +81,27 @@ class LoadOp : public framework::OperatorBase {
&fp16_tensor); &fp16_tensor);
// reset output tensor // reset output tensor
out_var->Clear(); var->Clear();
tensor = out_var->GetMutable<framework::LoDTensor>(); tensor = var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod()); tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor); tensor->ShareDataWith(fp16_tensor);
} }
} }
void LoadSelectedRows(std::istream &fin, const platform::Place &place,
framework::Variable *var) const {
auto *selectedRows = var->GetMutable<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
}
}; };
class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddOutput("Out", "The tensor need to be loaded"); AddOutput("Out", "The LoDTensor / SelectedRows need to be loaded");
AddAttr<bool>( AddAttr<bool>(
"load_as_fp16", "load_as_fp16",
"If true, the tensor will be first loaded and then " "If true, the tensor will be first loaded and then "
...@@ -85,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -85,7 +112,9 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
R"(Variable will be loaded from "file_path")") R"(Variable will be loaded from "file_path")")
.AddCustomChecker( .AddCustomChecker(
[](const std::string &path) { return !path.empty(); }); [](const std::string &path) { return !path.empty(); });
AddComment("Load operator will load a tensor variable from disk file."); AddComment(
"Load operator will load a LoDTensor / SelectedRows variable from disk "
"file.");
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -18,10 +18,7 @@ ...@@ -18,10 +18,7 @@
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include <mkl_cblas.h> #include "paddle/fluid/platform/dynload/mklml.h"
#include <mkl_lapacke.h>
#include <mkl_service.h>
#include <mkl_vml_functions.h>
#endif #endif
#ifdef PADDLE_USE_OPENBLAS #ifdef PADDLE_USE_OPENBLAS
...@@ -55,7 +52,7 @@ static void SetNumThreads(int num_threads) { ...@@ -55,7 +52,7 @@ static void SetNumThreads(int num_threads) {
openblas_set_num_threads(real_num_threads); openblas_set_num_threads(real_num_threads);
#elif defined(PADDLE_WITH_MKLML) #elif defined(PADDLE_WITH_MKLML)
int real_num_threads = num_threads > 1 ? num_threads : 1; int real_num_threads = num_threads > 1 ? num_threads : 1;
mkl_set_num_threads(real_num_threads); platform::dynload::MKL_Set_Num_Threads(real_num_threads);
#else #else
PADDLE_ENFORCE(false, "To be implemented."); PADDLE_ENFORCE(false, "To be implemented.");
#endif #endif
......
...@@ -22,61 +22,109 @@ namespace math { ...@@ -22,61 +22,109 @@ namespace math {
template <typename T> template <typename T>
struct CBlas; struct CBlas;
#ifdef PADDLE_WITH_MKLML
template <> template <>
struct CBlas<float> { struct CBlas<float> {
template <typename... ARGS> template <typename... ARGS>
static void GEMM(ARGS... args) { static void GEMM(ARGS... args) {
cblas_sgemm(args...); platform::dynload::cblas_sgemm(args...);
} }
template <typename... ARGS> template <typename... ARGS>
static void AXPY(ARGS... args) { static void AXPY(ARGS... args) {
cblas_saxpy(args...); platform::dynload::cblas_saxpy(args...);
}
template <typename... ARGS>
static void VCOPY(ARGS... args) {
platform::dynload::cblas_scopy(args...);
}
template <typename... ARGS>
static void GEMV(ARGS... args) {
platform::dynload::cblas_sgemv(args...);
}
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
platform::dynload::cblas_sgemm_batch(args...);
} }
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS> template <typename... ARGS>
static void VADD(ARGS... args) { static void VADD(ARGS... args) {
vsAdd(args...); platform::dynload::vsAdd(args...);
}
};
template <>
struct CBlas<double> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
platform::dynload::cblas_dgemm(args...);
}
template <typename... ARGS>
static void AXPY(ARGS... args) {
platform::dynload::cblas_daxpy(args...);
} }
#endif
template <typename... ARGS> template <typename... ARGS>
static void VCOPY(ARGS... args) { static void VCOPY(ARGS... args) {
cblas_scopy(args...); platform::dynload::cblas_dcopy(args...);
} }
template <typename... ARGS> template <typename... ARGS>
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
cblas_sgemv(args...); platform::dynload::cblas_dgemv(args...);
} }
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS> template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) { static void GEMM_BATCH(ARGS... args) {
cblas_sgemm_batch(args...); platform::dynload::cblas_dgemm_batch(args...);
}
template <typename... ARGS>
static void VADD(ARGS... args) {
platform::dynload::vdAdd(args...);
} }
#endif
}; };
#else
template <> template <>
struct CBlas<double> { struct CBlas<float> {
template <typename... ARGS> template <typename... ARGS>
static void GEMM(ARGS... args) { static void GEMM(ARGS... args) {
cblas_dgemm(args...); cblas_sgemm(args...);
} }
template <typename... ARGS> template <typename... ARGS>
static void AXPY(ARGS... args) { static void AXPY(ARGS... args) {
cblas_daxpy(args...); cblas_saxpy(args...);
} }
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS> template <typename... ARGS>
static void VADD(ARGS... args) { static void VCOPY(ARGS... args) {
vdAdd(args...); cblas_scopy(args...);
}
template <typename... ARGS>
static void GEMV(ARGS... args) {
cblas_sgemv(args...);
}
};
template <>
struct CBlas<double> {
template <typename... ARGS>
static void GEMM(ARGS... args) {
cblas_dgemm(args...);
}
template <typename... ARGS>
static void AXPY(ARGS... args) {
cblas_daxpy(args...);
} }
#endif
template <typename... ARGS> template <typename... ARGS>
static void VCOPY(ARGS... args) { static void VCOPY(ARGS... args) {
...@@ -87,15 +135,8 @@ struct CBlas<double> { ...@@ -87,15 +135,8 @@ struct CBlas<double> {
static void GEMV(ARGS... args) { static void GEMV(ARGS... args) {
cblas_dgemv(args...); cblas_dgemv(args...);
} }
#ifdef PADDLE_WITH_MKLML
template <typename... ARGS>
static void GEMM_BATCH(ARGS... args) {
cblas_dgemm_batch(args...);
}
#endif
}; };
#endif
template <> template <>
struct CBlas<platform::float16> { struct CBlas<platform::float16> {
static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); } static void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); }
......
...@@ -14,9 +14,7 @@ limitations under the License. */ ...@@ -14,9 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#include <mkl_cblas.h> #include "paddle/fluid/platform/dynload/mklml.h"
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif #endif
#ifdef PADDLE_USE_OPENBLAS #ifdef PADDLE_USE_OPENBLAS
......
...@@ -139,6 +139,7 @@ TEST(LoadFP16Op, CPU) { ...@@ -139,6 +139,7 @@ TEST(LoadFP16Op, CPU) {
save_op->Run(scope, place); save_op->Run(scope, place);
auto load_var = scope.Var("out_var"); auto load_var = scope.Var("out_var");
load_var->GetMutable<paddle::framework::LoDTensor>();
auto load_op = paddle::framework::OpRegistry::CreateOp( auto load_op = paddle::framework::OpRegistry::CreateOp(
"load", {}, {{"Out", {"out_var"}}}, attrs); "load", {}, {{"Out", {"out_var"}}}, attrs);
load_op->Run(scope, place); load_op->Run(scope, place);
......
...@@ -22,11 +22,17 @@ limitations under the License. */ ...@@ -22,11 +22,17 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
// define LOOKUP_TABLE_PATH for checkpoint notify to save lookup table variables
// to directory specified.
constexpr char LOOKUP_TABLE_PATH[] = "kLookupTablePath";
// TODO(yuyang18): If the functions below are needed by other files, move them // TODO(yuyang18): If the functions below are needed by other files, move them
// to paddle::filesystem namespace. // to paddle::filesystem namespace.
constexpr char kSEP = '/'; constexpr char kSEP = '/';
...@@ -67,9 +73,27 @@ class SaveOp : public framework::OperatorBase { ...@@ -67,9 +73,27 @@ class SaveOp : public framework::OperatorBase {
private: private:
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
auto iname = Input("X");
auto *var = scope.FindVar(iname);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
iname);
if (var->IsType<framework::LoDTensor>()) {
SaveLodTensor(place, var);
} else if (var->IsType<framework::SelectedRows>()) {
SaveSelectedRows(scope, place, var);
} else {
PADDLE_ENFORCE(
false,
"SaveOp only support LoDTensor and SelectedRows, %s has wrong type",
iname);
}
}
void SaveLodTensor(const platform::Place &place,
framework::Variable *var) const {
auto filename = Attr<std::string>("file_path"); auto filename = Attr<std::string>("file_path");
auto overwrite = Attr<bool>("overwrite"); auto overwrite = Attr<bool>("overwrite");
auto save_as_fp16 = Attr<bool>("save_as_fp16");
if (FileExists(filename) && !overwrite) { if (FileExists(filename) && !overwrite) {
PADDLE_THROW("%s is existed, cannot save to it when overwrite=false", PADDLE_THROW("%s is existed, cannot save to it when overwrite=false",
...@@ -78,26 +102,19 @@ class SaveOp : public framework::OperatorBase { ...@@ -78,26 +102,19 @@ class SaveOp : public framework::OperatorBase {
MkDirRecursively(DirName(filename).c_str()); MkDirRecursively(DirName(filename).c_str());
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto iname = Input("X");
auto *var = scope.FindVar(iname);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s for save_op",
iname);
PADDLE_ENFORCE(var->IsType<framework::LoDTensor>(),
"SaveOp only support LoDTensor, %s has wrong type", iname);
auto &tensor = var->Get<framework::LoDTensor>(); auto &tensor = var->Get<framework::LoDTensor>();
// get device context from pool // get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place); auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
auto save_as_fp16 = Attr<bool>("save_as_fp16");
auto in_dtype = framework::ToDataType(tensor.type()); auto in_dtype = framework::ToDataType(tensor.type());
auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype; auto out_dtype = save_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
...@@ -112,17 +129,43 @@ class SaveOp : public framework::OperatorBase { ...@@ -112,17 +129,43 @@ class SaveOp : public framework::OperatorBase {
} else { } else {
framework::SerializeToStream(fout, tensor, dev_ctx); framework::SerializeToStream(fout, tensor, dev_ctx);
} }
fout.close();
}
void SaveSelectedRows(const framework::Scope &scope,
const platform::Place &place,
framework::Variable *var) const {
auto *lt_var = scope.FindVar(LOOKUP_TABLE_PATH)->GetMutable<std::string>();
PADDLE_ENFORCE(
lt_var != nullptr,
"Can not find variable kLookupTablePath for SaveSelectedRows");
std::string filename = lt_var->data();
VLOG(4) << "SaveSelectedRows get File name: " << filename;
auto &selectedRows = var->Get<framework::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ofstream fout(filename);
PADDLE_ENFORCE(static_cast<bool>(fout), "Cannot open %s to write",
filename);
framework::SerializeToStream(fout, selectedRows, dev_ctx);
fout.close();
} }
}; };
class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker { class SaveOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor ) Input tensor to be saved"); AddInput("X", "(Tensor ) Input LoDTensor and SelectedRows to be saved");
AddComment(R"DOC( AddComment(R"DOC(
Save operator Save operator
This operator will serialize and write a tensor variable to file on disk. This operator will serialize and write LoDTensor / SelectedRows variable to file on disk.
)DOC"); )DOC");
AddAttr<bool>("overwrite", AddAttr<bool>("overwrite",
"(boolean, default true)" "(boolean, default true)"
...@@ -142,9 +185,26 @@ This operator will serialize and write a tensor variable to file on disk. ...@@ -142,9 +185,26 @@ This operator will serialize and write a tensor variable to file on disk.
} }
}; };
class SaveOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output(LOOKUP_TABLE_PATH).front();
auto &out_var = block->FindRecursiveOrCreateVar(out_var_name);
auto var_type = framework::proto::VarType::RAW;
out_var.SetType(var_type);
}
};
class SaveOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(save, ops::SaveOp, ops::SaveOpProtoMaker); REGISTER_OPERATOR(save, ops::SaveOp, paddle::framework::EmptyGradOpMaker,
ops::SaveOpProtoMaker, ops::SaveOpVarTypeInference,
ops::SaveOpShapeInference);
...@@ -17,3 +17,7 @@ if (CUPTI_FOUND) ...@@ -17,3 +17,7 @@ if (CUPTI_FOUND)
endif(CUPTI_FOUND) endif(CUPTI_FOUND)
nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader) nv_library(dynload_cuda SRCS ${CUDA_SRCS} DEPS dynamic_loader)
cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc) cc_library(dynload_warpctc SRCS warpctc.cc DEPS dynamic_loader warpctc)
if (WITH_MKLML)
cc_library(dynload_mklml SRCS mklml.cc DEPS dynamic_loader mklml)
endif()
# TODO(TJ): add iomp, mkldnn?
...@@ -49,6 +49,8 @@ DEFINE_string( ...@@ -49,6 +49,8 @@ DEFINE_string(
tensorrt_dir, "", tensorrt_dir, "",
"Specify path for loading tensorrt library, such as libnvinfer.so."); "Specify path for loading tensorrt library, such as libnvinfer.so.");
DEFINE_string(mklml_dir, "", "Specify path for loading libmklml_intel.so.");
namespace paddle { namespace paddle {
namespace platform { namespace platform {
namespace dynload { namespace dynload {
...@@ -76,6 +78,7 @@ static inline void* GetDsoHandleFromDefaultPath(const std::string& dso_path, ...@@ -76,6 +78,7 @@ static inline void* GetDsoHandleFromDefaultPath(const std::string& dso_path,
VLOG(3) << "Try to find library: " << dso_path VLOG(3) << "Try to find library: " << dso_path
<< " from default system path."; << " from default system path.";
// default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH // default search from LD_LIBRARY_PATH/DYLD_LIBRARY_PATH
// and /usr/local/lib path
void* dso_handle = dlopen(dso_path.c_str(), dynload_flags); void* dso_handle = dlopen(dso_path.c_str(), dynload_flags);
// DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to // DYLD_LIBRARY_PATH is disabled after Mac OS 10.11 to
...@@ -97,6 +100,10 @@ static inline void* GetDsoHandleFromDefaultPath(const std::string& dso_path, ...@@ -97,6 +100,10 @@ static inline void* GetDsoHandleFromDefaultPath(const std::string& dso_path,
} }
#endif #endif
if (nullptr == dso_handle) {
LOG(WARNING) << "Can not find library: " << dso_path
<< ". Please try to add the lib path to LD_LIBRARY_PATH.";
}
return dso_handle; return dso_handle;
} }
...@@ -206,6 +213,14 @@ void* GetTensorRtDsoHandle() { ...@@ -206,6 +213,14 @@ void* GetTensorRtDsoHandle() {
#endif #endif
} }
void* GetMKLMLDsoHandle() {
#if defined(__APPLE__) || defined(__OSX__)
return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "libmklml_intel.dylib");
#else
return GetDsoHandleFromSearchPath(FLAGS_mklml_dir, "libmklml_intel.so");
#endif
}
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -26,6 +26,7 @@ void* GetWarpCTCDsoHandle(); ...@@ -26,6 +26,7 @@ void* GetWarpCTCDsoHandle();
void* GetLapackDsoHandle(); void* GetLapackDsoHandle();
void* GetNCCLDsoHandle(); void* GetNCCLDsoHandle();
void* GetTensorRtDsoHandle(); void* GetTensorRtDsoHandle();
void* GetMKLMLDsoHandle();
} // namespace dynload } // namespace dynload
} // namespace platform } // namespace platform
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/dynload/mklml.h"
namespace paddle {
namespace platform {
namespace dynload {
std::once_flag mklml_dso_flag;
void* mklml_dso_handle = nullptr;
#define DEFINE_WRAP(__name) DynLoad__##__name __name
MKLML_ROUTINE_EACH(DEFINE_WRAP);
} // namespace dynload
} // namespace platform
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <dlfcn.h>
#include <mkl.h>
#include <mutex> // NOLINT
#include "paddle/fluid/platform/dynload/dynamic_loader.h"
namespace paddle {
namespace platform {
namespace dynload {
extern std::once_flag mklml_dso_flag;
extern void* mklml_dso_handle;
/**
* The following macro definition can generate structs
* (for each function) to dynamic load mklml routine
* via operator overloading.
*/
#define DYNAMIC_LOAD_MKLML_WRAP(__name) \
struct DynLoad__##__name { \
template <typename... Args> \
auto operator()(Args... args) -> decltype(__name(args...)) { \
using mklmlFunc = decltype(&::__name); \
std::call_once(mklml_dso_flag, []() { \
mklml_dso_handle = paddle::platform::dynload::GetMKLMLDsoHandle(); \
}); \
static void* p_##_name = dlsym(mklml_dso_handle, #__name); \
return reinterpret_cast<mklmlFunc>(p_##_name)(args...); \
} \
}; \
extern DynLoad__##__name __name
#define DECLARE_DYNAMIC_LOAD_MKLML_WRAP(__name) DYNAMIC_LOAD_MKLML_WRAP(__name)
#define MKLML_ROUTINE_EACH(__macro) \
__macro(cblas_sgemm); \
__macro(cblas_saxpy); \
__macro(cblas_scopy); \
__macro(cblas_sgemv); \
__macro(cblas_sgemm_batch); \
__macro(cblas_dgemm); \
__macro(cblas_daxpy); \
__macro(cblas_dcopy); \
__macro(cblas_dgemv); \
__macro(cblas_dgemm_batch); \
__macro(vsAdd); \
__macro(vdAdd); \
__macro(MKL_Set_Num_Threads)
MKLML_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_MKLML_WRAP);
#undef DYNAMIC_LOAD_MKLML_WRAP
} // namespace dynload
} // namespace platform
} // namespace paddle
...@@ -78,6 +78,8 @@ def as_numpy(tensor): ...@@ -78,6 +78,8 @@ def as_numpy(tensor):
Returns: Returns:
numpy.ndarray numpy.ndarray
""" """
if isinstance(tensor, core.LoDTensorArray):
return [as_numpy(t) for t in tensor]
if isinstance(tensor, list): if isinstance(tensor, list):
return [as_numpy(t) for t in tensor] return [as_numpy(t) for t in tensor]
assert isinstance(tensor, core.LoDTensor) assert isinstance(tensor, core.LoDTensor)
......
...@@ -454,7 +454,7 @@ class Operator(object): ...@@ -454,7 +454,7 @@ class Operator(object):
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'recv',
'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine', 'listen_and_serv', 'parallel_do', 'save_combine', 'load_combine',
'ncclInit', 'channel_create', 'channel_close', 'channel_send', 'ncclInit', 'channel_create', 'channel_close', 'channel_send',
'channel_recv', 'select', 'gen_nccl_id' 'channel_recv', 'select', 'checkpoint_notify', 'gen_nccl_id'
} }
def __init__(self, def __init__(self,
...@@ -1214,6 +1214,9 @@ class Block(object): ...@@ -1214,6 +1214,9 @@ class Block(object):
if var.type == core.VarDesc.VarType.STEP_SCOPES: if var.type == core.VarDesc.VarType.STEP_SCOPES:
ret_var = self.create_var( ret_var = self.create_var(
name=var.name, persistable=var.persistable, type=var.type) name=var.name, persistable=var.persistable, type=var.type)
elif var.type == core.VarDesc.VarType.RAW:
ret_var = self.create_var(
name=var.name, persistable=var.persistable, type=var.type)
elif var.type == core.VarDesc.VarType.SELECTED_ROWS: elif var.type == core.VarDesc.VarType.SELECTED_ROWS:
ret_var = self.create_var( ret_var = self.create_var(
name=var.name, name=var.name,
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import os import os
import errno
import time import time
import shutil import shutil
...@@ -25,7 +26,8 @@ __all__ = [ ...@@ -25,7 +26,8 @@ __all__ = [
'load_persistables', 'save_inference_model', 'load_inference_model', 'load_persistables', 'save_inference_model', 'load_inference_model',
'get_inference_program', 'save_checkpoint', 'load_checkpoint', 'get_inference_program', 'save_checkpoint', 'load_checkpoint',
'clean_checkpoint', 'load_persist_vars_without_grad', 'clean_checkpoint', 'load_persist_vars_without_grad',
'save_persist_vars_without_grad', 'get_latest_checkpoint_serial' 'load_lookup_table_vars', 'save_persist_vars_without_grad',
'get_latest_checkpoint_serial'
] ]
...@@ -795,6 +797,7 @@ def get_parameter_value_by_name(name, executor, program=None): ...@@ -795,6 +797,7 @@ def get_parameter_value_by_name(name, executor, program=None):
SUCCESS_MARK_FILENAME = "_SUCCESS" SUCCESS_MARK_FILENAME = "_SUCCESS"
CHECKPOINT_PREFIX = "checkpoint" CHECKPOINT_PREFIX = "checkpoint"
MODEL_DIR = "__model__" MODEL_DIR = "__model__"
LOOKUP_TABLE_DIR = "__lookup_table__"
TRAINER_PREFIX = "trainer" TRAINER_PREFIX = "trainer"
CHECKPOINT_SEPARATOR = "_" CHECKPOINT_SEPARATOR = "_"
...@@ -804,7 +807,9 @@ def save_checkpoint(executor, ...@@ -804,7 +807,9 @@ def save_checkpoint(executor,
trainer_id, trainer_id,
trainer_args=None, trainer_args=None,
main_program=None, main_program=None,
max_num_checkpoints=3): max_num_checkpoints=3,
lookup_table=None,
ps_endpoint_list=None):
""" """
This function filters out all checkpoint variables from the give This function filters out all checkpoint variables from the give
main_program and then saves these variables to the `checkpoint_dir` main_program and then saves these variables to the `checkpoint_dir`
...@@ -836,6 +841,12 @@ def save_checkpoint(executor, ...@@ -836,6 +841,12 @@ def save_checkpoint(executor,
max_num_checkpoints(int): The max number of total number of existing max_num_checkpoints(int): The max number of total number of existing
checkpoints. checkpoints.
Default: 3 Default: 3
lookup_table(string|None): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list|None): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Returns: Returns:
None None
...@@ -852,30 +863,40 @@ def save_checkpoint(executor, ...@@ -852,30 +863,40 @@ def save_checkpoint(executor,
prog = fluid.default_main_program() prog = fluid.default_main_program()
trainer_args = {"epoch_id": 200, trainer_args = {"epoch_id": 200,
"step_id": 20} # just an example "step_id": 20} # just an example
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_checkpoint(executor=exe, fluid.io.save_checkpoint(executor=exe,
checkpoint_dir=path, checkpoint_dir=path,
trainer_id=0, trainer_id=0,
trainer_args=trainer_args, trainer_args=trainer_args,
main_program=prog, main_program=prog,
max_num_checkpoints=3) max_num_checkpoints=3,
lookup_table=table_name,
ps_endpoint_list = ps_endpoints)
""" """
if checkpoint_dir is None: if checkpoint_dir is None:
raise ValueError("'checkpoint_dir' should not be None") raise ValueError("'checkpoint_dir' should not be None")
assert checkpoint_dir
if trainer_args: if trainer_args:
assert isinstance(trainer_args, dict) assert isinstance(trainer_args, dict)
if not os.path.isdir(checkpoint_dir): is_chief = trainer_id == 0
os.makedirs(checkpoint_dir)
_make_chekcpoint_dirs(checkpoint_dir)
serial = get_latest_checkpoint_serial(checkpoint_dir) + 1 serial = get_latest_checkpoint_serial(checkpoint_dir) + 1
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
save_trainer_args(cur_dir, trainer_id, trainer_args) save_trainer_args(cur_dir, trainer_id, trainer_args)
if trainer_id == 0: if is_chief:
save_persist_vars_without_grad(executor, cur_dir, main_program) save_persist_vars_without_grad(executor, cur_dir, main_program)
if is_chief and lookup_table and ps_endpoint_list:
save_pserver_vars_by_notify(executor, cur_dir, lookup_table,
ps_endpoint_list)
_scroll_delete(checkpoint_dir, max_num_checkpoints) _scroll_delete(checkpoint_dir, max_num_checkpoints)
...@@ -942,7 +963,8 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program): ...@@ -942,7 +963,8 @@ def load_checkpoint(executor, checkpoint_dir, serial, main_program):
def clean_checkpoint(checkpoint_dir, delete_dir=False): def clean_checkpoint(checkpoint_dir, delete_dir=False):
""" """
clean the checkpoint dir, when the train exits normally, the trainer will call clean_checkpoint to delete checkpoint directory saved before. clean the checkpoint dir, when the train exits normally,
the trainer will call clean_checkpoint to delete checkpoint directory saved before.
delete_dir only works when the directory is empty, otherwise, OSError is raised. delete_dir only works when the directory is empty, otherwise, OSError is raised.
: param checkpoint_dir : param checkpoint_dir
...@@ -1009,6 +1031,56 @@ def load_persist_vars_without_grad(executor, ...@@ -1009,6 +1031,56 @@ def load_persist_vars_without_grad(executor,
filename=None) filename=None)
def load_lookup_table_vars(executor, dirname, program, pserver_id, table_name):
"""
The parameter server will load lookup table's local file in
selectedrows variable.
Args:
executor(Executor): The executor to run for loading persistable variables
dirname(str): The directory path
main_program(Program): Find the variable named table_name in main_program
pserver_id(int): the serial number in pserver_endpoints list
table_name(str): lookup table name
Returns:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
dirname = "./checkpoints/checkpoint_9/__model__"
prog = fluid.default_main_program()
pserver_id = 1
table_name = "share_w"
fluid.io.load_lookup_table_vars(executor=exe,
dirname=dirname, program=prog, pserver_id=pserver_id,
table_name=table_name)
"""
for var in program.list_vars():
if var.name == table_name:
lookup_table_var = var
break
assert lookup_table_var is not None
lookup_table_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
table_file = table_name + CHECKPOINT_SEPARATOR + str(pserver_id)
load_prog = Program()
load_block = load_prog.global_block()
load_block.append_op(
type='load',
inputs={},
outputs={'Out': [lookup_table_var]},
attrs={'file_path': os.path.join(lookup_table_dir, table_file)})
executor.run(load_prog)
def save_persist_vars_without_grad(executor, dirname, program): def save_persist_vars_without_grad(executor, dirname, program):
""" """
This function filters out all checkpoint variables from the give This function filters out all checkpoint variables from the give
...@@ -1055,6 +1127,54 @@ def save_persist_vars_without_grad(executor, dirname, program): ...@@ -1055,6 +1127,54 @@ def save_persist_vars_without_grad(executor, dirname, program):
_write_success(cur_dir) _write_success(cur_dir)
def save_pserver_vars_by_notify(executor, dirname, lookup_table,
ps_endpoint_list):
"""
This function will send checkpoint notify message from Trainer 0
to all the pservers.
The checkpoint notify message contains lookup table name,
the absolute path on pserver to save lookup_table.
Args:
executor(Executor): The executor to run for send checkpoint notify.
dirname(str): The folder where to save checkpoints.
lookup_table(string): the lookup table name, when use distribute
lookup table, we can get lookup table name by DistributeTranspiler.
table_name
ps_endpoint_list(list): the parameter server ip:port list.
when use distribute lookup table, we can get ps_endpoint_list by
distribute arguments.
Return:
None
Examples:
.. code-block:: python
exe = fluid.Executor(fluid.CPUPlace())
param_path = "./my_paddle_model"
prog = fluid.default_main_program()
table_name = "share_w"
ps_endpoints = ["127.0.0.1:6000","127.0.0.1:6001"]
fluid.io.save_pserver_vars_by_notify(executor=exe,
dirname=param_path, lookup_table=table_name,
ps_endpoint_list=ps_endpoints)
"""
cur_dir = _get_lookuptable_dir(dirname)
checkpoint_notify_program = Program()
checkpoint_notify_block = checkpoint_notify_program.global_block()
attrs = {}
attrs['epmap'] = ps_endpoint_list
attrs['dir'] = cur_dir
attrs['lookup_table'] = lookup_table
checkpoint_notify_block.append_op(
type='checkpoint_notify', inputs={}, outputs={}, attrs=attrs)
executor.run(checkpoint_notify_program)
def save_trainer_args(dirname, trainer_id, trainer_args): def save_trainer_args(dirname, trainer_id, trainer_args):
assert isinstance(trainer_args, dict) assert isinstance(trainer_args, dict)
...@@ -1068,6 +1188,29 @@ def save_trainer_args(dirname, trainer_id, trainer_args): ...@@ -1068,6 +1188,29 @@ def save_trainer_args(dirname, trainer_id, trainer_args):
def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args): def load_trainer_args(checkpoint_dir, serial, trainer_id, trainer_args):
"""
trainer will load some args from it's independent directory,
such as epoch_id and step_id.
Args:
checkpoint_dir(str): The folder where all checkpoints are.
serial(int): The serial of checkpoint you would like to load.
trainer_id(int): current trainer id.
trainer_args(list): list about load trainer args
Return:
None
Examples:
.. code-block:: python
param_path = "./checkpoint/"
serial = 7
trainer_id = 2
trainer_args = ["epoch_id", "step_id"]
fluid.io.load_trainer_args(checkpoint_dir=param_path, serial=serial,
trainer_id=trainer_id, trainer_args=trainer_args)
"""
assert isinstance(trainer_args, list) assert isinstance(trainer_args, list)
cur_dir = _get_serial_dir(checkpoint_dir, serial) cur_dir = _get_serial_dir(checkpoint_dir, serial)
...@@ -1088,7 +1231,7 @@ def _is_checkpoint_var(var): ...@@ -1088,7 +1231,7 @@ def _is_checkpoint_var(var):
the checkpoint will not save or load all the variables. the checkpoint will not save or load all the variables.
var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded. var type is FEED_MINIBATCH/FETCH_LIST/RAW or var name ends with @GRAD are discarded.
: param var : param var(Variable)
""" """
if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \ if var.desc.type() == core.VarDesc.VarType.FEED_MINIBATCH or \
var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \ var.desc.type() == core.VarDesc.VarType.FETCH_LIST or \
...@@ -1108,6 +1251,23 @@ def _is_checkpoint_var(var): ...@@ -1108,6 +1251,23 @@ def _is_checkpoint_var(var):
return var.persistable return var.persistable
def _make_chekcpoint_dirs(dirs):
"""
_make_chekcpoint_dirs will makdir local directory directly, when the directory is exist, it will igore it.
"""
assert dirs is not None
if os.path.isfile(dirs):
raise OSError(errno.ENOTDIR, "dirs path shoule be a Directory.", dirs)
if not os.path.isdir(dirs):
try:
os.makedirs(dirs)
except OSError as err:
if err.errno != errno.EEXIST:
raise err
def _get_dir_serial(dirname): def _get_dir_serial(dirname):
_, serial = dirname.split(CHECKPOINT_SEPARATOR) _, serial = dirname.split(CHECKPOINT_SEPARATOR)
...@@ -1121,29 +1281,27 @@ def _get_dir_serial(dirname): ...@@ -1121,29 +1281,27 @@ def _get_dir_serial(dirname):
def _get_serial_dir(dirname, serial): def _get_serial_dir(dirname, serial):
serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial) serial_folder = CHECKPOINT_PREFIX + CHECKPOINT_SEPARATOR + str(serial)
serial_dir = os.path.join(dirname, serial_folder) serial_dir = os.path.join(dirname, serial_folder)
_make_chekcpoint_dirs(serial_dir)
if not os.path.isdir(serial_dir):
os.makedirs(serial_dir)
return serial_dir return serial_dir
def _get_model_dir(dirname): def _get_model_dir(dirname):
model_dir = os.path.join(dirname, MODEL_DIR) model_dir = os.path.join(dirname, MODEL_DIR)
_make_chekcpoint_dirs(model_dir)
return model_dir
if not os.path.isdir(model_dir):
os.makedirs(model_dir)
return model_dir def _get_lookuptable_dir(dirname):
lookuptable_dir = os.path.join(dirname, LOOKUP_TABLE_DIR)
_make_chekcpoint_dirs(lookuptable_dir)
return lookuptable_dir
def _get_trainer_dir(dirname, trainer_id): def _get_trainer_dir(dirname, trainer_id):
trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id) trainer_folder = TRAINER_PREFIX + CHECKPOINT_SEPARATOR + str(trainer_id)
trainer_dir = os.path.join(dirname, trainer_folder) trainer_dir = os.path.join(dirname, trainer_folder)
_make_chekcpoint_dirs(trainer_dir)
if not os.path.isdir(trainer_dir):
os.makedirs(trainer_dir)
return trainer_dir return trainer_dir
...@@ -1162,7 +1320,11 @@ def _scroll_delete(dirname, max_num_checkpoints=3): ...@@ -1162,7 +1320,11 @@ def _scroll_delete(dirname, max_num_checkpoints=3):
serials = serials[max_num_checkpoints:] serials = serials[max_num_checkpoints:]
for serial in serials: for serial in serials:
cur_dir = _get_serial_dir(dirname, serial) cur_dir = _get_serial_dir(dirname, serial)
try:
shutil.rmtree(cur_dir) shutil.rmtree(cur_dir)
except OSError as err:
if err.errno != errno.ENOENT:
raise err
def _write_success(dirname): def _write_success(dirname):
......
...@@ -160,7 +160,7 @@ class ParallelExecutor(object): ...@@ -160,7 +160,7 @@ class ParallelExecutor(object):
build_strategy, num_trainers, trainer_id) build_strategy, num_trainers, trainer_id)
self.scope = scope self.scope = scope
def run(self, fetch_list, feed=None, feed_dict=None): def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=False):
""" """
Run a parallel executor with fetch_list. Run a parallel executor with fetch_list.
...@@ -196,6 +196,8 @@ class ParallelExecutor(object): ...@@ -196,6 +196,8 @@ class ParallelExecutor(object):
to each device. Default None. to each device. Default None.
feed_dict: Alias for feed parameter, for backward compatibility. feed_dict: Alias for feed parameter, for backward compatibility.
This parameter has been deprecated. Default None. This parameter has been deprecated. Default None.
return_numpy(bool): Whether converts the fetched tensor to numpy.
Default: False.
Returns: Returns:
List: The fetched result list. List: The fetched result list.
...@@ -270,6 +272,9 @@ class ParallelExecutor(object): ...@@ -270,6 +272,9 @@ class ParallelExecutor(object):
if self.is_dist: if self.is_dist:
self.bcast_params() self.bcast_params()
if return_numpy:
return executor.as_numpy(arr)
return [arr[i] for i in range(len(arr))] return [arr[i] for i in range(len(arr))]
def bcast_params(self): def bcast_params(self):
......
...@@ -75,7 +75,9 @@ class TestFetchOp(unittest.TestCase): ...@@ -75,7 +75,9 @@ class TestFetchOp(unittest.TestCase):
fetch_list.append(k) fetch_list.append(k)
for data in train_inputs: for data in train_inputs:
ret = pe.run(fetch_list, feed=feeder.feed(data)) ret = pe.run(fetch_list,
feed=feeder.feed(data),
return_numpy=True)
for i in range(len(fetch_list)): for i in range(len(fetch_list)):
assert not math.isnan(np.sum(ret[i])) and \ assert not math.isnan(np.sum(ret[i])) and \
not math.isinf(np.sum(ret[i])) not math.isinf(np.sum(ret[i]))
......
...@@ -119,27 +119,20 @@ class CheckpointConfig(object): ...@@ -119,27 +119,20 @@ class CheckpointConfig(object):
max_num_checkpoints=3, max_num_checkpoints=3,
epoch_interval=1, epoch_interval=1,
step_interval=10): step_interval=10):
if checkpoint_dir is None:
self.checkpoint_dir = os.getcwd()
else:
self.checkpoint_dir = checkpoint_dir
self.max_num_checkpoints = max_num_checkpoints assert epoch_interval >= 1
assert step_interval >= 1
if epoch_interval < 1: self.checkpoint_dir = checkpoint_dir \
self.epoch_interval = 1 if checkpoint_dir is not None else os.getcwd()
else: self.max_num_checkpoints = max_num_checkpoints
self.epoch_interval = epoch_interval self.epoch_interval = epoch_interval
if step_interval < 1:
self.step_interval = 10
else:
self.step_interval = step_interval self.step_interval = step_interval
self.epoch_id = 0 self.epoch_id = 0
self.step_id = 0 self.step_id = 0
self.load_serial = None self.load_serial = None
self.is_pserver = False self.pserver_id = None
self.lookup_table_name = None
def check_and_get_place(place): def check_and_get_place(place):
...@@ -290,13 +283,20 @@ class Trainer(object): ...@@ -290,13 +283,20 @@ class Trainer(object):
self.checkpoint_cfg.load_serial, self.checkpoint_cfg.load_serial,
self.startup_program) self.startup_program)
if not self.checkpoint_cfg.is_pserver: if not self.checkpoint_cfg.pserver_id:
epoch_id, step_id = io.load_trainer_args( epoch_id, step_id = io.load_trainer_args(
self.checkpoint_cfg.checkpoint_dir, self.checkpoint_cfg.checkpoint_dir,
self.checkpoint_cfg.load_serial, self.trainer_id, self.checkpoint_cfg.load_serial, self.trainer_id,
self._get_checkpoint_load_args()) self._get_checkpoint_load_args())
self.checkpoint_cfg.epoch_id = int(epoch_id) self.checkpoint_cfg.epoch_id = int(epoch_id)
self.checkpoint_cfg.step_id = int(step_id) self.checkpoint_cfg.step_id = int(step_id)
else:
if self.checkpoint_cfg.lookup_table_name:
io.load_lookup_table_vars(
exe, self.checkpoint_cfg.checkpoint_dir,
self.startup_program,
self.checkpoint_cfg.pserver_id,
self.checkpoint_cfg.lookup_table_name)
if param_path and os.path.isdir(param_path): if param_path and os.path.isdir(param_path):
# load params from param_path into scope # load params from param_path into scope
...@@ -366,7 +366,10 @@ class Trainer(object): ...@@ -366,7 +366,10 @@ class Trainer(object):
self.trainer_id, pservers=pserver_endpoints, trainers=trainers) self.trainer_id, pservers=pserver_endpoints, trainers=trainers)
if training_role == "PSERVER": if training_role == "PSERVER":
if self.checkpoint_cfg: if self.checkpoint_cfg:
self.is_pserver = True pserver_id = eplist.index(current_endpoint)
self.checkpoint_cfg.pserver_id = pserver_id
if t.has_distributed_lookup_table:
self.checkpoint_cfg.lookup_table_name = t.table_name
self.train_program = t.get_pserver_program(current_endpoint) self.train_program = t.get_pserver_program(current_endpoint)
self.startup_program = t.get_startup_program(current_endpoint, self.startup_program = t.get_startup_program(current_endpoint,
...@@ -566,7 +569,8 @@ class Trainer(object): ...@@ -566,7 +569,8 @@ class Trainer(object):
def _save_checkpoint(self, epoch_id, step_id): def _save_checkpoint(self, epoch_id, step_id):
assert self.checkpoint_cfg assert self.checkpoint_cfg
if epoch_id % self.checkpoint_cfg.epoch_interval == 0 and step_id % self.checkpoint_cfg.step_interval == 0: if epoch_id % self.checkpoint_cfg.epoch_interval == 0 \
and step_id % self.checkpoint_cfg.step_interval == 0:
exe = executor.Executor(self.place) exe = executor.Executor(self.place)
io.save_checkpoint( io.save_checkpoint(
executor=exe, executor=exe,
......
...@@ -471,6 +471,8 @@ class DistributeTranspiler(object): ...@@ -471,6 +471,8 @@ class DistributeTranspiler(object):
pserver_index, pserver_program, pre_block_idx, grad_to_block_id) pserver_index, pserver_program, pre_block_idx, grad_to_block_id)
prefetch_var_name_to_block_id = self._create_prefetch_block( prefetch_var_name_to_block_id = self._create_prefetch_block(
pserver_index, pserver_program, table_opt_block) pserver_index, pserver_program, table_opt_block)
checkpoint_block_id = self._create_checkpoint_save_block(
pserver_program, table_opt_block.idx)
# NOTE: if has_distributed_lookup_table is False, then prefetch_block will # NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place # not be executed, so it's safe to use optimize_block to hold the place
...@@ -489,6 +491,7 @@ class DistributeTranspiler(object): ...@@ -489,6 +491,7 @@ class DistributeTranspiler(object):
if len(prefetch_var_name_to_block_id) > 0: if len(prefetch_var_name_to_block_id) > 0:
attrs['prefetch_var_name_to_block_id'] \ attrs['prefetch_var_name_to_block_id'] \
= prefetch_var_name_to_block_id = prefetch_var_name_to_block_id
attrs['checkpint_block_id'] = checkpoint_block_id
# step5 append the listen_and_serv op # step5 append the listen_and_serv op
pserver_program.global_block().append_op( pserver_program.global_block().append_op(
...@@ -910,6 +913,27 @@ class DistributeTranspiler(object): ...@@ -910,6 +913,27 @@ class DistributeTranspiler(object):
return table_opt_block return table_opt_block
def _create_checkpoint_save_block(self, pserver_program, pre_block_idx):
"""
create a new block to handle save checkpoint.
"""
import os
pserver_program.global_block().create_var(
name="kLookupTablePath",
persistable=True,
type=core.VarDesc.VarType.RAW)
checkpoint_save_block = pserver_program.create_block(pre_block_idx)
# this 'file_path' do not be used in save lookup table variable
checkpoint_save_block.append_op(
type='save',
inputs={'X': [self.table_name]},
outputs={},
attrs={'file_path': "none"})
return checkpoint_save_block.idx
def _create_vars_from_blocklist(self, def _create_vars_from_blocklist(self,
program, program,
block_list, block_list,
...@@ -1299,16 +1323,6 @@ class DistributeTranspiler(object): ...@@ -1299,16 +1323,6 @@ class DistributeTranspiler(object):
ufind.union(op1, op2) ufind.union(op1, op2)
return ufind return ufind
def _is_opt_role_op(self, op):
# NOTE: depend on oprole to find out whether this op is for
# optimize
op_maker = core.op_proto_and_checker_maker
optimize_role = core.op_proto_and_checker_maker.OpRole.Optimize
if op_maker.kOpRoleAttrName() in op.attrs and \
int(op.attrs[op_maker.kOpRoleAttrName()]) == int(optimize_role):
return True
return False
def _is_optimizer_op(self, op): def _is_optimizer_op(self, op):
if "Param" in op.input_names and \ if "Param" in op.input_names and \
"LearningRate" in op.input_names: "LearningRate" in op.input_names:
...@@ -1399,7 +1413,10 @@ class DistributeTranspiler(object): ...@@ -1399,7 +1413,10 @@ class DistributeTranspiler(object):
params_grads = [] params_grads = []
origin_var_dict = self.origin_program.global_block().vars origin_var_dict = self.origin_program.global_block().vars
for op in block.ops: for op in block.ops:
if self._is_opt_role_op(op): # NOTE(Yancey1989): we can not use op role to distinguish an optimizer op
# or not, because all ops in optimizer sub-graph would
# sign the optimizer op role
if self._is_optimizer_op(op):
opt_ops.append(op) opt_ops.append(op)
# HACK(wuyi): if we find grad vars from input of optimize # HACK(wuyi): if we find grad vars from input of optimize
# ops, we may get the output of clip op. Use syntax "@GRAD" # ops, we may get the output of clip op. Use syntax "@GRAD"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册