diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index 56cf6693caf4529d6e157e6e9a0d5c27d05ee0c3..c9d2388aa4375d797d9eb964934b1ad379caa539 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -18,8 +18,8 @@ namespace paddle { namespace framework { struct ReAllocateVisitor { - ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims) - : tensor_(tensor), dims_(dims) {} + ReAllocateVisitor(const framework::DDim& dims, framework::Tensor* tensor) + : dims_(dims), tensor_(tensor) {} template void operator()() const { @@ -153,6 +153,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { } PADDLE_ENFORCE_EQ(value.dims()[0], static_cast(1), "The first dim of value should be 1."); + std::lock_guard lock(auto_grown_mutex_); auto index = Index(key); bool is_new_key = false; if (index == -1) { @@ -164,7 +165,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { auto dims = value_->dims(); dims[0] = (dims[0] + 1) << 1; framework::VisitDataType(framework::ToDataType(value.type()), - ReAllocateVisitor(value_.get(), dims)); + ReAllocateVisitor(dims, value_.get())); } } diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index c27c927ee751c4392840bfb71f4814991b23a8c9..487c73908750065e5f62b53fe9fe17156d347cb3 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -125,6 +125,7 @@ class SelectedRows { Vector rows_; std::unique_ptr value_{nullptr}; int64_t height_; + std::mutex auto_grown_mutex_; }; /*