提交 bc2b521c 编写于 作者: L liaogang

Follow comments

上级 7f6e9aca
...@@ -43,22 +43,14 @@ TEST(Layer, BilinearInterpLayer) { ...@@ -43,22 +43,14 @@ TEST(Layer, BilinearInterpLayer) {
bilinear->set_img_size_x(32); bilinear->set_img_size_x(32);
bilinear->set_img_size_y(32); bilinear->set_img_size_y(32);
bilinear->set_out_size_x(64);
bilinear->set_out_size_y(64);
bilinear->set_num_channels(4); bilinear->set_num_channels(4);
for (auto useGpu : {false, true}) { for (auto useGpu : {false, true}) {
testLayerGrad(config, "bilinear_interp", 10, false, useGpu); for (auto out_size : {32, 64, 128}) {
} bilinear->set_out_size_x(out_size);
bilinear->set_out_size_y(out_size);
bilinear->set_img_size_x(32); testLayerGrad(config, "bilinear_interp", 10, false, useGpu);
bilinear->set_img_size_y(32); }
bilinear->set_out_size_x(32);
bilinear->set_out_size_y(32);
bilinear->set_num_channels(4);
for (auto useGpu : {false, true}) {
testLayerGrad(config, "bilinear_interp", 10, false, useGpu);
} }
} }
......
...@@ -3902,6 +3902,8 @@ void CpuMatrix::bilinearForward(const Matrix& in, ...@@ -3902,6 +3902,8 @@ void CpuMatrix::bilinearForward(const Matrix& in,
size_t batchSize = getHeight(); size_t batchSize = getHeight();
size_t inputW = in.getWidth(); size_t inputW = in.getWidth();
size_t inputH = in.getHeight(); size_t inputH = in.getHeight();
size_t inPosOffset = inImgH * inImgW;
size_t outPosOffset = outImgH * outImgW;
(void)(inputH); (void)(inputH);
real* outData = getData(); real* outData = getData();
...@@ -3931,8 +3933,8 @@ void CpuMatrix::bilinearForward(const Matrix& in, ...@@ -3931,8 +3933,8 @@ void CpuMatrix::bilinearForward(const Matrix& in,
h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wid]) + h2lambda * (w2lambda * inPos[0] + w1lambda * inPos[wid]) +
h1lambda * (w2lambda * inPos[hid * inImgW] + h1lambda * (w2lambda * inPos[hid * inImgW] +
w1lambda * inPos[hid * inImgW + wid]); w1lambda * inPos[hid * inImgW + wid]);
inPos += inImgH * inImgW; inPos += inPosOffset;
outPos += outImgH * outImgW; outPos += outPosOffset;
} }
} }
} }
...@@ -3954,6 +3956,8 @@ void CpuMatrix::bilinearBackward(const Matrix& out, ...@@ -3954,6 +3956,8 @@ void CpuMatrix::bilinearBackward(const Matrix& out,
size_t inputH = getHeight(); size_t inputH = getHeight();
size_t outputW = out.getWidth(); size_t outputW = out.getWidth();
size_t batchSize = out.getHeight(); size_t batchSize = out.getHeight();
size_t inPosOffset = inImgH * inImgW;
size_t outPosOffset = outImgH * outImgW;
(void)(inputH); (void)(inputH);
real* inGrad = getData(); real* inGrad = getData();
...@@ -3981,8 +3985,8 @@ void CpuMatrix::bilinearBackward(const Matrix& out, ...@@ -3981,8 +3985,8 @@ void CpuMatrix::bilinearBackward(const Matrix& out,
inPos[wid] += h2lambda * w1lambda * outPos[0]; inPos[wid] += h2lambda * w1lambda * outPos[0];
inPos[hid * inImgW] += h1lambda * w2lambda * outPos[0]; inPos[hid * inImgW] += h1lambda * w2lambda * outPos[0];
inPos[hid * inImgW + wid] += h1lambda * w1lambda * outPos[0]; inPos[hid * inImgW + wid] += h1lambda * w1lambda * outPos[0];
inPos += inImgH * inImgW; inPos += inPosOffset;
outPos += outImgH * outImgW; outPos += outPosOffset;
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册