diff --git a/lite/kernels/opencl/concat_image_compute.cc b/lite/kernels/opencl/concat_image_compute.cc index 1a7ef0828ff6067277b9bc836b153e6368d8c78c..da0c9225c906beb9d0717bdf7ac1107a70a7e9cb 100644 --- a/lite/kernels/opencl/concat_image_compute.cc +++ b/lite/kernels/opencl/concat_image_compute.cc @@ -38,15 +38,13 @@ class ConcatComputeImage : public KernelLiteAs(); concat_param_ = param_.get_mutable(); + axis_ = concat_param_->axis; + if (-1 == axis_) { + axis_ = concat_param_->x[0]->dims().size() - 1; + } auto inputs = concat_param_->x; - auto axis_ = concat_param_->axis; auto output_tensor_dims = concat_param_->output->dims(); - auto* axis_tensor = concat_param_->axis_tensor; - if (axis_tensor != nullptr) { - // auto* axis_tensor_data = axis_tensor->data(TARGET(kARM)); - // axis = axis_tensor_data[0]; - } if (inputs.size() == 2) { kernel_func_name_ = "concat2"; @@ -100,8 +98,7 @@ class ConcatComputeImage : public KernelLitedims(); for (int i = 1; i < inputs.size(); i++) { auto dims = inputs[i]->dims(); - // auto flag = CHECK_EQ_OR_FALSE(input0_tensor_dims.size(), dims.size()); - if (input0_tensor_dims.size() != dims.size()) { - printf("input shape must be same \n"); - return; - } + CHECK(input0_tensor_dims.size() == dims.size()) + << "All inputs must have the same axes!"; for (int i = 0; i < dims.size(); i++) { if (i != axis_) { - if (input0_tensor_dims[i] != dims[i]) { - printf("input shape must be same \n"); - return; - } + CHECK(input0_tensor_dims[i] == dims[i]) + << "All inputs must have the same shape, except at concat axis!"; } } } @@ -151,29 +143,18 @@ class ConcatComputeImage : public KernelLitedims().size() << "D]:" - << " dims:" << inputs[i]->dims()[0] << " " - << inputs[i]->dims()[1] << " " << inputs[i]->dims()[2] << " " - << inputs[i]->dims()[3]; + << " dims:" << inputs[i]->dims(); } VLOG(4) << "concat output shape: "; - VLOG(4) << " out dims: " - << "[" << output_tensor_dims.size() - << "D]:" << output_tensor_dims[0] << " " << output_tensor_dims[1] - << " " << output_tensor_dims[2] << " " << output_tensor_dims[3]; + VLOG(4) << " out dims: " << output_tensor_dims; VLOG(4) << "axis_: " << axis_; VLOG(4) << "flag_: " << flag_; VLOG(4) << TargetToStr(concat_param_->output->target()); - VLOG(4) << "output_image_shape(w,h):" << output_image_shape["width"] << " " + VLOG(4) << "output_image_shape(w,h): " << output_image_shape["width"] << " " << output_image_shape["height"]; - VLOG(4) << "output_tensor_dims[" << output_tensor_dims.size() - << "D]:" << output_tensor_dims[0] << " " << output_tensor_dims[1] - << " " << output_tensor_dims[2] << " " << output_tensor_dims[3] - << "output_tensor_dims[output_tensor_dims.size() - 1]" - << output_tensor_dims[output_tensor_dims.size() - 1]; - VLOG(4) << "output_tensor_w: " << output_tensor_w << ", flag_: " << flag_; + VLOG(4) << "output_tensor_w: " << output_tensor_w; VLOG(4) << "width_:" << width_; VLOG(4) << "global_work_size: " << output_tensor_dims[output_tensor_dims.size() - 1] << " " @@ -433,6 +414,7 @@ class ConcatComputeImage : public KernelLite