提交 1b8e151f 编写于 作者: P Peng Li

Support user specified label input in tests

上级 37f75959
...@@ -303,13 +303,31 @@ void initDataLayer(TestConfig testConf, ...@@ -303,13 +303,31 @@ void initDataLayer(TestConfig testConf,
ICpuGpuVectorPtr sequenceStartPositions; ICpuGpuVectorPtr sequenceStartPositions;
ICpuGpuVectorPtr subSequenceStartPositions; ICpuGpuVectorPtr subSequenceStartPositions;
IVectorPtr cpuSequenceDims; IVectorPtr cpuSequenceDims;
for (size_t i = 0; i < testConf.inputDefs.size(); i++) { for (size_t i = 0; i < testConf.inputDefs.size(); ++i) {
if (testConf.inputDefs[i].inputType != INPUT_SEQUENCE_LABEL) continue;
const std::vector<int>& labelSeqStartPositions =
testConf.inputDefs[i].labelSeqStartPositions;
if (labelSeqStartPositions.size() != 0) {
CHECK(!sequenceStartPositions);
CHECK_GE(labelSeqStartPositions.size(), 2);
sequenceStartPositions =
ICpuGpuVector::create(labelSeqStartPositions.size(), useGpu);
sequenceStartPositions->copyFrom(
labelSeqStartPositions.data(), labelSeqStartPositions.size(), useGpu);
}
}
for (size_t i = 0; i < testConf.inputDefs.size(); ++i) {
LayerConfig config; LayerConfig config;
config.set_name(testConf.inputDefs[i].name); config.set_name(testConf.inputDefs[i].name);
config.set_type("data"); config.set_type("data");
config.set_size(testConf.inputDefs[i].dim); config.set_size(testConf.inputDefs[i].dim);
LayerPtr layer = LayerPtr(new DataLayer(config)); LayerPtr layer = LayerPtr(new DataLayer(config));
size_t numSequence = batchSize / 10 + 1; size_t numSequence = sequenceStartPositions
? sequenceStartPositions->getSize() - 1
: batchSize / 10 + 1;
Argument data; Argument data;
auto fillData = [&](bool trans, int height, int width) { auto fillData = [&](bool trans, int height, int width) {
...@@ -336,9 +354,17 @@ void initDataLayer(TestConfig testConf, ...@@ -336,9 +354,17 @@ void initDataLayer(TestConfig testConf,
break; break;
case INPUT_LABEL: case INPUT_LABEL:
case INPUT_SEQUENCE_LABEL: case INPUT_SEQUENCE_LABEL:
data.ids = VectorT<int>::create(batchSize, useGpu); if (testConf.inputDefs[i].labelInitValue.size() != 0) {
// now rand number can be 0 to inputDefs[i].dim const std::vector<int>& labelInitValue =
data.ids->rand(testConf.inputDefs[i].dim); testConf.inputDefs[i].labelInitValue;
CHECK_EQ(labelInitValue.size(), batchSize);
data.ids = VectorT<int>::create(batchSize, useGpu);
data.ids->copyFrom(labelInitValue.data(), batchSize);
} else {
data.ids = VectorT<int>::create(batchSize, useGpu);
// now rand number can be 0 to inputDefs[i].dim
data.ids->rand(testConf.inputDefs[i].dim);
}
break; break;
case INPUT_SPARSE_NON_VALUE_DATA: case INPUT_SPARSE_NON_VALUE_DATA:
data.value = makeRandomSparseMatrix( data.value = makeRandomSparseMatrix(
......
...@@ -64,6 +64,8 @@ struct InputDef { ...@@ -64,6 +64,8 @@ struct InputDef {
size_t paraSize; size_t paraSize;
ParaSparse sparse; ParaSparse sparse;
bool isStatic; bool isStatic;
std::vector<int> labelInitValue;
std::vector<int> labelSeqStartPositions;
InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) { InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) {
inputType = type; inputType = type;
name = nameIn; name = nameIn;
...@@ -72,6 +74,23 @@ struct InputDef { ...@@ -72,6 +74,23 @@ struct InputDef {
sparse = {""}; sparse = {""};
isStatic = false; isStatic = false;
} }
InputDef(InputType type,
string nameIn,
size_t dimIn,
size_t sizeIn,
std::vector<int> labelInitValue,
std::vector<int> labelSeqStartPositions)
: labelInitValue(labelInitValue),
labelSeqStartPositions(labelSeqStartPositions) {
inputType = type;
name = nameIn;
dim = dimIn;
paraSize = sizeIn;
sparse = {""};
isStatic = false;
}
InputDef(InputType type, InputDef(InputType type,
string nameIn, string nameIn,
size_t dimIn, size_t dimIn,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册