提交 d5384e64 编写于 作者: G guosheng

refine layer gradient test of ROIPoolLayer

上级 0f4c7332
...@@ -1842,17 +1842,20 @@ TEST(Layer, roi_pool) { ...@@ -1842,17 +1842,20 @@ TEST(Layer, roi_pool) {
roiPoolConf->set_width(14); roiPoolConf->set_width(14);
roiPoolConf->set_height(14); roiPoolConf->set_height(14);
MatrixPtr roiValue = Matrix::create(10, 10, false, false); const size_t roiNum = 10;
const size_t roiDim = 10;
const size_t batchSize = 5;
MatrixPtr roiValue = Matrix::create(roiNum, roiDim, false, false);
roiValue->zeroMem(); roiValue->zeroMem();
real* roiData = roiValue->getData(); real* roiData = roiValue->getData();
for (size_t i = 0; i < roiValue->getElementCnt() / 5; ++i) { for (size_t i = 0; i < roiNum; ++i) {
*roiData++ = std::rand() % 2; roiData[i * roiDim + 0] = std::rand() % batchSize;
*roiData++ = std::rand() % 224; roiData[i * roiDim + 1] = std::rand() % 224; // xMin
*roiData++ = std::rand() % 224; roiData[i * roiDim + 2] = std::rand() % 224; // yMin
size_t xMin = static_cast<size_t>(*(roiData - 2)); size_t xMin = static_cast<size_t>(roiData[i * roiDim + 1]);
size_t yMin = static_cast<size_t>(*(roiData - 1)); size_t yMin = static_cast<size_t>(roiData[i * roiDim + 2]);
*roiData++ = xMin + std::rand() % (224 - xMin); roiData[i * roiDim + 3] = xMin + std::rand() % (224 - xMin); // xMax
*roiData++ = yMin + std::rand() % (224 - yMin); roiData[i * roiDim + 4] = yMin + std::rand() % (224 - yMin); // yMax
} }
config.inputDefs.push_back({INPUT_DATA, "input", 3 * 14 * 14, {}}); config.inputDefs.push_back({INPUT_DATA, "input", 3 * 14 * 14, {}});
...@@ -1860,7 +1863,7 @@ TEST(Layer, roi_pool) { ...@@ -1860,7 +1863,7 @@ TEST(Layer, roi_pool) {
config.layerConfig.add_inputs(); config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) { for (auto useGpu : {false, true}) {
testLayerGrad(config, "roi_pool", 5, false, useGpu, false); testLayerGrad(config, "roi_pool", batchSize, false, useGpu, false);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册