diff --git a/src/framework/executor.cpp b/src/framework/executor.cpp index 669ad42469fc9ca4e00a6d8ae11fe3d53b433ca9..28b8b1f652336fa53d05980f8d0458a5d5e21df0 100644 --- a/src/framework/executor.cpp +++ b/src/framework/executor.cpp @@ -971,6 +971,23 @@ void Executor::InitCombineMemory() { program_.scope->GetCLScpoe()->CommandQueue(); const TensorDesc &desc = var_desc->Tensor_desc(); DDim ddim = cl_image->dims(); + bool shouldResize = true; + if (ddim.size() > 4) { + for (int i = 0; i < ddim.size() - 4; ++i) { + if (ddim[i] != 0) { + shouldResize = false; + break; + } + } + if (shouldResize) { + std::vector temp_intput_dims; + temp_intput_dims.reserve(static_cast(4)); + for (int i = ddim.size() - 4; i < ddim.size(); ++i) { + temp_intput_dims.push_back(ddim[i]); + } + ddim = framework::make_ddim(temp_intput_dims); + } + } // DDim ddim = make_ddim(desc.Dims()); cl_image->InitEmptyImage(context, command_queue, ddim); } diff --git a/src/framework/load_ops.h b/src/framework/load_ops.h index 741b1402b1e6a134bed8dd3a60ab61a756602fd0..f588c33d2f2fb165f0c628b8ae703556aa37ff7c 100644 --- a/src/framework/load_ops.h +++ b/src/framework/load_ops.h @@ -103,7 +103,7 @@ LOAD_OP2(fusion_elementwise_add_relu, CPU, FPGA); LOAD_FUSION_MATCHER(fusion_elementwise_add_relu); #endif #ifdef SPLIT_OP -LOAD_OP1(split, CPU); +LOAD_OP2(split, CPU, GPU_CL); #endif #ifdef RESIZE_OP LOAD_OP1(resize, CPU); @@ -116,13 +116,13 @@ LOAD_FUSION_MATCHER(fusion_conv_add_bn_relu); LOAD_OP2(reshape, CPU, GPU_CL); #endif #ifdef RESHAPE2_OP -LOAD_OP1(reshape2, CPU); +LOAD_OP2(reshape2, CPU, GPU_CL); #endif #ifdef TRANSPOSE_OP LOAD_OP2(transpose, CPU, GPU_CL); #endif #ifdef TRANSPOSE2_OP -LOAD_OP1(transpose2, CPU); +LOAD_OP2(transpose2, CPU, GPU_CL); #endif #ifdef PRIORBOX_OP LOAD_OP2(prior_box, CPU, GPU_CL); diff --git a/src/operators/kernel/cl/cl_kernel/scale_kernel.cl b/src/operators/kernel/cl/cl_kernel/scale_kernel.cl index f5976be7e73b5fdbfd84e83a818c45b2d9e8b285..741eb6dcbc731f943fbcb95f5719c102ba847644 100644 --- a/src/operators/kernel/cl/cl_kernel/scale_kernel.cl +++ b/src/operators/kernel/cl/cl_kernel/scale_kernel.cl @@ -18,7 +18,7 @@ __kernel void scale(__read_only image2d_t input, __write_only image2d_t output, __private float scale, __private float bias, - __private float out_width){ + __private int out_width){ const int out_c = get_global_id(0); const int out_w = get_global_id(1); diff --git a/src/operators/kernel/cl/reshape2_kernel.cpp b/src/operators/kernel/cl/reshape2_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7dbea06a5167cefb3958081a35ffcc3791fb1663 --- /dev/null +++ b/src/operators/kernel/cl/reshape2_kernel.cpp @@ -0,0 +1,150 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef RESHAPE2_OP + +#include "operators/kernel/reshape2_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool Reshape2Kernel::Init(Reshape2Param *param) { + this->cl_helper_.AddKernel("reshape", "reshape.cl"); + return true; +} + +inline framework::DDim ValidateShape(const std::vector shape, + const framework::DDim &in_dims) { + const int64_t in_size = framework::product(in_dims); + // only one dimension can be set to -1, whose size will be automatically + // infered. + const int64_t unk_dim_val = -1; + const int64_t copy_dim_val = 0; + + std::vector output_shape(shape.size(), 0); + int64_t capacity = 1; + int unk_dim_idx = -1; + for (size_t i = 0; i < shape.size(); ++i) { + if (shape[i] == unk_dim_val) { + PADDLE_MOBILE_ENFORCE( + unk_dim_idx == -1, + "Only one input dimension of Attr(shape) can be unknown."); + unk_dim_idx = i; + } else if (shape[i] == copy_dim_val) { + PADDLE_MOBILE_ENFORCE( + static_cast(i) < in_dims.size(), + "The index of dimension to copy from input shape must be less " + "than the size of input shape."); + } else { + PADDLE_MOBILE_ENFORCE( + shape[i] > 0, + "Each input dimension of Attr(shape) must not be negtive except " + "one unknown dimension."); + } + + capacity *= (shape[i] ? shape[i] : in_dims[i]); + output_shape[i] = (shape[i] ? static_cast(shape[i]) : in_dims[i]); + } + + if (unk_dim_idx != -1) { + output_shape[unk_dim_idx] = -in_size / capacity; + PADDLE_MOBILE_ENFORCE(output_shape[unk_dim_idx] * capacity == -in_size, + "Invalid shape is given."); + } else { + PADDLE_MOBILE_ENFORCE(capacity == in_size, "Invalid shape is given."); + } + return framework::make_ddim(output_shape); +} + +template <> +void Reshape2Kernel::Compute( + const Reshape2Param ¶m) { + auto kernel = this->cl_helper_.KernelAt(0); + auto default_work_size = this->cl_helper_.DefaultWorkSize(*param.Out()); + const auto *input = param.InputX(); + auto *output = param.Out(); + auto input_image = input->GetCLImage(); + auto output_image = output->GetCLImage(); + const auto &inputDim = input->dims(); + const auto &outputDim = output->dims(); + int input_dims[4] = {1, 1, 1, 1}; + int output_dims[4] = {1, 1, 1, 1}; + // 1 1000 1 1 + for (int i = 0; i < inputDim.size(); i++) { + input_dims[4 - inputDim.size() + i] = inputDim[i]; + } + + // 1 1 1 1000 + for (int i = 0; i < outputDim.size(); i++) { + output_dims[4 - outputDim.size() + i] = outputDim[i]; + } + + int out_C = output_dims[1]; + int out_H = output_dims[2]; + int out_W = output_dims[3]; + int in_W = input_dims[3]; + int in_H = input_dims[2]; + int in_Stride0 = in_W; + int in_Stride1 = input_dims[2] * input_dims[3]; + int in_Stride2 = input_dims[1] * input_dims[2] * input_dims[3]; + int out_Stride0 = out_W; + int out_Stride1 = out_H * out_W; + int out_Stride2 = out_C * out_H * out_W; + DLOG << "out_C=" << out_C; + DLOG << "out_H=" << out_H; + DLOG << "out_W=" << out_W; + DLOG << "in_W=" << in_W; + DLOG << "default_work_size=" << default_work_size; + DLOG << "in_Stride0=" << in_Stride0; + DLOG << "in_Stride1=" << in_Stride1; + DLOG << "out_Stride0=" << out_Stride0; + DLOG << "out_Stride1=" << out_Stride1; + cl_int status; + status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &input_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &output_image); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 2, sizeof(int), &out_C); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 3, sizeof(int), &out_H); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 4, sizeof(int), &out_W); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 5, sizeof(int), &in_W); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 6, sizeof(int), &in_H); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 7, sizeof(int), &in_Stride0); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 8, sizeof(int), &in_Stride1); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 9, sizeof(int), &in_Stride2); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 10, sizeof(int), &out_Stride0); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 11, sizeof(int), &out_Stride1); + CL_CHECK_ERRORS(status); + status = clSetKernelArg(kernel, 12, sizeof(int), &out_Stride2); + CL_CHECK_ERRORS(status); + status = clEnqueueNDRangeKernel( + this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL, + default_work_size.data(), NULL, 0, NULL, NULL); + CL_CHECK_ERRORS(status); +} + +template class Reshape2Kernel; + +} // namespace operators +} // namespace paddle_mobile +#endif diff --git a/src/operators/kernel/cl/split_kernel.cpp b/src/operators/kernel/cl/split_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..58c7361bc5ff5430ce54a8d8bca323fbbe7f9f2a --- /dev/null +++ b/src/operators/kernel/cl/split_kernel.cpp @@ -0,0 +1,116 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef SPLIT_OP + +#include "operators/kernel/split_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool SplitKernel::Init(SplitParam* param) { + this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); + this->cl_helper_.AddKernel("feed", "feed_kernel.cl"); + return true; +} + +// Strided numel memory copy from src to dst by the specified axis +// +// For example, for a tensor dims [4, 20, 100], the strieded numel is +// [8000, 2000, 100] +// +// NOTE: The src and dst tensor should have the same elements +// except the specified axis. +template +void StridedNumelCopyWithAxis(int64_t axis, T* dst, + const framework::DDim& dst_stride_numel, + const T* src, + const framework::DDim& src_stride_numel, + int64_t size) { + int64_t before = dst_stride_numel[0] / dst_stride_numel[axis]; + int64_t src_after = src_stride_numel[axis]; + int64_t dst_after = dst_stride_numel[axis]; + + PADDLE_MOBILE_ENFORCE(src_stride_numel.size() == dst_stride_numel.size(), + "src and dst tensor should have the same dims size."); + + for (int64_t i = 0; i < axis; ++i) { + if (i < axis) { + PADDLE_MOBILE_ENFORCE(src_stride_numel[i] / src_stride_numel[axis] == + dst_stride_numel[i] / dst_stride_numel[axis], + "src and dst should have the same elements " + "except the specified axis."); + } else if (i == axis) { + continue; + } else { + PADDLE_MOBILE_ENFORCE(src_stride_numel[i] == dst_stride_numel[i], + "src and dst should have the same elements " + "except the specified axis."); + } + } + + for (int64_t i = 0; i < before; ++i) { + memory::Copy(dst + i * dst_after, src + i * src_after, sizeof(T) * size); + } +} + +template <> +void SplitKernel::Compute(const SplitParam& param) { + auto kernel0 = this->cl_helper_.KernelAt(0); + auto kernel1 = this->cl_helper_.KernelAt(1); + auto* input_image = param.InputX(); + auto in_stride = framework::stride_numel(input_image->dims()); + auto input_dims = input_image->dims(); + auto outs_images = param.Outs(); + int64_t axis = param.Axis(); + + Tensor* input_tensor = new Tensor(); + input_tensor->Resize(input_image->dims()); + input_tensor->mutable_data(); + + framework::CLImageToTensor(input_image, input_tensor, + this->cl_helper_.CLContext(), + this->cl_helper_.CLCommandQueue(), kernel0); + + size_t input_offset = 0; + for (auto out : outs_images) { + auto out_stride = framework::stride_numel(out->dims()); + + Tensor* temp_out = new Tensor(); + temp_out->Resize(out->dims()); + temp_out->mutable_data(); + framework::CLImageToTensor(out, temp_out, this->cl_helper_.CLContext(), + this->cl_helper_.CLCommandQueue(), kernel0); + StridedNumelCopyWithAxis(axis, temp_out->data(), out_stride, + input_tensor->data() + input_offset, + in_stride, out_stride[axis]); + input_offset += out_stride[axis]; + out->InitEmptyImage(this->cl_helper_.CLContext(), + this->cl_helper_.CLCommandQueue(), temp_out->dims()); + framework::TensorToCLImage(temp_out, out, this->cl_helper_.CLContext(), + this->cl_helper_.CLCommandQueue(), kernel1); + outs_images.push_back(out); + + delete (temp_out); + } + delete (input_tensor); +} + +template class SplitKernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/kernel/cl/transpose2_kernel.cpp b/src/operators/kernel/cl/transpose2_kernel.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d6d92d8a68cdc1237f9ff83c07110868b9230d34 --- /dev/null +++ b/src/operators/kernel/cl/transpose2_kernel.cpp @@ -0,0 +1,135 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#ifdef TRANSPOSE2_OP + +#include "operators/kernel/transpose2_kernel.h" + +namespace paddle_mobile { +namespace operators { + +template <> +bool Transpose2Kernel::Init(Transpose2Param *param) { + this->cl_helper_.AddKernel("fetch", "fetch_kernel.cl"); + this->cl_helper_.AddKernel("feed", "feed_kernel.cl"); + return true; +} + +inline bool IsShuffleChannel(const std::vector &axis) { + bool is_shuffle_channel = true; + if (axis.size() > 2 && axis[0] == 0 && axis[1] == 2 && axis[2] == 1) { + for (int i = 3; i < axis.size(); ++i) { + if (axis[i] != i) { + is_shuffle_channel = false; + break; + } + } + } else { + return false; + } + return is_shuffle_channel; +} + +template +void ShuffleChannelCompute(const Transpose2Param ¶m, + cl_context context, cl_command_queue commandQueue, + cl_kernel kernel0, cl_kernel kernel1) { + auto axis = param.Axis(); + int axis_size = axis.size(); + + bool shouldResize = true; + int diff_dim = 0; + if (axis_size > 4) { + for (int i = 0; i < axis_size - 4; ++i) { + if (axis[i] != i) { + shouldResize = false; + break; + } else { + diff_dim++; + } + } + if (shouldResize) { + std::vector temp_axis_dims; + temp_axis_dims.reserve(static_cast(4)); + for (int i = axis_size - 4; i < axis_size; ++i) { + temp_axis_dims.push_back(axis[i] - diff_dim); + } + axis.resize(4); + axis.clear(); + axis.insert(axis.begin(), temp_axis_dims.begin(), temp_axis_dims.end()); + } + } + + auto input = param.InputX(); + Tensor *input_tensor = new Tensor(); + input_tensor->Resize(input->dims()); + input_tensor->mutable_data(); + + framework::CLImageToTensor(input, input_tensor, context, commandQueue, + kernel0); + const Dtype *input_ptr = input_tensor->data(); + + auto output = param.Out(); + Tensor *output_tensor = new Tensor(); + output_tensor->Resize(input->dims()); + output_tensor->mutable_data(); + Dtype *output_ptr = output_tensor->mutable_data(); + // input and output's shape dimension must >= 2 && <= 6. + const framework::DDim &in_dim = input->dims(); + const framework::DDim &out_dim = output->dims(); + size_t offset = 1; + for (int i = 2; i < axis.size(); ++i) { + offset *= in_dim[i]; + } + +#pragma omp parallel for collapse(2) + for (int c1 = 0; c1 < out_dim[0]; ++c1) { + for (int c2 = 0; c2 < out_dim[1]; ++c2) { + size_t out_offset = (c1 * out_dim[1] + c2) * offset; + size_t in_offset = (c2 * in_dim[1] + c1) * offset; + memcpy(output_ptr + out_offset, input_ptr + in_offset, + offset * sizeof(Dtype)); + } + } + + output->InitEmptyImage(context, commandQueue, output_tensor->dims()); + framework::TensorToCLImage(output_tensor, output, context, commandQueue, + kernel1); + + delete (input_tensor); + delete (output_tensor); +} + +template <> +void Transpose2Kernel::Compute( + const Transpose2Param ¶m) { + auto kernel0 = this->cl_helper_.KernelAt(0); + auto kernel1 = this->cl_helper_.KernelAt(1); + + const std::vector &axis = param.Axis(); + bool shuffle_channel = IsShuffleChannel(axis); + if (shuffle_channel) { + ShuffleChannelCompute(param, this->cl_helper_.CLContext(), + this->cl_helper_.CLCommandQueue(), kernel0, + kernel1); + } else { + PADDLE_MOBILE_THROW_EXCEPTION("axis not support"); + } +} + +template class Transpose2Kernel; + +} // namespace operators +} // namespace paddle_mobile + +#endif diff --git a/src/operators/op_param.h b/src/operators/op_param.h index 3558fda919607642920cbff35a1cf9c752bba798..20b9ce4ddbb90944b0a2d0a5aed1d3c4ef759772 100644 --- a/src/operators/op_param.h +++ b/src/operators/op_param.h @@ -1362,7 +1362,7 @@ class Transpose2Param : public OpParam { axis_ = GetAttr>("axis", attrs); } - const GType *InputX() const { return input_x_; } + GType *InputX() const { return input_x_; } GType *Out() const { return out_; } @@ -1510,7 +1510,7 @@ class Reshape2Param : public OpParam { } } - const GType *InputX() const { return input_x_; } + GType *InputX() const { return input_x_; } const GType *InputShape() const { return input_shape_; } @@ -2807,7 +2807,7 @@ class SplitParam : public OpParam { // out_ts_.push_back(*scope.FindVar(outs_[i])->GetMutable()); // } } - const GType *InputX() const { return input_x_; } + GType *InputX() const { return input_x_; } std::vector Outs() const { return outs_; } int Axis() const { return axis; } int Num() const { return num; } diff --git a/src/operators/reshape2_op.cpp b/src/operators/reshape2_op.cpp index b43f2996623f31160827054802195152d8d2d873..4ac8f3458efd6fc19f885f3c55533c039bcf4b35 100644 --- a/src/operators/reshape2_op.cpp +++ b/src/operators/reshape2_op.cpp @@ -24,8 +24,52 @@ template void Reshape2Op::InferShape() const { auto &shape = this->param_.Shape(); auto input_x_dims = this->param_.InputX()->dims(); +#ifdef PADDLE_MOBILE_CL + auto input_dim_size = input_x_dims.size(); + bool shouldResize = true; + if (input_dim_size > 4) { + for (int i = 0; i < input_dim_size - 4; ++i) { + if (input_x_dims[i] != 0 && input_x_dims[i] != 1) { + shouldResize = false; + break; + } + } + if (shouldResize) { + std::vector temp_intput_dims; + temp_intput_dims.reserve(static_cast(4)); + for (int i = input_dim_size - 4; i < input_dim_size; ++i) { + temp_intput_dims.push_back(input_x_dims[i]); + } + framework::DDim temp_ddim = framework::make_ddim(temp_intput_dims); + this->param_.InputX()->Resize(temp_ddim); + input_x_dims = this->param_.InputX()->dims(); + } + } +#endif + auto out_dims = ValidateShape(shape, input_x_dims); this->param_.Out()->Resize(out_dims); +#ifdef PADDLE_MOBILE_CL + input_x_dims = this->param_.InputX()->dims(); + shouldResize = true; + if (out_dims.size() > 4) { + for (int i = 0; i < out_dims.size() - 4; ++i) { + if (out_dims[i] != 0 && out_dims[i] != 1) { + shouldResize = false; + break; + } + } + if (shouldResize) { + std::vector temp_output_dims; + temp_output_dims.reserve(static_cast(4)); + for (int i = out_dims.size() - 4; i < out_dims.size(); ++i) { + temp_output_dims.push_back(out_dims[i]); + } + framework::DDim temp_ddim = framework::make_ddim(temp_output_dims); + this->param_.Out()->Resize(temp_ddim); + } + } +#endif std::vector xshape_dims(input_x_dims.size() + 1, 0); for (int i = 0; i < input_x_dims.size(); ++i) { xshape_dims[i + 1] = input_x_dims[i]; @@ -40,6 +84,9 @@ namespace ops = paddle_mobile::operators; #ifdef PADDLE_MOBILE_CPU REGISTER_OPERATOR_CPU(reshape2, ops::Reshape2Op); #endif +#ifdef PADDLE_MOBILE_CL +REGISTER_OPERATOR_CL(reshape2, ops::Reshape2Op); +#endif #ifdef PADDLE_MOBILE_FPGA REGISTER_OPERATOR_FPGA(reshape2, ops::Reshape2Op); #endif diff --git a/src/operators/split_op.cpp b/src/operators/split_op.cpp index b440c0be436f333cc46e320c026f67b6020b8aab..ec82214a48551731eed1e51ef5455c39bb5f8e1e 100644 --- a/src/operators/split_op.cpp +++ b/src/operators/split_op.cpp @@ -86,5 +86,8 @@ REGISTER_OPERATOR_CPU(split, ops::SplitOp); #ifdef PADDLE_MOBILE_FPGA REGISTER_OPERATOR_FPGA(split, ops::SplitOp); #endif +#ifdef PADDLE_MOBILE_CL +REGISTER_OPERATOR_CL(split, ops::SplitOp); +#endif #endif // SPLIT_OP diff --git a/src/operators/transpose2_op.cpp b/src/operators/transpose2_op.cpp index 03db27a9a2f8fc8974a1b1c97b1d71782388103e..945e019f672cd47a009bd1ad1b4083798b97366d 100644 --- a/src/operators/transpose2_op.cpp +++ b/src/operators/transpose2_op.cpp @@ -29,6 +29,55 @@ void Transpose2Op::InferShape() const { size_t x_dims_size = input_x_dims.size(); size_t axis_size = axis.size(); +#ifdef PADDLE_MOBILE_CL + bool shouldResize = true; + int diff_dim = 0; + if (axis_size > 4) { + for (int i = 0; i < axis_size - 4; ++i) { + if (axis[i] != i) { + shouldResize = false; + break; + } else { + diff_dim++; + } + } + if (shouldResize) { + std::vector temp_axis_dims; + temp_axis_dims.reserve(static_cast(4)); + for (int i = axis_size - 4; i < axis_size; ++i) { + temp_axis_dims.push_back(axis[i] - diff_dim); + } + axis.resize(4); + axis.clear(); + axis.insert(axis.begin(), temp_axis_dims.begin(), temp_axis_dims.end()); + } + } + + auto input_dim_size = input_x_dims.size(); + shouldResize = true; + if (input_dim_size > 4) { + for (int i = 0; i < input_dim_size - 4; ++i) { + if (input_x_dims[i] != 0 && input_x_dims[i] != 1) { + shouldResize = false; + break; + } + } + if (shouldResize) { + std::vector temp_intput_dims; + temp_intput_dims.reserve(static_cast(4)); + for (int i = input_dim_size - 4; i < input_dim_size; ++i) { + temp_intput_dims.push_back(input_x_dims[i]); + } + framework::DDim temp_ddim = framework::make_ddim(temp_intput_dims); + this->param_.InputX()->Resize(temp_ddim); + } + } + + axis_size = axis.size(); + input_x_dims = this->param_.InputX()->dims(); + x_dims_size = input_x_dims.size(); +#endif + PADDLE_MOBILE_ENFORCE((x_dims_size == axis_size), "input_dims must " "be equal to the axis_size. ") @@ -63,5 +112,7 @@ REGISTER_OPERATOR_CPU(transpose2, ops::Transpose2Op); #ifdef PADDLE_MOBILE_FPGA REGISTER_OPERATOR_FPGA(transpose2, ops::Transpose2Op); #endif - +#ifdef PADDLE_MOBILE_CL +REGISTER_OPERATOR_CL(transpose2, ops::Transpose2Op); +#endif #endif // TRANSPOSE_OP