diff --git a/paddle/function/ConvOp.h b/paddle/function/ConvOp.h index 4d678cfe2735c07d37c92c7c4dd2d6a3151b0955..465db57ae7d82049d30973e643a12c27c39ec304 100644 --- a/paddle/function/ConvOp.h +++ b/paddle/function/ConvOp.h @@ -36,8 +36,8 @@ class ConvFunctionBase : public FunctionBase { public: void init(const FuncConfig& config) override { // function arguments - stride_ = config.get("stride"); - padding_ = config.get("padding"); + strides_ = config.get>("strides"); + paddings_ = config.get>("paddings"); // number of inputs and outputs numInputs_ = 2; @@ -60,8 +60,15 @@ public: } protected: - size_t padding_; - size_t stride_; + std::vector strides_; + std::vector paddings_; + inline int strideH() const { return strides_[0]; } + + inline int strideW() const { return strides_[1]; } + + inline int paddingH() const { return paddings_[0]; } + + inline int paddingW() const { return paddings_[1]; } }; } // namespace paddle diff --git a/paddle/function/ConvOpTest.cpp b/paddle/function/ConvOpTest.cpp index 715fa58b5597c96e550d2e20e82e642ebf9ede8b..db8d9fa9da4609248078598257346245d8b92be9 100644 --- a/paddle/function/ConvOpTest.cpp +++ b/paddle/function/ConvOpTest.cpp @@ -48,11 +48,13 @@ public: << " outputWidth=" << outputSize << " stride=" << stride << " padding=" << padding; + std::vector paddings = {padding, padding}; + std::vector strides = {stride, stride}; Compare2CpuFunction test(conv1, conv2, FuncConfig() - .set("padding", padding) - .set("stride", stride) + .set("paddings", paddings) + .set("strides", strides) .set("algo", algo)); TensorShape shape0{ diff --git a/paddle/function/GemmConvOp.cpp b/paddle/function/GemmConvOp.cpp index 5e6ee24486213d58905f02a1a508ddd2f192a613..42786e44e0e97a315bb5f71b9d3d389d9f743f85 100644 --- a/paddle/function/GemmConvOp.cpp +++ b/paddle/function/GemmConvOp.cpp @@ -118,10 +118,10 @@ public: inputWidth, filterHeight, filterWidth, - stride_, - stride_, - padding_, - padding_, + strideH(), + strideW(), + paddingH(), + paddingW(), outputHeight, outputWidth, colData); diff --git a/paddle/function/NaiveConvOp.cpp b/paddle/function/NaiveConvOp.cpp index f13aa880a1e88b30d271554c6d53db333779a017..f5d2aa16ab9b8fdedf6320df52bdeae24ca73eea 100644 --- a/paddle/function/NaiveConvOp.cpp +++ b/paddle/function/NaiveConvOp.cpp @@ -37,14 +37,16 @@ public: size_t outputChannels, size_t outputHeight, size_t outputWidth, - size_t padding, - size_t stride) { + size_t paddingH, + size_t paddingW, + size_t strideH, + size_t strideW) { for (size_t batch = 0; batch < batchSize; batch++) { for (size_t outC = 0; outC < outputChannels; outC++) { for (size_t outH = 0; outH < outputHeight; outH++) { for (size_t outW = 0; outW < outputWidth; outW++) { - const int inStartH = (outH * stride) - padding; - const int inStartW = (outW * stride) - padding; + const int inStartH = (outH * strideH) - paddingH; + const int inStartW = (outW * strideW) - paddingW; T outValue = (T)0; for (size_t inC = 0; inC < inputChannels; inC++) { for (size_t fH = 0; fH < filterHeight; fH++) { @@ -118,8 +120,10 @@ public: outputChannels, outputHeight, outputWidth, - padding_, - stride_); + paddingH(), + paddingW(), + strideH(), + strideW()); } };