提交 3a5bce77 编写于 作者: Q qiaolongfei

try to complete

上级 c3580eae
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/detail/grpc_server.h"
#include <paddle/fluid/operators/detail/send_recv.pb.h>
using ::grpc::ServerAsyncResponseWriter;
......@@ -156,6 +157,8 @@ class RequestPrefetch final : public RequestBase {
::grpc::ByteBuffer relay;
// TODO(Yancey1989): execute the Block which containers prefetch ops
VLOG(3) << "RequestPrefetch Process in";
responder_.Finish(relay, ::grpc::Status::OK, this);
status_ = FINISH;
}
......@@ -251,6 +254,7 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
}
void AsyncGRPCServer::TryToRegisterNewPrefetchOne() {
VLOG(4) << "TryToRegisterNewPrefetchOne in";
std::unique_lock<std::mutex> lock(cq_mutex_);
if (is_shut_down_) {
return;
......@@ -287,8 +291,8 @@ void AsyncGRPCServer::HandleRequest(::grpc::ServerCompletionQueue* cq,
// https://groups.google.com/forum/#!topic/grpc-io/xftlRy-IQwM
// https://groups.google.com/forum/#!topic/grpc-io/ywATt88Ef_I
if (!ok) {
LOG(WARNING) << cq_name << " recv no regular event:argument name"
<< base->GetReqName();
LOG(WARNING) << cq_name << " recv no regular event:argument name["
<< base->GetReqName() << "]";
TryToRegisterNewOne();
delete base;
continue;
......
......@@ -28,6 +28,7 @@ std::unique_ptr<detail::AsyncGRPCServer> rpc_service_;
void StartServer(const std::string& endpoint) {
rpc_service_.reset(new detail::AsyncGRPCServer(endpoint));
rpc_service_->RunSyncUpdate();
}
TEST(PREFETCH, CPU) {
......@@ -39,13 +40,23 @@ TEST(PREFETCH, CPU) {
platform::CPUPlace place;
platform::CPUDeviceContext ctx(place);
// create var on local scope
std::string var_name("tmp_0");
auto var = scope.Var(var_name);
auto tensor = var->GetMutable<framework::LoDTensor>();
tensor->Resize({10, 10});
std::string in_var_name("in");
std::string out_var_name("out");
auto* in_var = scope.Var(in_var_name);
auto* in_tensor = in_var->GetMutable<framework::LoDTensor>();
in_tensor->Resize({10, 10});
VLOG(3) << "before mutable_data";
in_tensor->mutable_data<int>(place);
scope.Var(out_var_name);
VLOG(3) << "before fetch";
detail::RPCClient client;
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, var_name, "");
client.AsyncPrefetchVariable("127.0.0.1:8889", ctx, scope, in_var_name,
out_var_name);
client.Wait();
rpc_service_->ShutDown();
server_thread.join();
rpc_service_.reset(nullptr);
}
......@@ -80,7 +80,7 @@ enum class GrpcMethod {
};
static const int kGrpcNumMethods =
static_cast<int>(GrpcMethod::kGetVariable) + 1;
static_cast<int>(GrpcMethod::kPrefetchVariable) + 1;
inline const char* GrpcMethodName(GrpcMethod id) {
switch (id) {
......
......@@ -112,6 +112,10 @@ class ListenAndServOp : public framework::OperatorBase {
framework::Executor executor(dev_place);
rpc_service_->SetExecutor(&executor);
rpc_service_->SetPrefetchBlkdId(0);
rpc_service_->SetProgram(program);
// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
// Record received sparse variables, so that
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册