提交 919d72a8 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Extends functionality of Tensor::flat_inner_dims() and...

Extends functionality of Tensor::flat_inner_dims() and Tensor::flat_outer_dims() by adding a template param to specify the desired rank of the output. This is useful, e.g., in order to turn a tensor into a (rank-3) batch of matrices.
Change: 119685357
上级 353bdea6
......@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/stl_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
......@@ -713,4 +714,36 @@ void Tensor::FillDescription(TensorDescription* description) const {
}
}
gtl::InlinedVector<int64, 5> Tensor::ComputeFlatInnerDims(
int64 num_out_dims) const {
gtl::InlinedVector<int64, 5> out_dims(num_out_dims, 0);
const int64 num_elements = NumElements();
if (num_elements != 0) {
int64 prod_out_dims = 1;
for (int64 out_dim = num_out_dims - 1; out_dim > 0; --out_dim) {
const int64 in_dim = out_dim + (dims() - num_out_dims);
out_dims[out_dim] =
(in_dim >= dims() || in_dim < 0) ? 1 : dim_size(in_dim);
prod_out_dims *= out_dims[out_dim];
}
out_dims[0] = num_elements / prod_out_dims;
}
return out_dims;
}
gtl::InlinedVector<int64, 5> Tensor::ComputeFlatOuterDims(
int64 num_out_dims) const {
gtl::InlinedVector<int64, 5> out_dims(num_out_dims, 0);
const int64 num_elements = NumElements();
if (num_elements != 0) {
int64 prod_out_dims = 1;
for (int64 out_dim = 0; out_dim < num_out_dims - 1; ++out_dim) {
out_dims[out_dim] = out_dim >= dims() ? 1 : dim_size(out_dim);
prod_out_dims *= out_dims[out_dim];
}
out_dims[num_out_dims - 1] = num_elements / prod_out_dims;
}
return out_dims;
}
} // namespace tensorflow
......@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
......@@ -243,40 +244,28 @@ class Tensor {
///
/// ```
template <typename T>
typename TTypes<T>::Flat flat();
typename TTypes<T>::Flat flat() {
return shaped<T, 1>({NumElements()});
}
template <typename T>
typename TTypes<T>::UnalignedFlat unaligned_flat() {
return unaligned_shaped<T, 1>({NumElements()});
}
/// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
/// Tensor dimensions but the last one into the first dimension of the result.
template <typename T>
typename TTypes<T>::Matrix flat_inner_dims() {
int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1;
if (last_size == 0) {
DCHECK_EQ(NumElements(), 0);
// Return something empty, avoiding divide by 0
return shaped<T, 2>({0, 0});
} else {
return shaped<T, 2>({NumElements() / last_size, last_size});
}
}
/// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
/// Tensor dimensions but the last NDIMS-1 into the first dimension of the
/// result. If NDIMS > dims() then leading dimensions of size 1 will be
/// added to make the output rank NDIMS.
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::Tensor flat_inner_dims();
/// Returns the data as an Eigen::Tensor with 2 dimensions, collapsing all
/// Tensor dimensions but the first one into the last dimension of the result.
template <typename T>
typename TTypes<T>::Matrix flat_outer_dims() {
int64 first_size = dims() > 0 ? dim_size(0) : 1;
if (first_size == 0) {
DCHECK_EQ(NumElements(), 0);
// Return something empty, avoiding divide by 0
return shaped<T, 2>({0, 0});
} else {
return shaped<T, 2>({first_size, NumElements() / first_size});
}
}
/// Returns the data as an Eigen::Tensor with NDIMS dimensions, collapsing all
/// Tensor dimensions but the first NDIMS-1 into the last dimension of the
/// result. If NDIMS > dims() then trailing dimensions of size 1 will be
/// added to make the output rank NDIMS.
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::Tensor flat_outer_dims();
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor shaped(gtl::ArraySlice<int64> new_sizes);
......@@ -308,31 +297,19 @@ class Tensor {
typename TTypes<T, NDIMS>::ConstTensor tensor() const;
template <typename T>
typename TTypes<T>::ConstFlat flat() const;
typename TTypes<T>::ConstFlat flat() const {
return shaped<T, 1>({NumElements()});
}
template <typename T>
typename TTypes<T>::UnalignedConstFlat unaligned_flat() const {
return unaligned_shaped<T, 1>({NumElements()});
}
template <typename T>
typename TTypes<T>::ConstMatrix flat_inner_dims() const {
int64 last_size = dims() > 0 ? dim_size(dims() - 1) : 1;
if (last_size == 0) {
DCHECK_EQ(NumElements(), 0);
// Return something empty, avoiding divide by 0
return shaped<T, 2>({0, 0});
} else {
return shaped<T, 2>({NumElements() / last_size, last_size});
}
}
template <typename T>
typename TTypes<T>::ConstMatrix flat_outer_dims() const;
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor shaped(
gtl::ArraySlice<int64> new_sizes) const;
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::UnalignedConstTensor unaligned_shaped(
gtl::ArraySlice<int64> new_sizes) const;
......@@ -340,6 +317,12 @@ class Tensor {
template <typename T>
typename TTypes<T>::ConstScalar scalar() const;
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::ConstTensor flat_inner_dims() const;
template <typename T, size_t NDIMS = 2>
typename TTypes<T, NDIMS>::ConstTensor flat_outer_dims() const;
/// Render the first `max_entries` values in `*this` into a string.
string SummarizeValue(int64 max_entries) const;
......@@ -378,6 +361,8 @@ class Tensor {
void FillDimsAndValidateCompatibleShape(
gtl::ArraySlice<int64> new_sizes,
Eigen::array<Eigen::DenseIndex, NDIMS>* dims) const;
gtl::InlinedVector<int64, 5> ComputeFlatInnerDims(int64 num_out_dims) const;
gtl::InlinedVector<int64, 5> ComputeFlatOuterDims(int64 num_out_dims) const;
TensorShape shape_;
TensorBuffer* buf_;
......@@ -534,26 +519,24 @@ typename TTypes<T>::ConstScalar Tensor::scalar() const {
return typename TTypes<T>::ConstScalar(base<T>());
}
template <typename T>
typename TTypes<T>::Flat Tensor::flat() {
return shaped<T, 1>({NumElements()});
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::flat_inner_dims() {
return shaped<T, NDIMS>(ComputeFlatInnerDims(NDIMS));
}
template <typename T>
typename TTypes<T>::ConstFlat Tensor::flat() const {
return shaped<T, 1>({NumElements()});
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::Tensor Tensor::flat_outer_dims() {
return shaped<T, NDIMS>(ComputeFlatOuterDims(NDIMS));
}
template <typename T>
typename TTypes<T>::ConstMatrix Tensor::flat_outer_dims() const {
int64 first_size = dims() > 0 ? dim_size(0) : 1;
if (first_size == 0) {
DCHECK_EQ(NumElements(), 0);
// Return something empty, avoiding divide by 0
return shaped<T, 2>({0, 0});
} else {
return shaped<T, 2>({first_size, NumElements() / first_size});
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_inner_dims() const {
return shaped<T, NDIMS>(ComputeFlatInnerDims(NDIMS));
}
template <typename T, size_t NDIMS>
typename TTypes<T, NDIMS>::ConstTensor Tensor::flat_outer_dims() const {
return shaped<T, NDIMS>(ComputeFlatOuterDims(NDIMS));
}
} // namespace tensorflow
......
......@@ -224,6 +224,49 @@ TEST(Tensor_Float, Reshape) {
EXPECT_EQ(flat_inner_dims(0, 0), 0.01f);
EXPECT_EQ(flat_inner_dims(23, 4), 0.02f);
}
{
auto flat_outer_dims = t.flat_outer_dims<float>();
EXPECT_EQ(2, flat_outer_dims.dimension(0));
EXPECT_EQ(60, flat_outer_dims.dimension(1));
EXPECT_EQ(flat_outer_dims(0, 0), 0.01f);
EXPECT_EQ(flat_outer_dims(1, 59), 0.02f);
}
{
auto flat_inner_dims = t.flat_inner_dims<float, 3>();
EXPECT_EQ(6, flat_inner_dims.dimension(0));
EXPECT_EQ(4, flat_inner_dims.dimension(1));
EXPECT_EQ(5, flat_inner_dims.dimension(2));
EXPECT_EQ(flat_inner_dims(0, 0, 0), 0.01f);
EXPECT_EQ(flat_inner_dims(5, 3, 4), 0.02f);
}
{
auto flat_outer_dims = t.flat_outer_dims<float, 3>();
EXPECT_EQ(2, flat_outer_dims.dimension(0));
EXPECT_EQ(3, flat_outer_dims.dimension(1));
EXPECT_EQ(20, flat_outer_dims.dimension(2));
EXPECT_EQ(flat_outer_dims(0, 0, 0), 0.01f);
EXPECT_EQ(flat_outer_dims(1, 2, 19), 0.02f);
}
{
auto flat_inner_dims = t.flat_inner_dims<float, 5>();
EXPECT_EQ(1, flat_inner_dims.dimension(0));
EXPECT_EQ(2, flat_inner_dims.dimension(1));
EXPECT_EQ(3, flat_inner_dims.dimension(2));
EXPECT_EQ(4, flat_inner_dims.dimension(3));
EXPECT_EQ(5, flat_inner_dims.dimension(4));
EXPECT_EQ(flat_inner_dims(0, 0, 0, 0, 0), 0.01f);
EXPECT_EQ(flat_inner_dims(0, 1, 2, 3, 4), 0.02f);
}
{
auto flat_outer_dims = t.flat_outer_dims<float, 5>();
EXPECT_EQ(2, flat_outer_dims.dimension(0));
EXPECT_EQ(3, flat_outer_dims.dimension(1));
EXPECT_EQ(4, flat_outer_dims.dimension(2));
EXPECT_EQ(5, flat_outer_dims.dimension(3));
EXPECT_EQ(1, flat_outer_dims.dimension(4));
EXPECT_EQ(flat_outer_dims(0, 0, 0, 0, 0), 0.01f);
EXPECT_EQ(flat_outer_dims(1, 2, 3, 4, 0), 0.02f);
}
}
TEST(Tensor_Scalar, Basics) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册