提交 c70d3e1a 编写于 作者: H hedaoyuan

Some bug fix

上级 3c0aa0cc
...@@ -96,7 +96,7 @@ public: ...@@ -96,7 +96,7 @@ public:
size_t inputHeight = inputs[0].shape()[2]; size_t inputHeight = inputs[0].shape()[2];
size_t inputWidth = inputs[0].shape()[3]; size_t inputWidth = inputs[0].shape()[3];
size_t filterHeight = inputs[1].shape()[2]; size_t filterHeight = inputs[1].shape()[2];
size_t filterWidth = inputs[1].shape()[2]; size_t filterWidth = inputs[1].shape()[3];
size_t outputChannels = outputs[0].shape()[1]; size_t outputChannels = outputs[0].shape()[1];
size_t outputHeight = outputs[0].shape()[2]; size_t outputHeight = outputs[0].shape()[2];
size_t outputWidth = outputs[0].shape()[3]; size_t outputWidth = outputs[0].shape()[3];
...@@ -148,23 +148,29 @@ public: ...@@ -148,23 +148,29 @@ public:
0.0f, 0.0f,
outputData + g * outputOffset, outputData + g * outputOffset,
N); N);
}
inputData += inputChannels * inputHeight * inputWidth; inputData += inputChannels * inputHeight * inputWidth;
outputData += outputChannels * outputHeight * outputWidth; outputData += outputChannels * outputHeight * outputWidth;
} }
} }
}
void resizeBuffer(size_t newSize) { void resizeBuffer(size_t newSize) {
if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) { if (!memory_ || newSize * sizeof(real) > memory_->getAllocSize()) {
if (Device == DEVICE_TYPE_CPU) {
memory_ = std::make_shared<CpuMemoryHandle>(newSize * sizeof(real)); memory_ = std::make_shared<CpuMemoryHandle>(newSize * sizeof(real));
} else {
memory_ = std::make_shared<GpuMemoryHandle>(newSize * sizeof(real));
}
} }
} }
private: private:
CpuMemHandlePtr memory_; MemoryHandlePtr memory_;
}; };
REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction); REGISTER_TYPED_FUNC(GemmConv, CPU, GemmConvFunction);
#ifndef PADDLE_ONLY_CPU
REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction); REGISTER_TYPED_FUNC(GemmConv, GPU, GemmConvFunction);
#endif
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册