You need to sign in or sign up before continuing.
Enrich ConvShift to support sequence data input
Created by: pkuyym
ConvShift Layer is designed to implement circular convolutional operation. Circular convolutional operation is a necessary step to implement location-based addressing in Neural Turing Machine. NTM was designed to process sequence data initially. However, current ConvShift Layer only supports non-sequence type data input.
Original implementation calls circularConv and circularConvDervative functions implemented in class Matrix.
void ConvShiftLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
size_t batchSize = inV0->getHeight();
size_t dataDim = inV0->getWidth();
CHECK_EQ(batchSize, inV1->getHeight());
CHECK_EQ(dataDim, getSize());
{
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
resetOutput(batchSize, dataDim);
}
MatrixPtr outV = getOutputValue();
REGISTER_TIMER_INFO("FwConvShiftTimer", getName().c_str());
outV->circularConv(*inV0, *inV1);
}
void ConvShiftLayer::backward(const UpdateCallback& callback) {
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
MatrixPtr outG = getOutputGrad();
MatrixPtr inG0 = getInputGrad(0);
MatrixPtr inG1 = getInputGrad(1);
REGISTER_TIMER_INFO("BwConvShiftTimer", getName().c_str());
if (inG0 && inG1) {
outG->circularConvDerivative(*outG, *inV0, *inV1, *inG0, *inG1);
} else {
CHECK(!inG0 || !inG1) << "Not supported";
}
}
First, I will check the type of input data. If the input data is non-sequence type, the function will run origin logic, otherwise call circularConvSeq and circularConvSeqDerivative.
void ConvShiftLayer::forward(PassType passType) {
Layer::forward(passType);
MatrixPtr inV0 = getInputValue(0);
size_t batchSize = inV0->getHeight();
size_t dataDim = inV0->getWidth();
CHECK_EQ(dataDim, getSize());
{
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
resetOutput(batchSize, dataDim);
}
REGISTER_TIMER_INFO("FwConvShiftTimer", getName().c_str());
if (!isSeqType()) {
MatrixPtr inV1 = getInputValue(1);
CHECK_EQ(batchSize, inV1->getHeight());
MatrixPtr outV = getOutputValue();
outV->circularConv(*inV0, *inV1);
} else {
circularConvSeq();
}
}
void ConvShiftLayer::backward(const UpdateCallback& callback) {
MatrixPtr inG0 = getInputGrad(0);
MatrixPtr inG1 = getInputGrad(1);
REGISTER_TIMER_INFO("BwConvShiftTimer", getName().c_str());
if (!(inG0 && inG1)) {
CHECK(!inG0 || !inG1) << "Not supported";
}
if (!isSeqType()) {
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
MatrixPtr outG = getOutputGrad();
outG->circularConvDerivative(*outG, *inV0, *inV1, *inG0, *inG1);
} else {
circularConvSeqDerivative();
}
}
Please go to related PR #2133 to check implementation details.