提交 0c4be7e6 编写于 作者: H hedaoyuan

add TensorType.h

上级 904eefaf
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <glog/logging.h>
namespace paddle {
/**
* TensorShape used to represent shape of normal tensor.
*/
class TensorShape {
public:
TensorShape() : ndims_(0), nelements_(0) { initDims(0); }
TensorShape(size_t ndims) : ndims_(ndims), nelements_(1) { initDims(ndims); };
TensorShape(std::initializer_list<size_t> dims) {
ndims_ = dims.size();
initDims(ndims_);
std::copy(dims.begin(), dims.end(), dims_.begin());
numElements();
};
TensorShape(const TensorShape& t)
: ndims_(t.ndims_), nelements_(t.nelements_) {
initDims(ndims_);
std::copy(t.dims_.begin(), t.dims_.end(), dims_.begin());
};
// get the size of specified dimension
size_t operator[](size_t dim) const {
CHECK_GE(dim, 0);
CHECK_LT(dim, ndims_);
return dims_[dim];
}
// set the size of specified dimension
void setDim(size_t dim, size_t size) {
CHECK_GE(dim, 0);
CHECK_LT(dim, ndims_);
dims_[dim] = size;
numElements();
}
// number of dimensions of the tensor
size_t ndims() const { return ndims_; }
size_t getElements() const { return nelements_; }
bool operator==(const TensorShape& t) const {
if (ndims() != t.ndims()) return false;
for (size_t i = 0; i < ndims(); i++) {
if (dims_[i] != t.dims_[i]) return false;
}
return true;
}
bool operator!=(const TensorShape& t) const { return !(*this == t); }
private:
// compute number of elements
void numElements() {
nelements_ = 1;
for (size_t n = 0; n < ndims_; n++) {
nelements_ *= dims_[n];
}
}
// init dims_
void initDims(size_t ndims) {
size_t count = ndims < 4 ? 4 : ndims;
dims_.assign(count, 1);
}
// number of dimensions
// ndims_ may be not equeal dims_.size()
size_t ndims_;
// number of elements
size_t nelements_;
std::vector<size_t> dims_;
};
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "TensorShape.h"
#include <gtest/gtest.h>
namespace paddle {
TEST(TensorShape, Constructor) {
TensorShape t1;
EXPECT_EQ(t1.ndims(), 0);
EXPECT_EQ(t1.getElements(), 0);
TensorShape t2(3);
EXPECT_EQ(t2.ndims(), 3);
EXPECT_EQ(t2.getElements(), 1);
TensorShape t3({8, 10});
EXPECT_EQ(t3.ndims(), 2);
EXPECT_EQ(t3.getElements(), 80);
TensorShape t4(t3);
EXPECT_EQ(t4.ndims(), t3.ndims());
EXPECT_EQ(t4.getElements(), t3.getElements());
TensorShape t5({1, 2, 3, 4, 5});
EXPECT_EQ(t5.ndims(), 5);
EXPECT_EQ(t5.getElements(), 120);
}
TEST(TensorShape, GetAndSet) {
TensorShape t({1, 2, 3});
EXPECT_EQ(t.ndims(), 3);
EXPECT_EQ(t.getElements(), 6);
EXPECT_EQ(t[1], 2);
t.setDim(1, 100);
EXPECT_EQ(t.getElements(), 300);
EXPECT_EQ(t[1], 100);
}
} // namespace paddle
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <glog/logging.h> #include "paddle/math/Matrix.h"
namespace paddle { namespace paddle {
...@@ -57,69 +57,60 @@ struct DataType<double> { ...@@ -57,69 +57,60 @@ struct DataType<double> {
static const ValueType value = VALUE_TYPE_DOUBLE; static const ValueType value = VALUE_TYPE_DOUBLE;
}; };
/** namespace detail {
* TensorShape used to represent shape of normal tensor.
*/
class TensorShape {
public:
TensorShape() : ndims_(0), nelements_(0) { initDims(0); }
TensorShape(size_t ndims) : ndims_(ndims), nelements_(1) { initDims(ndims); };
TensorShape(std::initializer_list<size_t> dims) {
ndims_ = dims.size();
initDims(ndims_);
std::copy(dims.begin(), dims.end(), dims_.begin());
numElements();
};
TensorShape(const TensorShape& t)
: ndims_(t.ndims_), nelements_(t.nelements_) {
initDims(ndims_);
std::copy(t.dims_.begin(), t.dims_.end(), dims_.begin());
};
// get the size of specified dimension
size_t operator[](size_t dim) const {
CHECK_GE(dim, 0);
CHECK_LT(dim, ndims_);
return dims_[dim];
}
// set the size of specified dimension template <typename VType, DeviceType Device>
void setDim(size_t dim, size_t size) { struct MatrixT;
CHECK_GE(dim, 0);
CHECK_LT(dim, ndims_);
dims_[dim] = size;
numElements();
}
// number of dimensions of the tensor template <>
size_t ndims() const { return ndims_; } struct MatrixT<real, DEVICE_TYPE_CPU> {
using type = CpuMatrix;
};
size_t getElements() const { return nelements_; } template <>
struct MatrixT<real, DEVICE_TYPE_GPU> {
using type = GpuMatrix;
};
private: template <>
// compute number of elements struct MatrixT<int, DEVICE_TYPE_CPU> {
void numElements() { using type = void; // Not implemented
nelements_ = 1; };
for (size_t n = 0; n < ndims_; n++) {
nelements_ *= dims_[n];
}
}
// init dims_ template <>
void initDims(size_t ndims) { struct MatrixT<int, DEVICE_TYPE_GPU> {
size_t count = ndims < 4 ? 4 : ndims; using type = void; // Not implemented
dims_.assign(count, 1); };
}
template <typename VType, DeviceType Device>
struct VectorT;
template <>
struct VectorT<real, DEVICE_TYPE_CPU> {
using type = CpuVector;
};
template <>
struct VectorT<real, DEVICE_TYPE_GPU> {
using type = GpuVector;
};
template <>
struct VectorT<int, DEVICE_TYPE_CPU> {
using type = CpuIVector;
};
template <>
struct VectorT<int, DEVICE_TYPE_GPU> {
using type = GpuIVector;
};
} // namespace detail
// number of dimensions template <typename VType, DeviceType DType>
// ndims_ may be not equeal dims_.size() struct Tensor {
size_t ndims_; typedef typename detail::MatrixT<VType, DType>::type Matrix;
// number of elements typedef typename detail::VectorT<VType, DType>::type Vector;
size_t nelements_;
std::vector<size_t> dims_;
}; };
} // namespace paddle } // namespace paddle
...@@ -17,37 +17,31 @@ limitations under the License. */ ...@@ -17,37 +17,31 @@ limitations under the License. */
namespace paddle { namespace paddle {
TEST(TensorShape, Constructor) { TEST(TensorType, Matrix) {
TensorShape t1; Tensor<real, DEVICE_TYPE_CPU>::Matrix matrix(100, 200);
EXPECT_EQ(t1.ndims(), 0); EXPECT_EQ(matrix.getHeight(), 100);
EXPECT_EQ(t1.getElements(), 0); EXPECT_EQ(matrix.getWidth(), 200);
EXPECT_EQ(matrix.getElementCnt(), 100 * 200);
TensorShape t2(3); EXPECT_EQ(matrix.useGpu(), false);
EXPECT_EQ(t2.ndims(), 3);
EXPECT_EQ(t2.getElements(), 1); Tensor<real, DEVICE_TYPE_GPU>::Matrix testGpu(100, 200);
EXPECT_EQ(testGpu.useGpu(), true);
TensorShape t3({8, 10});
EXPECT_EQ(t3.ndims(), 2);
EXPECT_EQ(t3.getElements(), 80);
TensorShape t4(t3);
EXPECT_EQ(t4.ndims(), t3.ndims());
EXPECT_EQ(t4.getElements(), t3.getElements());
TensorShape t5({1, 2, 3, 4, 5});
EXPECT_EQ(t5.ndims(), 5);
EXPECT_EQ(t5.getElements(), 120);
} }
TEST(TensorShape, GetAndSet) { TEST(TensorType, Vector) {
TensorShape t({1, 2, 3}); Tensor<real, DEVICE_TYPE_CPU>::Vector cpuVector(100);
EXPECT_EQ(t.ndims(), 3); Tensor<real, DEVICE_TYPE_GPU>::Vector gpuVector(100);
EXPECT_EQ(t.getElements(), 6); EXPECT_EQ(cpuVector.useGpu(), false);
EXPECT_EQ(gpuVector.useGpu(), true);
EXPECT_EQ(t[1], 2); EXPECT_EQ(cpuVector.getSize(), 100);
t.setDim(1, 100); EXPECT_EQ(gpuVector.getSize(), 100);
EXPECT_EQ(t.getElements(), 300);
EXPECT_EQ(t[1], 100); Tensor<int, DEVICE_TYPE_CPU>::Vector cpuIVector(100);
Tensor<int, DEVICE_TYPE_GPU>::Vector gpuIVector(100);
EXPECT_EQ(cpuIVector.useGpu(), false);
EXPECT_EQ(gpuIVector.useGpu(), true);
EXPECT_EQ(cpuIVector.getSize(), 100);
EXPECT_EQ(gpuIVector.getSize(), 100);
} }
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册