diff --git a/paddle/fluid/framework/selected_rows.cc b/paddle/fluid/framework/selected_rows.cc index b4168f38949c7fcb057ec8c5c562d0529a6d9e48..06ed87e7e8a2d5324b48a466b05207042ec1b7fa 100644 --- a/paddle/fluid/framework/selected_rows.cc +++ b/paddle/fluid/framework/selected_rows.cc @@ -18,8 +18,8 @@ namespace paddle { namespace framework { struct ReAllocateVisitor { - ReAllocateVisitor(framework::Tensor* tensor, const framework::DDim& dims) - : tensor_(tensor), dims_(dims) {} + ReAllocateVisitor(const framework::DDim& dims, framework::Tensor* tensor) + : dims_(dims), tensor_(tensor) {} template void operator()() const { @@ -34,8 +34,8 @@ struct ReAllocateVisitor { tensor_->ShareDataWith(cpu_tensor); } - framework::Tensor* tensor_; framework::DDim dims_; + framework::Tensor* tensor_; }; struct TensorCopyVisitor { @@ -158,6 +158,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { } PADDLE_ENFORCE_EQ(value.dims()[0], static_cast(1), "The first dim of value should be 1."); + std::lock_guard lock(*auto_grown_mutex_.get()); auto index = Index(key); bool is_new_key = false; if (index == -1) { @@ -169,7 +170,7 @@ bool SelectedRows::Set(int64_t key, const framework::Tensor& value) { auto dims = value_->dims(); dims[0] = (dims[0] + 1) << 1; framework::VisitDataType(framework::ToDataType(value.type()), - ReAllocateVisitor(value_.get(), dims)); + ReAllocateVisitor(dims, value_.get())); } } diff --git a/paddle/fluid/framework/selected_rows.h b/paddle/fluid/framework/selected_rows.h index c80b05eed9b1c50325316057a8afc26d5d52e82c..7160670ddd204c20021ea87cdd67ee4721d03451 100644 --- a/paddle/fluid/framework/selected_rows.h +++ b/paddle/fluid/framework/selected_rows.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include +#include +#include // NOLINT #include #include @@ -46,11 +48,13 @@ class SelectedRows { SelectedRows(const std::vector& rows, const int64_t& height) : rows_(rows), height_(height) { value_.reset(new Tensor()); + auto_grown_mutex_.reset(new std::mutex); } SelectedRows() { height_ = 0; value_.reset(new Tensor()); + auto_grown_mutex_.reset(new std::mutex); } platform::Place place() const { return value_->place(); } @@ -125,6 +129,7 @@ class SelectedRows { Vector rows_; std::unique_ptr value_{nullptr}; int64_t height_; + std::unique_ptr auto_grown_mutex_{nullptr}; }; /*