提交 c70d3e1a 编写于 作者: H hedaoyuan

Some bug fix

上级 3c0aa0cc
......@@ -96,7 +96,7 @@ public:
size_t inputHeight = inputs[0].shape()[2];
size_t inputWidth = inputs[0].shape()[3];
size_t filterHeight = inputs[1].shape()[2];
size_t filterWidth = inputs[1].shape()[2];
size_t filterWidth = inputs[1].shape()[3];
size_t outputChannels = outputs[0].shape()[1];
size_t outputHeight = outputs[0].shape()[2];
size_t outputWidth = outputs[0].shape()[3];
......@@ -148,23 +148,29 @@ public:
0.0f,
outputData + g * outputOffset,
N);
}
inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth;
}
}
}
void resizeBuffer(size_t newSize) {
if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) {
if (Device == DEVICE_TYPE_CPU) {
memory_ = std::make_shared<CpuMemoryHandle>(newSize * sizeof(real));
} else {
memory_ = std::make_shared<GpuMemoryHandle>(newSize * sizeof(real));
}
}
}
private:
CpuMemHandlePtr memory_;
MemoryHandlePtr memory_;
};
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction);
#endif
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册