提交 d35ef9de 编写于 作者: H hedaoyuan

follow commit

上级 ccf0b1bb
......@@ -56,7 +56,7 @@ public:
: buf_(buf), valueType_(valueType) {}
BufferArg(const Matrix& matrix)
: buf_((void*)matrix.getData()),
: buf_(reinterpret_cast<void*>(matrix.getData())),
valueType_(DataType<real>::value),
shape_(2) {
shape_.setDim(0, matrix.getHeight());
......@@ -64,21 +64,23 @@ public:
}
BufferArg(const Matrix& matrix, const TensorShape& shape)
: buf_((void*)matrix.getData()),
: buf_(reinterpret_cast<void*>(matrix.getData())),
valueType_(DataType<real>::value),
shape_(shape) {
CHECK_EQ(matrix.getElementCnt(), shape.getElements());
}
BufferArg(const Vector& vector)
: buf_((void*)vector.getData()),
: buf_(reinterpret_cast<void*>(vector.getData())),
valueType_(DataType<real>::value),
shape_(1) {
shape_.setDim(0, vector.getSize());
}
BufferArg(const IVector& vector)
: buf_((void*)vector.getData()), valueType_(VALUE_TYPE_INT32), shape_(1) {
: buf_(reinterpret_cast<void*>(vector.getData())),
valueType_(VALUE_TYPE_INT32),
shape_(1) {
shape_.setDim(0, vector.getSize());
}
......@@ -129,7 +131,7 @@ protected:
// sequence start positions in a mini-batch of sequences
// shape_.ndims() == 1
// valueType_ = int32
// if a < b than value_.buf_[a] < value_.buf_[b]
// if a < b then value_.buf_[a] < value_.buf_[b]
class SequenceIdArg : public BufferArg {
public:
SequenceIdArg(void* buf, const TensorShape& shape)
......@@ -203,13 +205,13 @@ public:
SparseMatrixArg(const CpuSparseMatrix& sparse)
: BufferArg(sparse),
row_((void*)sparse.getRows(), VALUE_TYPE_INT32),
col_((void*)sparse.getCols(), VALUE_TYPE_INT32) {}
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
SparseMatrixArg(const GpuSparseMatrix& sparse)
: BufferArg(sparse),
row_((void*)sparse.getRows(), VALUE_TYPE_INT32),
col_((void*)sparse.getCols(), VALUE_TYPE_INT32) {}
row_(reinterpret_cast<void*>(sparse.getRows()), VALUE_TYPE_INT32),
col_(reinterpret_cast<void*>(sparse.getCols()), VALUE_TYPE_INT32) {}
~SparseMatrixArg() {}
......
......@@ -30,14 +30,14 @@ public:
TensorShape(std::initializer_list<size_t> dims) {
ndims_ = dims.size();
initDims(ndims_);
std::copy(dims.begin(), dims.end(), dims_.begin());
dims_.assign(dims);
numElements();
};
TensorShape(const TensorShape& t)
: ndims_(t.ndims_), nelements_(t.nelements_) {
initDims(ndims_);
std::copy(t.dims_.begin(), t.dims_.end(), dims_.begin());
dims_.assign(t.dims_.begin(), t.dims_.end());
};
// get the size of specified dimension
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册