From 6c401ff6f257d0bf9280af9372ff226b937c43c1 Mon Sep 17 00:00:00 2001 From: pengyongrong Date: Fri, 7 Aug 2020 17:08:25 +0800 Subject: [PATCH] submitted concat-imgae2d to mindspore --- .../runtime/kernel/opencl/cl/fp32/concat.cl | 82 +++++++------- .../runtime/kernel/opencl/kernel/concat.cc | 102 +++++++++--------- .../src/runtime/kernel/opencl/kernel/concat.h | 9 +- .../src/runtime/opencl/opencl_allocator.cc | 31 +++--- .../src/runtime/kernel/opencl/concat_tests.cc | 69 ++++++++---- 5 files changed, 153 insertions(+), 140 deletions(-) diff --git a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/concat.cl b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/concat.cl index d457a3e4f..fc576b31c 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/concat.cl +++ b/mindspore/lite/src/runtime/kernel/opencl/cl/fp32/concat.cl @@ -1,54 +1,44 @@ #pragma OPENCL EXTENSION cl_khr_fp16 : enable -__kernel void Concat(__global float *input0, __global float *input1, __global float *output, const int4 input_shape0, - const int4 input_shape1, const int4 output_shape, const int axis) { - uint oh = get_global_id(0); - uint ow = get_global_id(1); - uint oc = get_global_id(2); - uint index_output; - uint input_idx; - if ((oh >= output_shape.y || oh < 0) || (ow >= output_shape.z || ow < 0) || (oc >= output_shape.w || oc < 0)) { - return; +#define FLT4 float4 +__constant sampler_t smp_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; +__kernel void Concat(__write_only image2d_t output_image2d, __read_only image2d_t input0_image2d, + __read_only image2d_t input1_image2d, int2 shared_int0, int4 shared_out) { + int X = get_global_id(0); // H + int Y = get_global_id(1); // W + int S = 0; + if (X >= shared_out.y || Y >= shared_out.z) return; + for (int i = 0; i < shared_int0.x; i++) { + FLT4 result0 = read_imagef(input0_image2d, smp_none, (int2)((Y)*shared_int0.x + (i), (X))); + write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result0); + S++; } - if (axis == 3) { - index_output = oh * output_shape.z * output_shape.w + ow * output_shape.w + oc; - if (oc < input_shape0.w) { - input_idx = (input_shape0.z * oh + ow) * input_shape0.w + oc; - output[index_output] = input0[input_idx]; - } else if ((input_shape0.w <= oc) && oc < (input_shape0.w + input_shape1.w)) { - input_idx = (input_shape1.z * oh + ow) * input_shape1.w + (oc - input_shape0.w); - output[index_output] = input1[input_idx]; - } else { - output[index_output] = 0; - } + for (int i = 0; i < shared_int0.y; i++) { + FLT4 result1 = read_imagef(input1_image2d, smp_none, (int2)((Y)*shared_int0.y + (i), (X))); + write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result1); + S++; } } -__kernel void Concat3input(__global float *input0, __global float *input1, __global float *input2, - __global float *output, const int4 input_shape0, const int4 input_shape1, - const int4 input_shape2, const int4 output_shape, const int axis) { - uint oh = get_global_id(0); - uint ow = get_global_id(1); - uint oc = get_global_id(2); - uint index_output; - uint input_idx; - if ((oh >= output_shape.y || oh < 0) || (ow >= output_shape.z || ow < 0) || (oc >= output_shape.w || oc < 0)) { - return; +__kernel void Concat3input(__write_only image2d_t output_image2d, __read_only image2d_t input0_image2d, + __read_only image2d_t input1_image2d, __read_only image2d_t input2_image2d, int3 shared_int0, + int4 shared_out) { + int X = get_global_id(0); // H + int Y = get_global_id(1); // W + int S = 0; + if (X >= shared_out.y || Y >= shared_out.z) return; + for (int i = 0; i < shared_int0.x; i++) { + FLT4 result0 = read_imagef(input0_image2d, smp_none, (int2)((Y)*shared_int0.x + (i), (X))); + write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result0); + S++; } - index_output = oh * output_shape.z * output_shape.w + ow * output_shape.w + oc; - if (oc < (input_shape0.w + input_shape1.w)) { - if (oc < input_shape0.w) { - input_idx = (input_shape0.z * oh + ow) * input_shape0.w + oc; - output[index_output] = input0[input_idx]; - } else { - input_idx = (input_shape1.z * oh + ow) * input_shape1.w + (oc - input_shape0.w); - output[index_output] = input1[input_idx]; - } - } else { - if ((input_shape0.w + input_shape1.w + input_shape2.w) <= oc) { - output[index_output] = 0; - } else { - input_idx = (input_shape2.z * oh + ow) * input_shape2.w + (oc - input_shape0.w - input_shape1.w); - output[index_output] = input2[input_idx]; - } + for (int i = 0; i < shared_int0.y; i++) { + FLT4 result1 = read_imagef(input1_image2d, smp_none, (int2)((Y)*shared_int0.y + (i), (X))); + write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result1); + S++; + } + for (int i = 0; i < shared_int0.z; i++) { + FLT4 result2 = read_imagef(input2_image2d, smp_none, (int2)((Y)*shared_int0.z + (i), (X))); + write_imagef(output_image2d, (int2)((Y)*shared_out.w + (S), (X)), result2); + S++; } } diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc index 4823f1eca..82540b49f 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.cc @@ -13,6 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include #include #include #include @@ -27,6 +28,26 @@ using mindspore::schema::PrimitiveType_Concat; namespace mindspore::kernel { +int ConcatOpenCLKernel::GetImageSize(size_t idx, std::vector *img_size) { + size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); + size_t im_dst_x, im_dst_y; + if (inputs_[0]->GetFormat() == schema::Format_NHWC4) { + im_dst_x = outputs_[0]->Width() * CO4; + im_dst_y = outputs_[0]->Height(); + } else { + im_dst_y = outputs_[0]->Height() * CO4; + im_dst_x = outputs_[0]->Width(); + } +#ifdef ENABLE_FP16 + size_t img_dtype = CL_HALF_FLOAT; +#else + size_t img_dtype = CL_FLOAT; +#endif + img_size->clear(); + std::vector vec{im_dst_x, im_dst_y, img_dtype}; + *img_size = vec; + return 1; +} int ConcatOpenCLKernel::Init() { if (inputs_[0]->shape().size() != 4) { MS_LOG(ERROR) << "only support dim=4"; @@ -132,72 +153,45 @@ int ConcatOpenCLKernel::Run() { } auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); - std::vector local; - std::vector global; + MS_LOG(INFO) << " judge the numbers of input vector"; + auto input0_shape = inputs_[0]->shape(); + auto input1_shape = inputs_[1]->shape(); + auto input2_shape = inputs_[2]->shape(); + auto output_shape = outputs_[0]->shape(); + + cl_int2 input0_shape2_ = {DivideRoundUp(input0_shape[3], 4), DivideRoundUp(input1_shape[3], 4)}; // change + cl_int3 input0_shape3_ = {DivideRoundUp(input0_shape[3], 4), DivideRoundUp(input1_shape[3], 4), + DivideRoundUp(input2_shape[3], 4)}; + cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], DivideRoundUp(output_shape[3], 4)}; + + uint32_t OH = output_shape[0] * output_shape[1]; // N*H + uint32_t OW = output_shape[2]; + + std::vector local = {1, 1}; + std::vector global = {OH, OW}; + // ConcatGetWorkGroup(global, &local, 512); + + int arg_cn = 0; if (inputs_.size() == 2) { - auto input0_shape = inputs_[0]->shape(); - auto input1_shape = inputs_[1]->shape(); - auto output_shape = outputs_[0]->shape(); - - cl_int4 input0_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], input0_shape[3]}; - cl_int4 input1_shape_ = {input1_shape[0], input1_shape[1], input1_shape[2], input1_shape[3]}; - cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], output_shape[3]}; - - uint32_t OH = output_shape[0] * output_shape[1]; // N*H - uint32_t OW = output_shape[2]; - uint32_t OC = output_shape[3]; - global = {OH, OW, OC}; // HWC - ConcatGetWorkGroup(global, &local, 384); - std::cout << "local size=:" << std::endl; - for (int i = 0; i < local.size(); i++) { - std::cout << local[i] << " "; - } - std::cout << std::endl; - int arg_cn = 0; + MS_LOG(INFO) << " SetKernelArg"; + ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data()); - ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); - ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape_); - ocl_runtime->SetKernelArg(kernel_, arg_cn++, input1_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape2_); ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_); - ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_); - } - if (inputs_.size() == 3) { - auto input0_shape = inputs_[0]->shape(); - auto input1_shape = inputs_[1]->shape(); - auto input2_shape = inputs_[2]->shape(); - auto output_shape = outputs_[0]->shape(); - - cl_int4 input0_shape_ = {input0_shape[0], input0_shape[1], input0_shape[2], input0_shape[3]}; - cl_int4 input1_shape_ = {input1_shape[0], input1_shape[1], input1_shape[2], input1_shape[3]}; - cl_int4 input2_shape_ = {input2_shape[0], input2_shape[1], input2_shape[2], input2_shape[3]}; - cl_int4 output_shape_ = {output_shape[0], output_shape[1], output_shape[2], output_shape[3]}; - - uint32_t OH = output_shape[0] * output_shape[1]; // N*H - uint32_t OW = output_shape[2]; - uint32_t OC = output_shape[3]; - global = {OH, OW, OC}; // HWC - ConcatGetWorkGroup(global, &local, 384); - std::cout << "local size=:" << std::endl; - for (int i = 0; i < local.size(); i++) { - std::cout << local[i] << " "; - } - std::cout << std::endl; - int arg_cn = 0; + } else if (inputs_.size() == 3) { + MS_LOG(INFO) << " SetKernelArg"; + ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[0]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[1]->Data()); ocl_runtime->SetKernelArg(kernel_, arg_cn++, inputs_[2]->Data()); - ocl_runtime->SetKernelArg(kernel_, arg_cn++, outputs_[0]->Data()); - ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape_); - ocl_runtime->SetKernelArg(kernel_, arg_cn++, input1_shape_); - ocl_runtime->SetKernelArg(kernel_, arg_cn++, input2_shape_); + ocl_runtime->SetKernelArg(kernel_, arg_cn++, input0_shape3_); ocl_runtime->SetKernelArg(kernel_, arg_cn++, output_shape_); - ocl_runtime->SetKernelArg(kernel_, arg_cn++, param->axis_); } ocl_runtime->RunKernel(kernel_, global, local, nullptr); return 0; -} +} // namespace mindspore::kernel kernel::LiteKernel *OpenCLConcatKernelCreator(const std::vector &inputs, const std::vector &outputs, diff --git a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h index 1f2c115f8..0f1d1ab13 100644 --- a/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h +++ b/mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h @@ -14,8 +14,8 @@ * limitations under the License. */ -#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_Concat_H_ -#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_Concat_H_ +#ifndef MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONCAT_H_ +#define MINDSPORE_LITE_SRC_BACKEND_OPENCL_CONCAT_H_ #include #include "ir/anf.h" @@ -25,11 +25,11 @@ namespace mindspore::kernel { -class ConcatOpenCLKernel : public LiteKernel { +class ConcatOpenCLKernel : public OpenCLKernel { public: explicit ConcatOpenCLKernel(OpParameter *parameter, const std::vector &inputs, const std::vector &outputs) - : LiteKernel(parameter, inputs, outputs) {} + : OpenCLKernel(parameter, inputs, outputs) {} ~ConcatOpenCLKernel() override{}; @@ -40,6 +40,7 @@ class ConcatOpenCLKernel : public LiteKernel { int Run_axis0(); int Run() override; + int GetImageSize(size_t idx, std::vector *img_size) override; private: cl::Kernel kernel_; diff --git a/mindspore/lite/src/runtime/opencl/opencl_allocator.cc b/mindspore/lite/src/runtime/opencl/opencl_allocator.cc index ed579540b..588a34746 100644 --- a/mindspore/lite/src/runtime/opencl/opencl_allocator.cc +++ b/mindspore/lite/src/runtime/opencl/opencl_allocator.cc @@ -69,8 +69,8 @@ void *OpenCLAllocator::Malloc(size_t size) { host_ptr = clSVMAlloc((*ocl_runtime->Context())(), flags, size, 0); } else { cl_int ret = CL_SUCCESS; - cl::Buffer *buffer = - new cl::Buffer(*ocl_runtime->Context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, size, NULL, &ret); + cl::Buffer *buffer = new cl::Buffer(*ocl_runtime->Context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, + size, NULL, &ret); if (ret != CL_SUCCESS) { MS_LOG(ERROR) << "Create OpenCL buffer failed! (ERROR CODE: " << ret << ")"; UnLock(); @@ -125,8 +125,8 @@ void *OpenCLAllocator::Malloc(size_t size, const std::vector& img_size) cl_int ret = CL_SUCCESS; // CL_HALF_FLOAT, CL_FLOAT cl::ImageFormat image_format(CL_RGBA, img_size[2]); - cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE, - image_format, img_size[0], img_size[1], 0, nullptr, &ret); + cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_WRITE, image_format, + img_size[0], img_size[1], 0, nullptr, &ret); if (ret != CL_SUCCESS) { MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; UnLock(); @@ -164,20 +164,26 @@ void *OpenCLAllocator::CreateImageFromHost(void *data, size_t size, const std::v auto iter = free_list_.lower_bound(size); if (iter != free_list_.end() && (iter->second->size_ >= size) && (iter->second->size_ < (size << shift_factor_))) { auto mem_buf = iter->second; - free_list_.erase(iter); - allocated_list_[mem_buf->host_ptr_] = mem_buf; - UnLock(); - MS_LOG(DEBUG) << "Malloc Image2D from free list. size: " << mem_buf->size_ << ", host addr: " << mem_buf->host_ptr_ - << ", device addr: " << mem_buf->device_ptr_; - return mem_buf->host_ptr_; + bool is_match{mem_buf->img_size.size() == img_size.size()}; + for (int i = 0; i < img_size.size() && is_match; ++i) { + is_match = img_size[i] == mem_buf->img_size[i]; + } + if (is_match) { + free_list_.erase(iter); + allocated_list_[mem_buf->host_ptr_] = mem_buf; + UnLock(); + MS_LOG(DEBUG) << "Malloc Image2D from free list. size: " << mem_buf->size_ + << ", host addr: " << mem_buf->host_ptr_ << ", device addr: " << mem_buf->device_ptr_; + return mem_buf->host_ptr_; + } } void *host_ptr = nullptr; void *device_ptr = nullptr; cl_int ret = CL_SUCCESS; // CL_HALF_FLOAT, CL_FLOAT cl::ImageFormat image_format(CL_RGBA, img_size[2]); - cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, image_format, - img_size[0], img_size[1], 0, data, &ret); + cl::Image2D *buffer = new cl::Image2D(*ocl_runtime->Context(), CL_MEM_READ_ONLY | CL_MEM_COPY_HOST_PTR, + image_format, img_size[0], img_size[1], 0, data, &ret); if (ret != CL_SUCCESS) { MS_LOG(ERROR) << "Create OpenCL Image2D failed! (ERROR CODE: " << ret << ")"; UnLock(); @@ -372,4 +378,3 @@ int OpenCLAllocator::GetImageSize(void *host_ptr, std::vector* img_size) } } // namespace mindspore::lite::opencl - diff --git a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc index 7c039d47c..0cb6156ae 100644 --- a/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc +++ b/mindspore/lite/test/ut/src/runtime/kernel/opencl/concat_tests.cc @@ -21,7 +21,6 @@ #include "mindspore/lite/src/runtime/kernel/opencl/subgraph_opencl_kernel.h" #include "mindspore/lite/src/runtime/kernel/opencl/kernel/concat.h" - int DivideRoundUp(int n, int div) { int q = n / div; return n % div == 0 ? q : q + 1; @@ -77,15 +76,26 @@ void ConcatComputeByCPU_3input_dim4_axis3(float *input0, float *input1, float *i postion = i * output_shape[1] * output_shape[2] * output_shape[3] + j * output_shape[2] * output_shape[3] + k * output_shape[3]; for (int w = 0; w < output_shape[3]; w++) { - if (w < input_shape0[3] + input_shape1[3]) { - output[postion++] = (w < input_shape0[3]) ? input0[index0++] : input1[index1++]; + if (w < input_shape0[3]) { + int align = DivideRoundUp(input_shape0[3], 4) * 4; + index0 = i * input_shape0[1] * input_shape0[2] * align + j * input_shape0[2] * align + k * align + w; + output[postion++] = input0[index0]; + } else if (w >= input_shape0[3] && w < (input_shape0[3] + input_shape1[3])) { + int align = DivideRoundUp(input_shape1[3], 4) * 4; + index1 = i * input_shape1[1] * input_shape1[2] * align + j * input_shape1[2] * align + k * align + w - + input_shape0[3]; + output[postion++] = input1[index1]; } else if ((input_shape0[3] + input_shape1[3]) <= w && w < (input_shape0[3] + input_shape1[3] + input_shape2[3])) { - output[postion++] = input2[index2++]; + int align = DivideRoundUp(input_shape2[3], 4) * 4; + index2 = i * input_shape2[1] * input_shape2[2] * align + j * input_shape2[2] * align + k * align + w - + input_shape0[3] - input_shape1[3]; + output[postion++] = input2[index2]; } else { - for (int ind = input_shape0[3] + input_shape1[3]; ind < output_shape[3]; ind++) { + for (int ind = input_shape0[3] + input_shape1[3] + input_shape2[3]; ind < output_shape[3]; ind++) { output[postion++] = 0; } + break; } } } @@ -96,18 +106,31 @@ void ConcatComputeByCPU_3input_dim4_axis3(float *input0, float *input1, float *i namespace mindspore { class TestConcatOpenCL : public mindspore::Common { public: - TestConcatOpenCL(){} + TestConcatOpenCL() {} }; + +template +void CompareOutputData1(T *output_data, T *correct_data, int size, float err_bound) { + for (size_t i = 0; i < size; i++) { + T abs = fabs(output_data[i] - correct_data[i]); + // printf("i=%d %.3f %.3f\n", i, output_data[i], correct_data[i]); + ASSERT_LE(abs, err_bound); + } +} + TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { MS_LOG(INFO) << "begin test"; auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); ocl_runtime->Init(); + auto allocator = ocl_runtime->GetAllocator(); MS_LOG(INFO) << "init tensors"; - constexpr int INPUT_NUM = 3; - std::array, INPUT_NUM> input_shapes = { - std::vector{1, 240, 240, 16}, std::vector{1, 240, 240, 16}, std::vector{1, 240, 240, 64}}; - std::vector output_shape = {1, 240, 240, 96}; + constexpr int INPUT_NUM = 2; + // std::array, INPUT_NUM> input_shapes = { + // std::vector{1, 120, 120, 16}, std::vector{1, 120, 120, 16},std::vector{1, 120, 120, 96}}; + std::array, INPUT_NUM> input_shapes = {std::vector{1, 32, 512, 48}, + std::vector{1, 32, 512, 48}}; + std::vector output_shape = {1, 32, 512, 96}; output_shape[3] = DivideRoundUp(output_shape[3], 4) * 4; auto data_type = kNumberTypeFloat32; auto tensor_type = schema::NodeType_ValueNode; @@ -118,32 +141,30 @@ TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { auto *output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC, tensor_type); std::vector outputs{output_tensor}; std::cout << "input_shapes size=: " << input_shapes.size() << std::endl; - MS_LOG(INFO) << "initialize tensors"; + + std::cout << "initialize tensors"; auto param = new ConcatParameter(); param->axis_ = 3; auto *concat_kernel = new kernel::ConcatOpenCLKernel(reinterpret_cast(param), inputs, outputs); concat_kernel->Init(); - MS_LOG(INFO) << "initialize sub_graph"; std::vector kernels{concat_kernel}; auto *sub_graph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); + // to do allocate memory for inputs and outputs + for (auto &input_tensor : inputs) { + input_tensor->MallocData(allocator); + } sub_graph->Init(); - + unsigned int seed = 123; MS_LOG(INFO) << "initialize input data"; - srand(time(NULL)); for (auto &input_tensor : inputs) { auto input_data = reinterpret_cast(input_tensor->Data()); - static unsigned int seed = 123; for (int i = 0; i < input_tensor->ElementsNum(); ++i) { input_data[i] = static_cast(rand_r(&seed) % 10 + 1); } - printf("\n"); } - MS_LOG(INFO) << "==================output data================"; - sub_graph->Run(); - auto *output_data_gpu = reinterpret_cast(output_tensor->Data()); - printf("\n"); + // compute the result for CPU auto *input_data0 = reinterpret_cast(inputs[0]->Data()); auto *input_data1 = reinterpret_cast(inputs[1]->Data()); std::vector output_data_cpu(output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3]); @@ -156,8 +177,10 @@ TEST_F(TestConcatOpenCL, ConcatFp32_2input_dim4_axis3) { ConcatComputeByCPU_3input_dim4_axis3(input_data0, input_data1, input_data2, output_data_cpu.data(), input_shapes[0], input_shapes[1], input_shapes[2], output_shape, param->axis_); } - printf("\n"); - CompareOutputData(output_data_gpu, output_data_cpu.data(), output_tensor->ElementsNum(), 0.00001); - MS_LOG(INFO) << "Testconcat passed"; + + std::cout << "==================output data================" << std::endl; + sub_graph->Run(); + auto *output_data_gpu = reinterpret_cast(output_tensor->Data()); + CompareOutputData1(output_data_gpu, output_data_cpu.data(), output_tensor->ElementsNum(), 0.00001); } } // namespace mindspore -- GitLab