提交 3133c09f 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #1768 from luotao1/pad

fix PadOp bug on Gpu
......@@ -38,7 +38,7 @@ public:
if (err) {
*err = Error(e.what());
} else {
LOG(FATAL) << "Cannot get key " << key << "with error " << e.what();
LOG(FATAL) << "Cannot get key " << key << " with error " << e.what();
}
return T();
}
......
......@@ -44,9 +44,9 @@ void Pad<DEVICE_TYPE_GPU>(real* outputs,
size_t nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + 1024 - 1) / 1024;
int cstart = pad.channelStart, cend = pad.channelEnd;
int hstart = pad.heightStart, hend = pad.heightEnd;
int wstart = pad.widthStart, wend = pad.widthEnd;
int cstart = pad.channel[0], cend = pad.channel[1];
int hstart = pad.height[0], hend = pad.height[1];
int wstart = pad.width[0], wend = pad.width[1];
int outC = inC + cstart + cend;
int outH = inH + hstart + hend;
int outW = inW + wstart + wend;
......@@ -83,9 +83,9 @@ void PadGrad<DEVICE_TYPE_GPU>(real* inGrad,
int nth = num * inC * inH * inW;
int blockSize = 1024;
int gridSize = (nth + 1024 - 1) / 1024;
int cstart = pad.channelStart, cend = pad.channelEnd;
int hstart = pad.heightStart, hend = pad.heightEnd;
int wstart = pad.widthStart, wend = pad.widthEnd;
int cstart = pad.channel[0], cend = pad.channel[1];
int hstart = pad.height[0], hend = pad.height[1];
int wstart = pad.width[0], wend = pad.width[1];
int outC = inC + cstart + cend;
int outH = inH + hstart + hend;
int outW = inW + wstart + wend;
......
......@@ -24,51 +24,25 @@ TEST(Pad, real) {
for (size_t imgSizeW : {5, 32, 96}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
FunctionCompare compare("Pad",
for (bool test_grad : {false, true}) {
FunctionCompare compare(
test_grad ? "PadGrad" : "Pad",
FuncConfig()
.set("cstart", 2)
.set("cend", 3)
.set("hstart", 1)
.set("hend", 2)
.set("wstart", 3)
.set("wend", 2));
.set<std::vector<uint32_t>>("channel", {2, 3})
.set<std::vector<uint32_t>>("height", {1, 2})
.set<std::vector<uint32_t>>("width", {3, 2}));
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
TensorShape outDims{
numSamples, channels + 5, imgSizeH + 3, imgSizeW + 5};
compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, inDims));
compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, outDims, ASSIGN_TO));
compare.addInputs(
BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
compare.addOutputs(BufferArg(
VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO));
compare.run();
}
}
}
}
}
TEST(PadGrad, real) {
for (size_t numSamples : {5, 32}) {
for (size_t channels : {1, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
FunctionCompare compare("PadGrad",
FuncConfig()
.set("cstart", 2)
.set("cend", 3)
.set("hstart", 1)
.set("hend", 2)
.set("wstart", 3)
.set("wend", 2));
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
TensorShape outDims{
numSamples, channels + 5, imgSizeH + 3, imgSizeW + 5};
compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, outDims));
compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, inDims, ASSIGN_TO));
compare.run();
}
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册