提交 8747d60d 编写于 作者: Q qingqing01 提交者: GitHub

Merge pull request #3449 from lcy-seso/enable_self_defined_ids

enable self-defined index data in testLayerGrad.
...@@ -388,14 +388,23 @@ void initDataLayer(TestConfig testConf, ...@@ -388,14 +388,23 @@ void initDataLayer(TestConfig testConf,
data.grad->zeroMem(); data.grad->zeroMem();
break; break;
case INPUT_SELF_DEFINE_DATA: { case INPUT_SELF_DEFINE_DATA: {
size_t height = testConf.inputDefs[i].selfDefinedData->getHeight(); if (testConf.inputDefs[i].ids.size()) {
size_t width = testConf.inputDefs[i].selfDefinedData->getWidth(); data.ids = IVector::create(testConf.inputDefs[i].ids.size(), useGpu);
CHECK_GT(static_cast<int>(height), 0); data.ids->copyFrom(testConf.inputDefs[i].ids.data(),
CHECK_GT(static_cast<int>(width), 0); testConf.inputDefs[i].ids.size());
data.value = Matrix::create(height, width, false, useGpu); } else if (testConf.inputDefs[i].selfDefinedData) {
data.grad = Matrix::create(height, width, false, useGpu); size_t height = testConf.inputDefs[i].selfDefinedData->getHeight();
data.value->copyFrom(*testConf.inputDefs[i].selfDefinedData); size_t width = testConf.inputDefs[i].selfDefinedData->getWidth();
data.grad->zeroMem(); CHECK_GT(static_cast<int>(height), 0);
CHECK_GT(static_cast<int>(width), 0);
data.value = Matrix::create(height, width, false, useGpu);
data.grad = Matrix::create(height, width, false, useGpu);
data.value->copyFrom(*testConf.inputDefs[i].selfDefinedData);
data.grad->zeroMem();
} else {
LOG(FATAL) << "No self-defined data are given.";
return;
}
const std::vector<int>& labelSeqStartPositions = const std::vector<int>& labelSeqStartPositions =
testConf.inputDefs[i].labelSeqStartPositions; testConf.inputDefs[i].labelSeqStartPositions;
......
...@@ -68,6 +68,7 @@ struct InputDef { ...@@ -68,6 +68,7 @@ struct InputDef {
std::vector<int> labelInitValue; std::vector<int> labelInitValue;
std::vector<int> labelSeqStartPositions; std::vector<int> labelSeqStartPositions;
std::vector<int> labelSubSeqStartPositions; std::vector<int> labelSubSeqStartPositions;
std::vector<int> ids;
MatrixPtr selfDefinedData; MatrixPtr selfDefinedData;
InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) { InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) {
...@@ -95,6 +96,23 @@ struct InputDef { ...@@ -95,6 +96,23 @@ struct InputDef {
isStatic = false; isStatic = false;
} }
InputDef(InputType type,
string nameIn,
const std::vector<int>& ids,
const std::vector<int>& selfDefinedSeqStartPos = {},
const std::vector<int>& selfDefinedSubSeqStartPos = {})
: labelSeqStartPositions(selfDefinedSeqStartPos),
labelSubSeqStartPositions(selfDefinedSubSeqStartPos),
ids(ids) {
selfDefinedData = nullptr;
inputType = type;
name = nameIn;
dim = 0;
sparse = {""};
paraSize = 0;
isStatic = false;
}
InputDef(InputType type, InputDef(InputType type,
string nameIn, string nameIn,
size_t dimIn, size_t dimIn,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册