提交 01bd8dd2 编写于 作者: H Hong Ming 提交者: Tensor Tang

enable conv_winograd, fix conv_gemmlike bug, and update the unit tests of conv op

test=develop
上级 656e27e8
...@@ -30,10 +30,15 @@ class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -30,10 +30,15 @@ class ConvCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
void Run() override; void Run() override;
virtual ~ConvCompute() = default; ~ConvCompute() {
if (impl_ != nullptr) {
delete impl_;
}
}
private: private:
lite::arm::math::ImplBase<TARGET(kARM), PRECISION(kFloat), param_t>* impl_; lite::arm::math::ImplBase<TARGET(kARM), PRECISION(kFloat), param_t>* impl_{
nullptr};
}; };
} // namespace arm } // namespace arm
......
...@@ -123,11 +123,6 @@ TEST(conv_arm, init) { ...@@ -123,11 +123,6 @@ TEST(conv_arm, init) {
} }
TEST(conv_arm, compute) { TEST(conv_arm, compute) {
lite::Tensor input;
lite::Tensor filter;
lite::Tensor bias;
lite::Tensor output;
lite::Tensor output_ref;
DeviceInfo::Init(); DeviceInfo::Init();
for (auto n : {1, 2}) { for (auto n : {1, 2}) {
for (auto ic : {6, 32 /*, 128*/}) { for (auto ic : {6, 32 /*, 128*/}) {
...@@ -149,17 +144,26 @@ TEST(conv_arm, compute) { ...@@ -149,17 +144,26 @@ TEST(conv_arm, compute) {
std::vector<int64_t> input_shape = {n, ic, ih, iw}; std::vector<int64_t> input_shape = {n, ic, ih, iw};
std::vector<int64_t> filter_shape = {oc, ic / group, std::vector<int64_t> filter_shape = {oc, ic / group,
ks, ks}; ks, ks};
std::vector<int64_t> output_shape({n, oc}); const int dks = dilation * (ks - 1) + 1;
const int dkernel = dilation * (ks - 1) + 1; int oh = (ih + 2 * padding - dks) / stride + 1;
output_shape.push_back( int ow = (iw + 2 * padding - dks) / stride + 1;
(ih + 2 * padding - dkernel) / stride + 1); std::vector<int64_t> output_shape({n, oc, oh, ow});
output_shape.push_back(
(iw + 2 * padding - dkernel) / stride + 1);
// resize input, filter and output // resize input, filter and output
Tensor input;
Tensor filter;
Tensor bias;
Tensor output;
Tensor output_ref;
input.Resize(input_shape); input.Resize(input_shape);
filter.Resize(filter_shape); filter.Resize(filter_shape);
output.Resize(output_shape); output.Resize(output_shape);
output_ref.Resize(output_shape); output_ref.Resize(output_shape);
LOG(INFO) << "input: " << input.dims();
LOG(INFO) << "filter: " << filter.dims()
<< " padding:" << padding
<< " stride:" << stride
<< " dilation:" << dilation;
LOG(INFO) << "output: " << output.dims();
auto* input_data = input.mutable_data<float>(); auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>(); auto* filter_data = filter.mutable_data<float>();
auto* output_data = output.mutable_data<float>(); auto* output_data = output.mutable_data<float>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册