未验证 提交 9bc1f34b 编写于 作者: Z zhaoyang-star 提交者: GitHub

[BugFix][OpenCL] Fix concat image impl when axis is not 1. test=develop (#4241)

* [BugFix][OpenCL] Fix concat image impl when concat axis is not 1

* fix code when axis == 1. test=develop

* fix illegal access when print debug info. test=develop

* fix typo
上级 4539964b
......@@ -38,15 +38,13 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
concat_param_ = param_.get_mutable<param_t>();
axis_ = concat_param_->axis;
if (-1 == axis_) {
axis_ = concat_param_->x[0]->dims().size() - 1;
}
auto inputs = concat_param_->x;
auto axis_ = concat_param_->axis;
auto output_tensor_dims = concat_param_->output->dims();
auto* axis_tensor = concat_param_->axis_tensor;
if (axis_tensor != nullptr) {
// auto* axis_tensor_data = axis_tensor->data<int>(TARGET(kARM));
// axis = axis_tensor_data[0];
}
if (inputs.size() == 2) {
kernel_func_name_ = "concat2";
......@@ -100,8 +98,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
width_ = output_tensor_dims[0]; // n
flag_ = 2;
break;
case 3:
case -1: // width
case 3: // width
width_ = output_tensor_dims[1]; // c
flag_ = 3;
break;
......@@ -113,17 +110,12 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
auto input0_tensor_dims = inputs[0]->dims();
for (int i = 1; i < inputs.size(); i++) {
auto dims = inputs[i]->dims();
// auto flag = CHECK_EQ_OR_FALSE(input0_tensor_dims.size(), dims.size());
if (input0_tensor_dims.size() != dims.size()) {
printf("input shape must be same \n");
return;
}
CHECK(input0_tensor_dims.size() == dims.size())
<< "All inputs must have the same axes!";
for (int i = 0; i < dims.size(); i++) {
if (i != axis_) {
if (input0_tensor_dims[i] != dims[i]) {
printf("input shape must be same \n");
return;
}
CHECK(input0_tensor_dims[i] == dims[i])
<< "All inputs must have the same shape, except at concat axis!";
}
}
}
......@@ -151,29 +143,18 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
VLOG(4) << "concat input shape: ";
for (size_t i = 0; i < inputs.size(); i++) {
VLOG(4) << "inputs [" << i << "]"
<< "[" << inputs[i]->dims().size() << "D]:"
<< " dims:" << inputs[i]->dims()[0] << " "
<< inputs[i]->dims()[1] << " " << inputs[i]->dims()[2] << " "
<< inputs[i]->dims()[3];
<< " dims:" << inputs[i]->dims();
}
VLOG(4) << "concat output shape: ";
VLOG(4) << " out dims: "
<< "[" << output_tensor_dims.size()
<< "D]:" << output_tensor_dims[0] << " " << output_tensor_dims[1]
<< " " << output_tensor_dims[2] << " " << output_tensor_dims[3];
VLOG(4) << " out dims: " << output_tensor_dims;
VLOG(4) << "axis_: " << axis_;
VLOG(4) << "flag_: " << flag_;
VLOG(4) << TargetToStr(concat_param_->output->target());
VLOG(4) << "output_image_shape(w,h):" << output_image_shape["width"] << " "
VLOG(4) << "output_image_shape(w,h): " << output_image_shape["width"] << " "
<< output_image_shape["height"];
VLOG(4) << "output_tensor_dims[" << output_tensor_dims.size()
<< "D]:" << output_tensor_dims[0] << " " << output_tensor_dims[1]
<< " " << output_tensor_dims[2] << " " << output_tensor_dims[3]
<< "output_tensor_dims[output_tensor_dims.size() - 1]"
<< output_tensor_dims[output_tensor_dims.size() - 1];
VLOG(4) << "output_tensor_w: " << output_tensor_w << ", flag_: " << flag_;
VLOG(4) << "output_tensor_w: " << output_tensor_w;
VLOG(4) << "width_:" << width_;
VLOG(4) << "global_work_size: "
<< output_tensor_dims[output_tensor_dims.size() - 1] << " "
......@@ -433,6 +414,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
}
#endif
private:
int axis_ = 1;
int flag_ = 1;
int width_ = 1;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册