提交 d35ef9de 编写于 作者: H hedaoyuan

follow commit

上级 ccf0b1bb
...@@ -56,7 +56,7 @@ public: ...@@ -56,7 +56,7 @@ public:
: buf_(buf), valueType_(valueType) {} : buf_(buf), valueType_(valueType) {}
BufferArg(const Matrix& matrix) BufferArg(const Matrix& matrix)
: buf_((void*)matrix.getData()), : buf_(reinterpret_cast<void*>(matrix.getData())),
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(2) { shape_(2) {
shape_.setDim(0, matrix.getHeight()); shape_.setDim(0, matrix.getHeight());
...@@ -64,21 +64,23 @@ public: ...@@ -64,21 +64,23 @@ public:
} }
BufferArg(const Matrix& matrix, const TensorShape& shape) BufferArg(const Matrix& matrix, const TensorShape& shape)
: buf_((void*)matrix.getData()), : buf_(reinterpret_cast<void*>(matrix.getData())),
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(shape) { shape_(shape) {
CHECK_EQ(matrix.getElementCnt(), shape.getElements()); CHECK_EQ(matrix.getElementCnt(), shape.getElements());
} }
BufferArg(const Vector& vector) BufferArg(const Vector& vector)
: buf_((void*)vector.getData()), : buf_(reinterpret_cast<void*>(vector.getData())),
valueType_(DataType<real>::value), valueType_(DataType<real>::value),
shape_(1) { shape_(1) {
shape_.setDim(0, vector.getSize()); shape_.setDim(0, vector.getSize());
} }
BufferArg(const IVector& vector) BufferArg(const IVector& vector)
: buf_((void*)vector.getData()), valueType_(VALUE_TYPE_INT32), shape_(1) { : buf_(reinterpret_cast<void*>(vector.getData())),
valueType_(VALUE_TYPE_INT32),
shape_(1) {
shape_.setDim(0, vector.getSize()); shape_.setDim(0, vector.getSize());
} }
...@@ -129,7 +131,7 @@ protected: ...@@ -129,7 +131,7 @@ protected:
// sequence start positions in a mini-batch of sequences // sequence start positions in a mini-batch of sequences
// shape_.ndims() == 1 // shape_.ndims() == 1
// valueType_ = int32 // valueType_ = int32
// if a < b than value_.buf_[a] < value_.buf_[b] // if a < b then value_.buf_[a] < value_.buf_[b]
class SequenceIdArg : public BufferArg { class SequenceIdArg : public BufferArg {
public: public:
SequenceIdArg(void* buf, const TensorShape& shape) SequenceIdArg(void* buf, const TensorShape& shape)
...@@ -203,13 +205,13 @@ public: ...@@ -203,13 +205,13 @@ public:
SparseMatrixArg(const CpuSparseMatrix& sparse) SparseMatrixArg(const CpuSparseMatrix& sparse)
: BufferArg(sparse), : BufferArg(sparse),
row_((void*)sparse.getRows(), VALUE_TYPE_INT32), row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_((void*)sparse.getCols(), VALUE_TYPE_INT32) {} col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
SparseMatrixArg(const GpuSparseMatrix& sparse) SparseMatrixArg(const GpuSparseMatrix& sparse)
: BufferArg(sparse), : BufferArg(sparse),
row_((void*)sparse.getRows(), VALUE_TYPE_INT32), row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_((void*)sparse.getCols(), VALUE_TYPE_INT32) {} col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
~SparseMatrixArg() {} ~SparseMatrixArg() {}
......
...@@ -30,14 +30,14 @@ public: ...@@ -30,14 +30,14 @@ public:
TensorShape(std::initializer_list<size_t> dims) { TensorShape(std::initializer_list<size_t> dims) {
ndims_ = dims.size(); ndims_ = dims.size();
initDims(ndims_); initDims(ndims_);
std::copy(dims.begin(), dims.end(), dims_.begin()); dims_.assign(dims);
numElements(); numElements();
}; };
TensorShape(const TensorShape& t) TensorShape(const TensorShape& t)
: ndims_(t.ndims_), nelements_(t.nelements_) { : ndims_(t.ndims_), nelements_(t.nelements_) {
initDims(ndims_); initDims(ndims_);
std::copy(t.dims_.begin(), t.dims_.end(), dims_.begin()); dims_.assign(t.dims_.begin(), t.dims_.end());
}; };
// get the size of specified dimension // get the size of specified dimension
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册