未验证 提交 17b2a7d8 编写于 作者: J Jiaying Zhao 提交者: GitHub

Merge pull request #1577 from smilejames/develop

fix gpu conv_kernel
......@@ -26,18 +26,6 @@ bool ConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) {
param->Paddings()[0] == param->Paddings()[1],
"need equal");
auto filter_ddim = param->Filter()->dims();
std::vector<int64_t> filter_shape(
{filter_ddim[1], filter_ddim[0], filter_ddim[2], filter_ddim[3]});
framework::DDim ddim = framework::make_ddim(filter_shape);
if (filter_ddim[1] == 1) {
param->Filter()->Resize(ddim);
}
param->Filter()->InitCLImage(cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
int offset = static_cast<int>(param->Filter()->dims()[2]) / 2 -
static_cast<int>(param->Paddings()[1]);
param->SetOffset(offset);
......@@ -49,19 +37,25 @@ bool ConvKernel<GPU_CL, float>::Init(ConvParam<GPU_CL> *param) {
DLOG << " filter dims: " << param->Filter()->dims();
if (param->Filter()->dims()[2] == 1 && param->Filter()->dims()[3] == 1) {
DLOG << " here1 ";
param->Filter()->InitNImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_1x1", "conv_kernel.cl");
DLOG << "conv 1x1";
} else if (param->Filter()->dims()[0] == 1 &&
} else if (param->Filter()->dims()[1] == 1 &&
param->Input()->dims()[1] == param->Output()->dims()[1] &&
param->Filter()->dims()[2] == 3) {
DLOG << " here2 ";
param->Filter()->InitDWImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("depth_conv_3x3", "depthwise_conv_kernel.cl");
DLOG << "depth_conv 3x3";
} else if (param->Filter()->dims()[2] == 3 &&
param->Filter()->dims()[3] == 3) {
DLOG << " here3 ";
param->Filter()->InitCLImage(cl_helper_.CLContext(),
cl_helper_.CLCommandQueue());
this->cl_helper_.AddKernel("conv_3x3", "conv_kernel.cl");
DLOG << "conv 3x3";
} else {
PADDLE_MOBILE_THROW_EXCEPTION(" not support ");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册