提交 048b14a9 编写于 作者: H hedaoyuan

Change stride to strides, and change padding to paddings.

上级 455888c5
...@@ -36,8 +36,8 @@ class ConvFunctionBase : public FunctionBase { ...@@ -36,8 +36,8 @@ class ConvFunctionBase : public FunctionBase {
public: public:
void init(const FuncConfig& config) override { void init(const FuncConfig& config) override {
// function arguments // function arguments
stride_ = config.get<size_t>("stride"); strides_ = config.get<std::vector<size_t>>("strides");
padding_ = config.get<size_t>("padding"); paddings_ = config.get<std::vector<size_t>>("paddings");
// number of inputs and outputs // number of inputs and outputs
numInputs_ = 2; numInputs_ = 2;
...@@ -60,8 +60,15 @@ public: ...@@ -60,8 +60,15 @@ public:
} }
protected: protected:
size_t padding_; std::vector<size_t> strides_;
size_t stride_; std::vector<size_t> paddings_;
inline int strideH() const { return strides_[0]; }
inline int strideW() const { return strides_[1]; }
inline int paddingH() const { return paddings_[0]; }
inline int paddingW() const { return paddings_[1]; }
}; };
} // namespace paddle } // namespace paddle
...@@ -48,11 +48,13 @@ public: ...@@ -48,11 +48,13 @@ public:
<< " outputWidth=" << outputSize << " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding; << " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
Compare2CpuFunction test(conv1, Compare2CpuFunction test(conv1,
conv2, conv2,
FuncConfig() FuncConfig()
.set("padding", padding) .set("paddings", paddings)
.set("stride", stride) .set("strides", strides)
.set("algo", algo)); .set("algo", algo));
TensorShape shape0{ TensorShape shape0{
......
...@@ -118,10 +118,10 @@ public: ...@@ -118,10 +118,10 @@ public:
inputWidth, inputWidth,
filterHeight, filterHeight,
filterWidth, filterWidth,
stride_, strideH(),
stride_, strideW(),
padding_, paddingH(),
padding_, paddingW(),
outputHeight, outputHeight,
outputWidth, outputWidth,
colData); colData);
......
...@@ -37,14 +37,16 @@ public: ...@@ -37,14 +37,16 @@ public:
size_t outputChannels, size_t outputChannels,
size_t outputHeight, size_t outputHeight,
size_t outputWidth, size_t outputWidth,
size_t padding, size_t paddingH,
size_t stride) { size_t paddingW,
size_t strideH,
size_t strideW) {
for (size_t batch = 0; batch < batchSize; batch++) { for (size_t batch = 0; batch < batchSize; batch++) {
for (size_t outC = 0; outC < outputChannels; outC++) { for (size_t outC = 0; outC < outputChannels; outC++) {
for (size_t outH = 0; outH < outputHeight; outH++) { for (size_t outH = 0; outH < outputHeight; outH++) {
for (size_t outW = 0; outW < outputWidth; outW++) { for (size_t outW = 0; outW < outputWidth; outW++) {
const int inStartH = (outH * stride) - padding; const int inStartH = (outH * strideH) - paddingH;
const int inStartW = (outW * stride) - padding; const int inStartW = (outW * strideW) - paddingW;
T outValue = (T)0; T outValue = (T)0;
for (size_t inC = 0; inC < inputChannels; inC++) { for (size_t inC = 0; inC < inputChannels; inC++) {
for (size_t fH = 0; fH < filterHeight; fH++) { for (size_t fH = 0; fH < filterHeight; fH++) {
...@@ -118,8 +120,10 @@ public: ...@@ -118,8 +120,10 @@ public:
outputChannels, outputChannels,
outputHeight, outputHeight,
outputWidth, outputWidth,
padding_, paddingH(),
stride_); paddingW(),
strideH(),
strideW());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册