diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h index 417904ea84ac625476a3021a93050098221a792a..f9f563051e264ae7ed7cf3c07c0065522b2bbe2e 100644 --- a/paddle/framework/selected_rows.h +++ b/paddle/framework/selected_rows.h @@ -26,7 +26,9 @@ class SelectedRows { platform::Place place() const { return value_->place(); } - Tensor& value() const { return *value_; } + const Tensor& value() const { return *value_; } + + Tensor* mutable_value() { return value_.get(); } int64_t height() const { return height_; } diff --git a/paddle/framework/selected_rows_test.cc b/paddle/framework/selected_rows_test.cc index 0c7f4600b9fbe6f46e23b9f2af68b359ec9e911e..4ee13a65d72e44693573397bb686b355effb2227 100644 --- a/paddle/framework/selected_rows_test.cc +++ b/paddle/framework/selected_rows_test.cc @@ -23,8 +23,8 @@ class SelectedRowsTester : public ::testing::Test { int64_t row_numel = 100; selected_rows_.reset(new SelectedRows(rows, height)); - Tensor& value = selected_rows_->value(); - value.mutable_data( + Tensor* value = selected_rows_->mutable_value(); + value->mutable_data( make_ddim({static_cast(rows.size()), row_numel}), place_); }