From 4b13c80eeb19082cf83e76469594b802c8008070 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 12 Oct 2017 20:28:34 -0700 Subject: [PATCH] add selected rows --- paddle/framework/CMakeLists.txt | 3 ++ paddle/framework/selected_rows.cc | 16 ++++++++ paddle/framework/selected_rows.h | 52 ++++++++++++++++++++++++++ paddle/framework/selected_rows_test.cc | 47 +++++++++++++++++++++++ 4 files changed, 118 insertions(+) create mode 100644 paddle/framework/selected_rows.cc create mode 100644 paddle/framework/selected_rows.h create mode 100644 paddle/framework/selected_rows_test.cc diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 14947b6f2..312df0fd7 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -53,3 +53,6 @@ endif() cc_library(tensor_array SRCS tensor_array.cc DEPS lod_tensor) cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place) + +cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) +cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) diff --git a/paddle/framework/selected_rows.cc b/paddle/framework/selected_rows.cc new file mode 100644 index 000000000..c74459c9d --- /dev/null +++ b/paddle/framework/selected_rows.cc @@ -0,0 +1,16 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/selected_rows.h" + +namespace paddle { +namespace framework {} // namespace framework +} // namespace paddle diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h new file mode 100644 index 000000000..2bd6b160b --- /dev/null +++ b/paddle/framework/selected_rows.h @@ -0,0 +1,52 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace framework { + +class SelectedRows { + public: + SelectedRows(const std::vector& rows, const int64_t& height) + : rows_(rows), height_(height) { + value_.reset(new Tensor()); + } + + SelectedRows() {} + + platform::Place place() const { return value_->place(); } + + Tensor& value() const { return *value_; } + + int64_t height() const { return height_; } + + void set_height(int64_t height) { height_ = height; } + + const std::vector& rows() const { return rows_; } + + void set_rows(const std::vector& rows) { rows_ = rows; } + + DDim GetCompleteDims() const { + std::vector dims = vectorize(value_->dims()); + dims[0] = height_; + return make_ddim(dims); + } + + private: + std::vector rows_; + std::unique_ptr value_{nullptr}; + int64_t height_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/selected_rows_test.cc b/paddle/framework/selected_rows_test.cc new file mode 100644 index 000000000..0c7f4600b --- /dev/null +++ b/paddle/framework/selected_rows_test.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/selected_rows.h" +#include "gtest/gtest.h" + +namespace paddle { +namespace framework { + +class SelectedRowsTester : public ::testing::Test { + public: + virtual void SetUp() override { + std::vector rows{0, 4, 7}; + int64_t height = 10; + int64_t row_numel = 100; + selected_rows_.reset(new SelectedRows(rows, height)); + + Tensor& value = selected_rows_->value(); + value.mutable_data( + make_ddim({static_cast(rows.size()), row_numel}), place_); + } + + protected: + platform::CPUPlace place_; + std::unique_ptr selected_rows_{nullptr}; +}; + +TEST_F(SelectedRowsTester, height) { ASSERT_EQ(selected_rows_->height(), 10); } + +TEST_F(SelectedRowsTester, dims) { + ASSERT_EQ(selected_rows_->value().dims(), make_ddim({3, 100})); +} + +TEST_F(SelectedRowsTester, complete_dims) { + ASSERT_EQ(selected_rows_->GetCompleteDims(), make_ddim({10, 100})); +} + +} // namespace framework +} // namespace paddle -- GitLab