提交 ad6c0142 编写于 作者: Y Yancey1989

clean up codes

上级 268e9dc1
......@@ -3,7 +3,6 @@ cc_library(op_handle_base SRCS op_handle_base.cc DEPS var_handle device_context
cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(send_op_handle SRCS send_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
......@@ -27,7 +26,7 @@ endif()
cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor)
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle send_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)
cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
......
......@@ -19,7 +19,6 @@
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
#include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h"
#include "paddle/fluid/framework/details/send_op_handle.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/scope.h"
......@@ -141,7 +140,6 @@ bool MultiDevSSAGraphBuilder::IsDistTrainOp(
return checker(op.OutputArgumentNames(), send_vars) ||
checker(op.InputArgumentNames(), recv_vars);
return false;
}
bool MultiDevSSAGraphBuilder::IsRPCOp(const OpDesc &op) const {
......@@ -471,17 +469,16 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
ConnectOp(result, result->ops_.back().get(), "send_barrier");
} else if (op.Type() == "fetch_barrier") {
ConnectOp(result, result->ops_.back().get(), "recv");
} else if (op.Type() == "send" || op.Type() == "send_vars") {
} else if (op.Type() == "send_vars") {
// do nothing
} else {
PADDLE_THROW(
"rpc op should be in [send,"
"rpc op should be in ["
"send_vars, send_barrier. recv, fetch_barrier]");
}
// FIXME(wuyi): send op always copy from GPU 0
// Create inputs for output on original place and no ssa output
// is created for send op.
// TODO(Yancey1989): schedule rpc op on different place may
// increate throughput
CreateOpHandleIOs(result, op, 0);
}
......
......@@ -31,6 +31,7 @@ void RPCOpHandle::RunImpl() {
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
// FIXME(Yancey1989): need a better solution instead of use DebugString()
if (in->DebugString() == "dummy") { // HACK
continue;
}
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/send_op_handle.h"
namespace paddle {
namespace framework {
namespace details {
SendOpHandle::SendOpHandle(const framework::OpDesc &op_desc,
const Scope *local_scope,
const platform::Place &place)
: op_(framework::OpRegistry::CreateOp(op_desc)),
local_scope_(local_scope),
place_(place) {}
void SendOpHandle::RunImpl() {
// TODO(wuyi): need further analysis whether wait VarDummyHandle.
// Wait input done
for (auto *in : inputs_) {
auto &p = static_cast<VarHandle *>(in)->place_;
if (in->DebugString() == "dummy") { // HACK
continue;
}
if (in->generated_op_) {
in->generated_op_->RecordWaitEventOnCtx(dev_ctxes_[p]);
}
}
auto &tmp_scope = local_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>();
// FIXME(wuyi): can not use RunAndRecordEvent here, for it will cause dead
// lock.
op_->Run(*tmp_scope, place_);
}
std::string SendOpHandle::Name() const { return "send"; }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
namespace details {
struct SendOpHandle : public OpHandleBase {
SendOpHandle(const framework::OpDesc& op_desc, const Scope* local_scope,
const platform::Place& place);
std::string Name() const override;
// Delay and buffer nccl_all_reduce together can significantly increase
// performance. Disable this feature by returning false.
bool IsMultiDeviceTransfer() override { return false; };
protected:
void RunImpl() override;
private:
std::unique_ptr<OperatorBase> op_;
const Scope* local_scope_;
const platform::Place& place_;
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -39,6 +39,7 @@ class Variable {
template <typename T>
T* GetMutable() {
// TODO(Yancey1989): need to make Variable completely thread-safe.
std::unique_lock<std::mutex> lock(mutex_);
if (!IsType<T>()) {
holder_.reset(new PlaceholderImpl<T>(new T()));
......
......@@ -13,8 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#pragma once
namespace paddle {
namespace inference {
namespace analysis {
......
......@@ -249,6 +249,7 @@ bool RPCClient::Proceed() {
return true;
}
std::shared_ptr<grpc::Channel> RPCClient::GetChannel(const std::string& ep) {
// TODO(Yancey1989): make grpc client completely thread-safe
std::unique_lock<std::mutex> lock(mutex_);
auto it = channels_.find(ep);
if (it != channels_.end()) {
......
......@@ -38,7 +38,7 @@ class RecvOp : public framework::OperatorBase {
auto outs = Outputs("Out");
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
auto client_var_name = Output("RPCClient");
int sync_recv = Attr<int>("sync_recv");
int sync_mode = Attr<int>("sync_mode");
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
......@@ -55,7 +55,7 @@ class RecvOp : public framework::OperatorBase {
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
rpc_client->AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}
if (sync_recv) {
if (sync_mode) {
PADDLE_ENFORCE(rpc_client->Wait());
}
}
......@@ -78,7 +78,7 @@ This operator can get variables from server side.
"Server endpoints in the order of input "
"variables for mapping")
.SetDefault({});
AddAttr<int>("sync_recv",
AddAttr<int>("sync_mode",
"(int, default 0)"
"sync recv or async recv.")
.SetDefault(0);
......
......@@ -360,19 +360,6 @@ class DistributeTranspiler:
ps_dispatcher.reset()
eplist = ps_dispatcher.dispatch(recv_vars)
#program.global_block().append_op(
# type="recv",
# inputs={},
# outputs={"Out": recv_vars,
# "RPCClient": rpc_client_var},
# attrs={"epmap": eplist})
#program.global_block().append_op(
# type="fetch_barrier",
# inputs={},
# outputs={"RPCClient": rpc_client_var},
# attrs={"endpoints": pserver_endpoints})
for i, ep in enumerate(eplist):
self.param_grad_ep_mapping[ep]["params"].append(recv_vars[i])
self.param_grad_ep_mapping[ep]["grads"].append(send_vars[i])
......
......@@ -41,7 +41,7 @@ class PSDispatcher(object):
class HashName(PSDispatcher):
"""
Hash variable names to servral endpoints
Hash variable names to several endpoints
"""
def __init__(self, pserver_endpoints):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册