提交 a7827593 编写于 作者: L Luo Tao

fix PadOp bug on Gpu

上级 d324ed7f
...@@ -38,7 +38,7 @@ public: ...@@ -38,7 +38,7 @@ public:
if (err) { if (err) {
*err = Error(e.what()); *err = Error(e.what());
} else { } else {
LOG(FATAL) << "Cannot get key " << key << "with error " << e.what(); LOG(FATAL) << "Cannot get key " << key << " with error " << e.what();
} }
return T(); return T();
} }
......
...@@ -44,9 +44,9 @@ void Pad<DEVICE_TYPE_GPU>(real* outputs, ...@@ -44,9 +44,9 @@ void Pad<DEVICE_TYPE_GPU>(real* outputs,
size_t nth = num * inC * inH * inW; size_t nth = num * inC * inH * inW;
int blockSize = 1024; int blockSize = 1024;
int gridSize = (nth + 1024 - 1) / 1024; int gridSize = (nth + 1024 - 1) / 1024;
int cstart = pad.channelStart, cend = pad.channelEnd; int cstart = pad.channel[0], cend = pad.channel[1];
int hstart = pad.heightStart, hend = pad.heightEnd; int hstart = pad.height[0], hend = pad.height[1];
int wstart = pad.widthStart, wend = pad.widthEnd; int wstart = pad.width[0], wend = pad.width[1];
int outC = inC + cstart + cend; int outC = inC + cstart + cend;
int outH = inH + hstart + hend; int outH = inH + hstart + hend;
int outW = inW + wstart + wend; int outW = inW + wstart + wend;
...@@ -83,9 +83,9 @@ void PadGrad<DEVICE_TYPE_GPU>(real* inGrad, ...@@ -83,9 +83,9 @@ void PadGrad<DEVICE_TYPE_GPU>(real* inGrad,
int nth = num * inC * inH * inW; int nth = num * inC * inH * inW;
int blockSize = 1024; int blockSize = 1024;
int gridSize = (nth + 1024 - 1) / 1024; int gridSize = (nth + 1024 - 1) / 1024;
int cstart = pad.channelStart, cend = pad.channelEnd; int cstart = pad.channel[0], cend = pad.channel[1];
int hstart = pad.heightStart, hend = pad.heightEnd; int hstart = pad.height[0], hend = pad.height[1];
int wstart = pad.widthStart, wend = pad.widthEnd; int wstart = pad.width[0], wend = pad.width[1];
int outC = inC + cstart + cend; int outC = inC + cstart + cend;
int outH = inH + hstart + hend; int outH = inH + hstart + hend;
int outW = inW + wstart + wend; int outW = inW + wstart + wend;
......
...@@ -24,51 +24,25 @@ TEST(Pad, real) { ...@@ -24,51 +24,25 @@ TEST(Pad, real) {
for (size_t imgSizeW : {5, 32, 96}) { for (size_t imgSizeW : {5, 32, 96}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW; << " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
for (bool test_grad : {false, true}) {
FunctionCompare compare("Pad", FunctionCompare compare(
test_grad ? "PadGrad" : "Pad",
FuncConfig() FuncConfig()
.set("cstart", 2) .set<std::vector<uint32_t>>("channel", {2, 3})
.set("cend", 3) .set<std::vector<uint32_t>>("height", {1, 2})
.set("hstart", 1) .set<std::vector<uint32_t>>("width", {3, 2}));
.set("hend", 2)
.set("wstart", 3)
.set("wend", 2));
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW}; TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
TensorShape outDims{ TensorShape outDims{
numSamples, channels + 5, imgSizeH + 3, imgSizeW + 5}; numSamples, channels + 5, imgSizeH + 3, imgSizeW + 5};
compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, inDims)); compare.addInputs(
compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, outDims, ASSIGN_TO)); BufferArg(VALUE_TYPE_FLOAT, test_grad ? outDims : inDims));
compare.addOutputs(BufferArg(
VALUE_TYPE_FLOAT, test_grad ? inDims : outDims, ASSIGN_TO));
compare.run(); compare.run();
} }
} }
} }
} }
}
TEST(PadGrad, real) {
for (size_t numSamples : {5, 32}) {
for (size_t channels : {1, 5, 32}) {
for (size_t imgSizeH : {5, 33, 100}) {
for (size_t imgSizeW : {5, 32, 96}) {
VLOG(3) << " numSamples=" << numSamples << " channels=" << channels
<< " imgSizeH=" << imgSizeH << " imgSizeW=" << imgSizeW;
FunctionCompare compare("PadGrad",
FuncConfig()
.set("cstart", 2)
.set("cend", 3)
.set("hstart", 1)
.set("hend", 2)
.set("wstart", 3)
.set("wend", 2));
TensorShape inDims{numSamples, channels, imgSizeH, imgSizeW};
TensorShape outDims{
numSamples, channels + 5, imgSizeH + 3, imgSizeW + 5};
compare.addInputs(BufferArg(VALUE_TYPE_FLOAT, outDims));
compare.addOutputs(BufferArg(VALUE_TYPE_FLOAT, inDims, ASSIGN_TO));
compare.run();
}
}
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册