提交 cdf8d990 编写于 作者: H hedaoyuan

Bug fix.

上级 2d9113da
......@@ -74,9 +74,9 @@ public:
virtual void check(const BufferArgs& inputs,
const BufferArgs& outputs) override {
const TensorShape& output = inputs[0].shape();
const TensorShape& input = inputs[0].shape();
const TensorShape& filter = inputs[1].shape();
const TensorShape& input = outputs[0].shape();
const TensorShape& output = outputs[0].shape();
checkShape(input, filter, output);
}
......
......@@ -60,12 +60,15 @@ public:
<< " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2Function<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
conv1,
conv2,
FuncConfig()
.set("padding", padding)
.set("stride", stride)
.set("paddings", paddings)
.set("strides", strides)
.set("groups", (size_t)1)
.set("algo", algo));
TensorShape shape0{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册