提交 2c2947f8 编写于 作者: D dangqingqing

Fix a bug for TransLayer and add unit testing.

上级 fb05a731
......@@ -56,7 +56,14 @@ void TransLayer::backward(const UpdateCallback& callback) {
return;
}
MatrixPtr preGrad = getInputGrad(0);
outputGrad->transpose(preGrad, false);
if (preGrad) {
MatrixPtr transGrad = Matrix::create(preGrad->getHeight(),
preGrad->getWidth(),
/* trans= */ false,
preGrad->useGpu());
outputGrad->transpose(transGrad, false);
preGrad->add(*transGrad);
}
}
} // namespace paddle
......@@ -1689,6 +1689,22 @@ TEST(Layer, smooth_l1) {
}
}
TEST(Layer, TransLayer) {
TestConfig config;
const int height = 128;
const int width = 1028;
config.layerConfig.set_type("trans");
config.layerConfig.set_size(width);
config.inputDefs.push_back(
{INPUT_DATA, "layer_0", /* dim= */ height * width, /* paraSize= */ 0});
config.layerConfig.add_inputs();
for (auto useGpu : {false, true}) {
testLayerGrad(config, "trans", height, /* trans= */ false, useGpu);
}
}
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册