未验证 提交 da556ed6 编写于 作者: C chengduo 提交者: GitHub

enhance ParallelExecutor stable (#11637)

上级 073af623
...@@ -103,13 +103,6 @@ void BroadcastOpHandle::RunImpl() { ...@@ -103,13 +103,6 @@ void BroadcastOpHandle::RunImpl() {
}); });
} }
// FIXME(zcd): a temporary fix for some language model that has sparse
// parameter.
bool use_mutex = true;
if (in_var->IsType<paddle::framework::SelectedRows>()) {
use_mutex = false;
}
if (use_mutex) {
this->RunAndRecordEvent([&] { this->RunAndRecordEvent([&] {
{ {
platform::NCCLGroupGuard guard; platform::NCCLGroupGuard guard;
...@@ -127,26 +120,6 @@ void BroadcastOpHandle::RunImpl() { ...@@ -127,26 +120,6 @@ void BroadcastOpHandle::RunImpl() {
&VariableVisitor::GetMutableTensor(out_var)); &VariableVisitor::GetMutableTensor(out_var));
} }
}); });
} else {
this->RunAndRecordEventNoMutex([&] {
{
platform::NCCLGroupGuard guard;
for (auto &call : broadcast_calls) {
call();
}
}
if (!out_handle->IsTheSameVar(*in_var_handle)) {
auto out_var = var_scopes.at(in_var_handle->scope_idx_)
->FindVar(out_var_handles[0]->name_);
paddle::framework::TensorCopy(
in_tensor, in_var_handle->place_,
*(dev_ctxes_.at(in_var_handle->place_)),
&VariableVisitor::GetMutableTensor(out_var));
}
});
}
#else #else
PADDLE_THROW("CUDA is not enabled."); PADDLE_THROW("CUDA is not enabled.");
#endif #endif
......
...@@ -11,8 +11,8 @@ ...@@ -11,8 +11,8 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/op_handle_base.h"
#include <map>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -122,34 +122,16 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) { ...@@ -122,34 +122,16 @@ void OpHandleBase::RunAndRecordEvent(const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event if (!events_.empty()) { // Use event
std::function<void()> method = callback; std::function<void()> method = callback;
// NOTE(zcd): device context must be ordered here because RecordEvent
// will use a mutex to ensure the safe of multi-threads.
std::map<platform::DeviceContext *, platform::Place> ordered_ctxes;
for (auto &p : dev_ctxes_) { for (auto &p : dev_ctxes_) {
method = [method, p, this]() { ordered_ctxes.emplace(p.second, p.first);
static_cast<platform::CUDADeviceContext *>(p.second)->RecordEvent(
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
method);
};
}
method();
} else {
#endif
callback();
#ifdef PADDLE_WITH_CUDA
} }
#endif for (auto &p : ordered_ctxes) {
}
void OpHandleBase::RunAndRecordEventNoMutex(
const std::function<void()> &callback) {
#ifdef PADDLE_WITH_CUDA
if (!events_.empty()) { // Use event
std::function<void()> method = callback;
for (auto &p : dev_ctxes_) {
method = [method, p, this]() { method = [method, p, this]() {
static_cast<platform::CUDADeviceContext *>(p.second) static_cast<platform::CUDADeviceContext *>(p.first)->RecordEvent(
->RecordEventNoMutex( events_.at(boost::get<platform::CUDAPlace>(p.second).device),
events_.at(boost::get<platform::CUDAPlace>(p.first).device),
method); method);
}; };
} }
......
...@@ -85,10 +85,6 @@ class OpHandleBase { ...@@ -85,10 +85,6 @@ class OpHandleBase {
protected: protected:
void RunAndRecordEvent(const std::function<void()> &callback); void RunAndRecordEvent(const std::function<void()> &callback);
// FIXME(zcd): A temporary fix for some language model that has sparse
// parameter.
void RunAndRecordEventNoMutex(const std::function<void()> &callback);
void RunAndRecordEvent(platform::Place p, void RunAndRecordEvent(platform::Place p,
const std::function<void()> &callback); const std::function<void()> &callback);
......
...@@ -80,9 +80,7 @@ void ReduceOpHandle::RunImpl() { ...@@ -80,9 +80,7 @@ void ReduceOpHandle::RunImpl() {
} }
if (pre_in_var->IsType<framework::SelectedRows>()) { if (pre_in_var->IsType<framework::SelectedRows>()) {
// FIXME(zcd): A temporary fix for some language model that has sparse this->RunAndRecordEvent([&] {
// parameter.
this->RunAndRecordEventNoMutex([&] {
std::vector<const SelectedRows *> in_selected_rows = std::vector<const SelectedRows *> in_selected_rows =
GetInputValues<SelectedRows>(in_var_handles, var_scopes); GetInputValues<SelectedRows>(in_var_handles, var_scopes);
GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p, GatherSelectedRows(in_selected_rows, in_places, dev_ctxes_, t_out_p,
......
...@@ -106,14 +106,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -106,14 +106,6 @@ class CUDADeviceContext : public DeviceContext {
PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
} }
// FIXME(zcd): A temporary fix for some language model that has sparse
// parameter.
template <typename Callback>
void RecordEventNoMutex(cudaEvent_t ev, Callback callback) {
callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
}
private: private:
CUDAPlace place_; CUDAPlace place_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册