未验证 提交 9ffd5eec 编写于 作者: W Wu Yi 提交者: GitHub

test fix fetch bar place for ce (#16406)

* test fix fetch bar place for ce

* fix ps mode dist train in develop test=develop

* fix style check test=develop

* update test=develop
上级 4f2278f0
......@@ -5,6 +5,7 @@ cc_library(scale_loss_grad_op_handle SRCS scale_loss_grad_op_handle.cc DEPS op_h
cc_library(fetch_op_handle SRCS fetch_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory)
cc_library(computation_op_handle SRCS computation_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(fetch_barrier_op_handle SRCS fetch_barrier_op_handle.cc DEPS framework_proto scope place operator op_registry)
cc_library(multi_devices_helper SRCS multi_devices_helper.cc DEPS graph graph_helper)
cc_library(multi_devices_graph_print_pass SRCS multi_devices_graph_print_pass.cc DEPS multi_devices_helper)
......@@ -72,7 +73,7 @@ cc_library(sequential_execution_pass SRCS sequential_execution_pass.cc DEPS grap
cc_library(all_reduce_deps_pass SRCS all_reduce_deps_pass.cc DEPS graph graph_helper pass)
cc_library(multi_devices_graph_pass SRCS multi_devices_graph_pass.cc DEPS multi_devices_helper computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
scale_loss_grad_op_handle rpc_op_handle fetch_barrier_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle fused_broadcast_op_handle)
cc_library(fuse_all_reduce_op_pass SRCS fuse_all_reduce_op_pass.cc DEPS graph graph_helper fused_all_reduce_op_handle)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/fetch_barrier_op_handle.h"
#include <string>
namespace paddle {
namespace framework {
namespace details {
FetchBarrierOpHandle::FetchBarrierOpHandle(
ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places)
// fetch_barrier op always run on place0, but output on all places.
: OpHandleBase(node),
op_(framework::OpRegistry::CreateOp(*node->Op())),
local_scopes_(local_scopes),
places_(places),
run_scope_(local_scopes[0]),
place_(places[0]) {
for (auto &p : places) {
this->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p));
}
}
bool FetchBarrierOpHandle::IsMultiDeviceTransfer() {
// override IsMultiDeviceTransfer to return true
return true;
}
void FetchBarrierOpHandle::RunImpl() {
WaitInputVarGenerated(place_);
auto run_func = [this]() {
op_->Run(*run_scope_->FindVar(kLocalExecScopeName)->Get<Scope *>(), place_);
};
if (is_lock_and_record_event_free_) {
run_func();
} else {
this->RunAndRecordEvent(run_func);
}
}
bool FetchBarrierOpHandle::NeedWait(VarHandleBase *in_var) {
bool need_wait =
in_var && in_var->GeneratedOp() &&
in_var->GeneratedOp()->DeviceContext(place_) != dev_ctxes_.at(place_);
return need_wait;
}
std::string FetchBarrierOpHandle::Name() const { return op_->Type(); }
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace details {
// **NOTE**: fetch_barrier op is special it outputs all recved variables on
// all places if there are multiple places, must init with
// multiple dev_ctxes_ !!!!
struct FetchBarrierOpHandle : public OpHandleBase {
public:
FetchBarrierOpHandle(ir::Node *node, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places);
bool IsMultiDeviceTransfer() override;
std::string Name() const override;
protected:
void RunImpl() override;
bool NeedWait(VarHandleBase *in_var) override;
private:
std::unique_ptr<OperatorBase> op_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
Scope *run_scope_;
platform::Place place_;
bool is_lock_and_record_event_free_{false};
};
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -23,6 +23,7 @@
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/fetch_barrier_op_handle.h"
#include "paddle/fluid/framework/details/fused_broadcast_op_handle.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
......@@ -851,9 +852,17 @@ int DistSSAGraphBuilder::CreateRPCOp(ir::Graph *result, ir::Node *node) const {
PADDLE_ENFORCE(op_dev_id != -1, "can not find the right place for rpc op: %s",
node->Op()->Type());
result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
node->Op()->Type(), places_[op_dev_id]));
// Create fetch_barrier op handle to enable output on all devices.
// **NOTE** fetch_barrier should output variables list same as recv op does.
if (node->Op()->Type() == "fetch_barrier") {
result->Get<GraphOps>(kGraphOps).emplace_back(new FetchBarrierOpHandle(
result->CreateOpNode(node->Op()), local_scopes_, places_));
} else {
result->Get<GraphOps>(kGraphOps).emplace_back(new RPCOpHandle(
result->CreateOpNode(node->Op()), *node->Op(), local_scopes_[op_dev_id],
node->Op()->Type(), places_[op_dev_id]));
}
if (node->Op()->Type() == "send") {
CreateOpHandleIOs(result, node, op_dev_id);
......
......@@ -55,7 +55,7 @@ void OpHandleBase::Run(bool use_cuda) {
if (out_var_handle) {
int dev_id =
boost::get<platform::CUDAPlace>(out_var_handle->place()).device;
out_var_handle->SetGenerateEvent(events_[dev_id]);
out_var_handle->SetGenerateEvent(events_.at(dev_id));
}
}
} else {
......@@ -71,7 +71,7 @@ void OpHandleBase::Run(bool use_cuda) {
"The place of input(%s) is not consistent with the "
"place of current op(%s).",
out_var_handle->Name(), Name());
out_var_handle->SetGenerateEvent(events_[dev_id]);
out_var_handle->SetGenerateEvent(events_.at(dev_id));
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册