提交 2357aee3 编写于 作者: Y Yi Wang 提交者: GitHub

Merge pull request #3250 from lcy-seso/seq_softmax_in_sub_seq_mode

add gradient check for sequence softmax activation.
......@@ -186,7 +186,10 @@ Error __must_check forward(Argument& act) {
useGpu(act.deviceId));
}
auto starts = act.sequenceStartPositions->getVector(useGpu(act.deviceId));
auto starts =
act.hasSubseq()
? act.subSequenceStartPositions->getVector(useGpu(act.deviceId))
: act.sequenceStartPositions->getVector(useGpu(act.deviceId));
act.value->sequenceSoftmax(*act.value, *starts);
return Error();
}
......@@ -197,8 +200,9 @@ Error __must_check backward(Argument& act) {
"Input width for each timestep of sequence softmax should be 1");
}
size_t numSequences = act.getNumSequences();
const int* starts = act.sequenceStartPositions->getData(false);
size_t numSequences =
act.hasSubseq() ? act.getNumSubSequences() : act.getNumSequences();
const int* starts = act.getCpuStartPositions();
for (size_t i = 0; i < numSequences; ++i) {
// TODO(Dangqingqing) optimization for GPU
......
......@@ -57,6 +57,39 @@ TEST(Activation, activation) {
}
}
void testSequenceSoftmaxAct(bool hasSubseq) {
LOG(INFO) << "test activation: sequence softmax";
const size_t size = 1;
TestConfig config;
config.biasSize = 0;
config.layerConfig.set_type("addto");
config.layerConfig.set_size(size);
config.layerConfig.set_active_type("sequence_softmax");
config.inputDefs.push_back(
{hasSubseq ? INPUT_HASSUB_SEQUENCE_DATA : INPUT_SEQUENCE_DATA,
"layer_0",
1,
0});
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config,
"sequence_softmax",
100,
/* trans= */ false,
useGpu,
/* useWeight */ true);
}
}
TEST(SequenceSoftmaxActivation, activation) {
for (auto hasSubseq : {false, true}) {
LOG(INFO) << "hasSubseq = " << hasSubseq;
testSequenceSoftmaxAct(hasSubseq);
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册