diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 2c61ae40a59bafb86beb3e8677252733e51080f1..ff90d9c5d2075ee9fe4f9dd3a42fc00fa98e18c5 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -56,3 +56,5 @@ cc_test(tensor_array_test SRCS tensor_array_test.cc DEPS tensor_array place) cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry proto_desc) +cc_library(selected_rows SRCS selected_rows.cc DEPS tensor) +cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows) diff --git a/paddle/framework/selected_rows.cc b/paddle/framework/selected_rows.cc new file mode 100644 index 0000000000000000000000000000000000000000..c74459c9dd7006a24615b1d6df041583088fb25c --- /dev/null +++ b/paddle/framework/selected_rows.cc @@ -0,0 +1,16 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/selected_rows.h" + +namespace paddle { +namespace framework {} // namespace framework +} // namespace paddle diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h new file mode 100644 index 0000000000000000000000000000000000000000..f9f563051e264ae7ed7cf3c07c0065522b2bbe2e --- /dev/null +++ b/paddle/framework/selected_rows.h @@ -0,0 +1,54 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/tensor.h" + +namespace paddle { +namespace framework { + +class SelectedRows { + public: + SelectedRows(const std::vector& rows, const int64_t& height) + : rows_(rows), height_(height) { + value_.reset(new Tensor()); + } + + SelectedRows() { value_.reset(new Tensor()); } + + platform::Place place() const { return value_->place(); } + + const Tensor& value() const { return *value_; } + + Tensor* mutable_value() { return value_.get(); } + + int64_t height() const { return height_; } + + void set_height(int64_t height) { height_ = height; } + + const std::vector& rows() const { return rows_; } + + void set_rows(const std::vector& rows) { rows_ = rows; } + + DDim GetCompleteDims() const { + std::vector dims = vectorize(value_->dims()); + dims[0] = height_; + return make_ddim(dims); + } + + private: + std::vector rows_; + std::unique_ptr value_{nullptr}; + int64_t height_; +}; + +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/selected_rows_test.cc b/paddle/framework/selected_rows_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4ee13a65d72e44693573397bb686b355effb2227 --- /dev/null +++ b/paddle/framework/selected_rows_test.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/selected_rows.h" +#include "gtest/gtest.h" + +namespace paddle { +namespace framework { + +class SelectedRowsTester : public ::testing::Test { + public: + virtual void SetUp() override { + std::vector rows{0, 4, 7}; + int64_t height = 10; + int64_t row_numel = 100; + selected_rows_.reset(new SelectedRows(rows, height)); + + Tensor* value = selected_rows_->mutable_value(); + value->mutable_data( + make_ddim({static_cast(rows.size()), row_numel}), place_); + } + + protected: + platform::CPUPlace place_; + std::unique_ptr selected_rows_{nullptr}; +}; + +TEST_F(SelectedRowsTester, height) { ASSERT_EQ(selected_rows_->height(), 10); } + +TEST_F(SelectedRowsTester, dims) { + ASSERT_EQ(selected_rows_->value().dims(), make_ddim({3, 100})); +} + +TEST_F(SelectedRowsTester, complete_dims) { + ASSERT_EQ(selected_rows_->GetCompleteDims(), make_ddim({10, 100})); +} + +} // namespace framework +} // namespace paddle