From 7ad76355af156f046008cd7d2687cfd7be84b965 Mon Sep 17 00:00:00 2001 From: zhaojiaying01 Date: Wed, 24 Oct 2018 19:08:57 +0800 Subject: [PATCH] update conv3x3 and fix CLImageConverterNWBlock --- src/framework/cl/cl_image_converter.cpp | 8 ++++---- src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp | 4 ++-- src/operators/kernel/cl/conv_add_kernel.cpp | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/framework/cl/cl_image_converter.cpp b/src/framework/cl/cl_image_converter.cpp index ad3aec5b86..ebcfd0d675 100644 --- a/src/framework/cl/cl_image_converter.cpp +++ b/src/framework/cl/cl_image_converter.cpp @@ -248,8 +248,8 @@ void CLImageConverterNWBlock::NCHWToImage(float *tensor, half_t *image, for (int c = 0; c < C; c++) { for (int h = 0; h < H; ++h) { for (int w = 0; w < W; ++w) { - int index = 4 * c * (width * H) + 4 * (n / 4) * H * W + h * W * 4 + - w * 4 + (n % 4); + int index = 4 * c * (width * H) + 4 * h * width + 4 * W * (n / 4) + + w * 4 + n % 4; if (n < N) { image[index] = Float2Half(*p); p++; @@ -283,8 +283,8 @@ void CLImageConverterNWBlock::ImageToNCHW(half_t *image, float *tensor, for (int c = 0; c < C; c++) { for (int h = 0; h < H; ++h) { for (int w = 0; w < W; ++w) { - int index = 4 * c * (width * H) + 4 * (n / 4) * H * W + h * W * 4 + - w * 4 + (n % 4); + int index = 4 * c * (width * H) + 4 * h * width + 4 * W * (n / 4) + + w * 4 + n % 4; *p = Half2Float(image[index]); p++; if (index >= (width * height * 4)) { diff --git a/src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp b/src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp index 29f1b15199..c990d64f67 100644 --- a/src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp +++ b/src/operators/kernel/cl/conv_add_bn_relu_kernel.cpp @@ -150,8 +150,8 @@ bool ConvAddBNReluKernel::Init( } else if (param->Filter()->dims()[2] == 3 && param->Filter()->dims()[3] == 3) { - param->Filter()->InitCLImage(cl_helper_.CLContext(), - cl_helper_.CLCommandQueue()); + param->Filter()->InitNImage(cl_helper_.CLContext(), + cl_helper_.CLCommandQueue()); this->cl_helper_.AddKernel("conv_3x3", "conv_add_bn_relu_kernel.cl"); DLOG << " conv add bn relu conv_3x3"; diff --git a/src/operators/kernel/cl/conv_add_kernel.cpp b/src/operators/kernel/cl/conv_add_kernel.cpp index d8064fd50e..2a21391c5e 100644 --- a/src/operators/kernel/cl/conv_add_kernel.cpp +++ b/src/operators/kernel/cl/conv_add_kernel.cpp @@ -45,7 +45,7 @@ bool ConvAddKernel::Init(FusionConvAddParam *param) { } else if (param->Filter()->dims()[2] == 3 && param->Filter()->dims()[3] == 3) { - param->Filter()->InitCLImage(cl_helper_.CLContext(), + param->Filter()->InitNImage(cl_helper_.CLContext(), cl_helper_.CLCommandQueue()); this->cl_helper_.AddKernel("conv_3x3", "conv_add_kernel.cl"); -- GitLab