提交 46bd5f53 编写于 作者: B backyes 提交者: Yu Yang

add input sparse data check for sparse layer at runtime (#247)

* add input sparse data check for sparse layer at runtime,
to avoid invalid data access at pserver end while doing prefetch

* remote sparse design support binary sparse and float saprse both
上级 d1d52bb7
...@@ -227,12 +227,18 @@ void CacheRowCpuMatrix::mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB, ...@@ -227,12 +227,18 @@ void CacheRowCpuMatrix::mul(CpuSparseMatrix* a, CpuMatrix* b, real scaleAB,
void SparsePrefetchRowCpuMatrix::addRows(const unsigned int* ids, size_t len) { void SparsePrefetchRowCpuMatrix::addRows(const unsigned int* ids, size_t len) {
std::vector<unsigned int>& localIndices = indexDictHandle_->localIndices; std::vector<unsigned int>& localIndices = indexDictHandle_->localIndices;
for (size_t i = 0; i < len; i ++) {
CHECK_LT(*(ids + i), this->getHeight())
<< "id:" << *(ids + i) << "Height:" << this->getHeight()
<< "sparse id value exceeds the max input dimension, "
<< "it could be caused invalid input data samples";
}
localIndices.insert(localIndices.end(), ids, ids + len); localIndices.insert(localIndices.end(), ids, ids + len);
} }
void SparsePrefetchRowCpuMatrix::addRows(MatrixPtr input) { void SparsePrefetchRowCpuMatrix::addRows(MatrixPtr input) {
CpuSparseMatrix* mat = dynamic_cast<CpuSparseMatrix*>(input.get()); CpuSparseMatrix* mat = dynamic_cast<CpuSparseMatrix*>(input.get());
CHECK(mat) << "only support non value sparse matrix"; CHECK(mat) << "only support sparse matrix";
addRows(reinterpret_cast<const unsigned int*>(mat->getCols()), addRows(reinterpret_cast<const unsigned int*>(mat->getCols()),
mat->getElementCnt()); mat->getElementCnt());
} }
...@@ -243,7 +249,13 @@ void SparsePrefetchRowCpuMatrix::addRows(IVectorPtr ids) { ...@@ -243,7 +249,13 @@ void SparsePrefetchRowCpuMatrix::addRows(IVectorPtr ids) {
int* index = ids->getData(); int* index = ids->getData();
for (size_t i = 0; i < numSamples; ++i) { for (size_t i = 0; i < numSamples; ++i) {
if (index[i] == -1) continue; if (index[i] == -1) continue;
localIndices.push_back((unsigned int)index[i]);
unsigned int id = (unsigned int)index[i];
CHECK_LT(id, this->getHeight())
<< "id:" << id << "Height:" << this->getHeight()
<< "sparse id value exceeds the max input dimension, "
<< "it could be caused invalid input data samples";
localIndices.push_back(id);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册