提交 048b14a9 编写于 作者: H hedaoyuan

Change stride to strides, and change padding to paddings.

上级 455888c5
......@@ -36,8 +36,8 @@ class ConvFunctionBase : public FunctionBase {
public:
void init(const FuncConfig& config) override {
// function arguments
stride_ = config.get<size_t>("stride");
padding_ = config.get<size_t>("padding");
strides_ = config.get<std::vector<size_t>>("strides");
paddings_ = config.get<std::vector<size_t>>("paddings");
// number of inputs and outputs
numInputs_ = 2;
......@@ -60,8 +60,15 @@ public:
}
protected:
size_t padding_;
size_t stride_;
std::vector<size_t> strides_;
std::vector<size_t> paddings_;
inline int strideH() const { return strides_[0]; }
inline int strideW() const { return strides_[1]; }
inline int paddingH() const { return paddings_[0]; }
inline int paddingW() const { return paddings_[1]; }
};
} // namespace paddle
......@@ -48,11 +48,13 @@ public:
<< " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2CpuFunction test(conv1,
conv2,
FuncConfig()
.set("padding", padding)
.set("stride", stride)
.set("paddings", paddings)
.set("strides", strides)
.set("algo", algo));
TensorShape shape0{
......
......@@ -118,10 +118,10 @@ public:
inputWidth,
filterHeight,
filterWidth,
stride_,
stride_,
padding_,
padding_,
strideH(),
strideW(),
paddingH(),
paddingW(),
outputHeight,
outputWidth,
colData);
......
......@@ -37,14 +37,16 @@ public:
size_t outputChannels,
size_t outputHeight,
size_t outputWidth,
size_t padding,
size_t stride) {
size_t paddingH,
size_t paddingW,
size_t strideH,
size_t strideW) {
for (size_t batch = 0; batch < batchSize; batch++) {
for (size_t outC = 0; outC < outputChannels; outC++) {
for (size_t outH = 0; outH < outputHeight; outH++) {
for (size_t outW = 0; outW < outputWidth; outW++) {
const int inStartH = (outH * stride) - padding;
const int inStartW = (outW * stride) - padding;
const int inStartH = (outH * strideH) - paddingH;
const int inStartW = (outW * strideW) - paddingW;
T outValue = (T)0;
for (size_t inC = 0; inC < inputChannels; inC++) {
for (size_t fH = 0; fH < filterHeight; fH++) {
......@@ -118,8 +120,10 @@ public:
outputChannels,
outputHeight,
outputWidth,
padding_,
stride_);
paddingH(),
paddingW(),
strideH(),
strideW());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册