diff --git a/paddle/fluid/operators/checkpoint_notify_op.cc b/paddle/fluid/operators/checkpoint_notify_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..1b922e08907dc9ff4205af2b11e5277d3c73e86f --- /dev/null +++ b/paddle/fluid/operators/checkpoint_notify_op.cc @@ -0,0 +1,81 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include // NOLINT +#include + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/lod_tensor.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/macros.h" +#include "paddle/fluid/operators/send_recv_util.h" + +namespace paddle { +namespace operators { + +class CheckpointNotifyOp : public framework::OperatorBase { + public: + CheckpointNotifyOp(const std::string& type, + const framework::VariableNameMap& inputs, + const framework::VariableNameMap& outputs, + const framework::AttributeMap& attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + void RunImpl(const framework::Scope& scope, + const platform::Place& place) const override { + std::vector epmap = Attr>("epmap"); + std::string dir = Attr("dir"); + + detail::RPCClient* rpc_client = + detail::RPCClient::GetInstance(); + VLOG(3) << "sending " << ins[i] << " to " << epmap[i] << " to get " + << outs[i] << " back"; + rpc_client->AsyncCheckpointNotify(epmap[i], dir); + rpc_client->Wait(); + } +}; + +class CheckpointNotifyOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() { + AddAttr>( + "epmap", + "(string vector, default 127.0.0.1:6164)" + "Server endpoints in the order of input variables for mapping") + .SetDefault({"127.0.0.1:6164"}); + AddAttr( + "dir", "(string, default '') indicate the folder checkpoint will use"); + AddComment(R"DOC( +Prefetch operator + +This operator will send Ids variables to listen_and_serve op at +the parameter server and fetch result back. +)DOC"); + } +}; + +class CheckpointNotifyOpShapeInference : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext* ctx) const override {} +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OPERATOR(checkpointnotify, ops::CheckpointNotifyOp, + paddle::framework::EmptyGradOpMaker, + ops::CheckpointNotifyOpMaker, + ops::CheckpointNotifyOpShapeInference); diff --git a/paddle/fluid/operators/listen_and_serv_op.cc b/paddle/fluid/operators/listen_and_serv_op.cc index 0804a266d0f1e15fe8ccccc37f3981805d3926e2..088366dac7b2842dc06daa078e739fa42061a7d7 100644 --- a/paddle/fluid/operators/listen_and_serv_op.cc +++ b/paddle/fluid/operators/listen_and_serv_op.cc @@ -221,6 +221,7 @@ static void FillRequestCtx( std::unordered_map> *prefetch_ctx, + std::shared_ptr checkpoint_ctx, detail::RPCServer *rpc_server) { h->SetScope(scope); h->SetDevCtx(dev_ctx); @@ -228,6 +229,7 @@ static void FillRequestCtx( h->SetProgram(program); h->SetPrefetchPreparedCtx(prefetch_ctx); h->SetRPCServer(rpc_server); + h->SetCheckpointNotifyPreparedCtx(checkpoint_ctx); } void ListenAndServOp::RunImpl(const framework::Scope &scope, @@ -297,9 +299,14 @@ void ListenAndServOp::RunImpl(const framework::Scope &scope, prefetch_var_name_to_prepared_ctx[prefetch_var_name] = prefetch_prepared[i]; } - auto f = std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, - &dev_ctx, &executor, program, - &prefetch_var_name_to_prepared_ctx, rpc_service_.get()); + int checkpoint_point_block_id = Attr(kCheckpointBlockId); + std::shared_ptr ckpt_pre_context = + executor.Prepare(*program, checkpoint_point_block_id); + + auto f = + std::bind(FillRequestCtx, std::placeholders::_1, &recv_scope, &dev_ctx, + &executor, program, &prefetch_var_name_to_prepared_ctx, + &ckpt_pre_context, rpc_service_.get()); f(request_send_handler_.get()); f(request_get_handler_.get());