未验证 提交 7d26dd81 编写于 作者: C chengduo 提交者: GitHub

Enhance Parallel Executor stable (#11634)

* Fix Parallel Exe(VarHandel's version)

* Fix broadcast

* enhance ParallelExecutor stable
上级 b5c5a3aa
...@@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() { ...@@ -73,6 +73,9 @@ void BroadcastOpHandle::RunImpl() {
int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device; int root_id = boost::get<platform::CUDAPlace>(in_tensor.place()).device;
std::vector<std::function<void()>> broadcast_calls; std::vector<std::function<void()>> broadcast_calls;
int type = platform::ToNCCLDataType(in_tensor.type());
size_t numel = static_cast<size_t>(in_tensor.numel());
for (auto out_var_handle : out_var_handles) { for (auto out_var_handle : out_var_handles) {
Variable *out_var = var_scopes.at(out_var_handle->scope_idx_) Variable *out_var = var_scopes.at(out_var_handle->scope_idx_)
->FindVar(out_var_handle->name_); ->FindVar(out_var_handle->name_);
...@@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() { ...@@ -87,13 +90,11 @@ void BroadcastOpHandle::RunImpl() {
send_recv_buffer = const_cast<void *>(in_tensor.data<void>()); send_recv_buffer = const_cast<void *>(in_tensor.data<void>());
out_handle = out_var_handle; out_handle = out_var_handle;
} else { } else {
send_recv_buffer = send_recv_buffer = VariableVisitor::GetMutableTensor(out_var)
VariableVisitor::GetMutableTensor(out_var).mutable_data( .Resize(in_tensor.dims())
out_var_handle->place_); .mutable_data(out_var_handle->place_);
} }
int type = platform::ToNCCLDataType(in_tensor.type());
size_t numel = static_cast<size_t>(in_tensor.numel());
broadcast_calls.emplace_back( broadcast_calls.emplace_back(
[send_recv_buffer, numel, type, root_id, &nccl_ctx] { [send_recv_buffer, numel, type, root_id, &nccl_ctx] {
PADDLE_ENFORCE(platform::dynload::ncclBcast( PADDLE_ENFORCE(platform::dynload::ncclBcast(
......
...@@ -351,7 +351,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result, ...@@ -351,7 +351,7 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
auto &prev_grad = vars.back(); auto &prev_grad = vars.back();
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
auto var = new VarHandle(vars.size() - 1, i, og, p); auto var = new VarHandle(vars.size(), i, og, p);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
} }
...@@ -447,8 +447,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result, ...@@ -447,8 +447,7 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(SSAGraph *result,
op_handle->AddInput(prev_grad.get()); op_handle->AddInput(prev_grad.get());
} }
auto &vars = result->vars_[dst_dev_id][og]; auto &vars = result->vars_[dst_dev_id][og];
auto var = auto var = new VarHandle(vars.size(), dst_dev_id, og, places_[dst_dev_id]);
new VarHandle(vars.size() - 1, dst_dev_id, og, places_[dst_dev_id]);
vars.emplace_back(var); vars.emplace_back(var);
op_handle->AddOutput(var); op_handle->AddOutput(var);
return var; return var;
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include <map>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -122,11 +122,16 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { ...@@ -122,11 +122,16 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event if (!events_.empty()) { // Use event
std::function<void()> method = callback; std::function<void()> method = callback;
// NOTE(zcd): device context must be ordered here because RecordEvent
// will use a mutex to ensure the safe of multi-threads.
std::map<platform::DeviceContext *, platform::Place> ordered_ctxes;
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
ordered_ctxes.emplace(p.second, p.first);
}
for (auto &p : ordered_ctxes) {
method = [method, p, this]() { method = [method, p, this]() {
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent( static_cast<platform::CUDADeviceContext *>(p.first)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.first).device), events_.at(boost::get<platform::CUDAPlace>(p.second).device),
method); method);
}; };
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册