From e3c041d319b0effb32b240814c507a656dd6fb32 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Tue, 29 May 2018 19:57:57 +0800 Subject: [PATCH] add auto_grown_mutex for selected rows --- paddle/fluid/framework/selected_rows.cc | 7 ++++--- paddle/fluid/framework/selected_rows.h | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index 56cf6693caf..c9d2388aa43 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -18,8 +18,8 @@ namespace paddle { namespace framework { struct ReAllocateVisitor { - ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims) - : tensor_(tensor), dims_(dims) {} + ReAllocateVisitor(const framework::DDim& dims, framework::Tensor* tensor) + : dims_(dims), tensor_(tensor) {} template void operator()() const { @@ -153,6 +153,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { } PADDLE_ENFORCE_EQ(value.dims()[0], static_cast(1), "The first dim of value should be 1."); + std::lock_guard lock(auto_grown_mutex_); auto index = Index(key); bool is_new_key = false; if (index == -1) { @@ -164,7 +165,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { auto dims = value_->dims(); dims[0] = (dims[0] + 1) << 1; framework::VisitDataType(framework::ToDataType(value.type()), - ReAllocateVisitor(value_.get(), dims)); + ReAllocateVisitor(dims, value_.get())); } } diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index c27c927ee75..487c7390875 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -125,6 +125,7 @@ class SelectedRows { Vector rows_; std::unique_ptr value_{nullptr}; int64_t height_; + std::mutex auto_grown_mutex_; }; /* -- GitLab