提交 6bb970b5 编写于 作者: C caoying03

enable defining sub-sequence data in test layer gradients.

上级 c46aed57
...@@ -400,7 +400,6 @@ void initDataLayer(TestConfig testConf, ...@@ -400,7 +400,6 @@ void initDataLayer(TestConfig testConf,
const std::vector<int>& labelSeqStartPositions = const std::vector<int>& labelSeqStartPositions =
testConf.inputDefs[i].labelSeqStartPositions; testConf.inputDefs[i].labelSeqStartPositions;
if (labelSeqStartPositions.size() != 0) { if (labelSeqStartPositions.size() != 0) {
CHECK(!sequenceStartPositions);
CHECK_GE(static_cast<int>(labelSeqStartPositions.size()), 2); CHECK_GE(static_cast<int>(labelSeqStartPositions.size()), 2);
sequenceStartPositions = sequenceStartPositions =
...@@ -410,6 +409,19 @@ void initDataLayer(TestConfig testConf, ...@@ -410,6 +409,19 @@ void initDataLayer(TestConfig testConf,
useGpu); useGpu);
data.sequenceStartPositions = sequenceStartPositions; data.sequenceStartPositions = sequenceStartPositions;
} }
const std::vector<int>& labelSubSeqStartPositions =
testConf.inputDefs[i].labelSubSeqStartPositions;
if (labelSubSeqStartPositions.size() != 0) {
CHECK_GE(static_cast<int>(labelSubSeqStartPositions.size()), 2);
subSequenceStartPositions =
ICpuGpuVector::create(labelSubSeqStartPositions.size(), useGpu);
subSequenceStartPositions->copyFrom(labelSubSeqStartPositions.data(),
labelSubSeqStartPositions.size(),
useGpu);
data.subSequenceStartPositions = subSequenceStartPositions;
}
break; break;
} }
default: default:
......
...@@ -67,6 +67,7 @@ struct InputDef { ...@@ -67,6 +67,7 @@ struct InputDef {
bool isStatic; bool isStatic;
std::vector<int> labelInitValue; std::vector<int> labelInitValue;
std::vector<int> labelSeqStartPositions; std::vector<int> labelSeqStartPositions;
std::vector<int> labelSubSeqStartPositions;
MatrixPtr selfDefinedData; MatrixPtr selfDefinedData;
InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) { InputDef(InputType type, string nameIn, size_t dimIn, size_t sizeIn) {
...@@ -81,8 +82,10 @@ struct InputDef { ...@@ -81,8 +82,10 @@ struct InputDef {
InputDef(InputType type, InputDef(InputType type,
string nameIn, string nameIn,
MatrixPtr selfDefinedData, MatrixPtr selfDefinedData,
std::vector<int> selfDefinedSeqStartPos = {}) std::vector<int> selfDefinedSeqStartPos = {},
std::vector<int> selfDefinedSubSeqStartPos = {})
: labelSeqStartPositions(selfDefinedSeqStartPos), : labelSeqStartPositions(selfDefinedSeqStartPos),
labelSubSeqStartPositions(selfDefinedSubSeqStartPos),
selfDefinedData(selfDefinedData) { selfDefinedData(selfDefinedData) {
inputType = type; inputType = type;
name = nameIn; name = nameIn;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册