提交 8ab8d030 编写于 作者: C chenzomi

opencl: image2d for pooling

上级 2b5b35ea
__kernel void MaxPooling2d(__global float4 *input, __global float4 *output, const int4 input_shape, __kernel void MaxPooling2d_BUF(__global float4 *input, __global float4 *output, const int4 input_shape,
const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) { const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) {
// axis to dst tensor coordinate // axis to dst tensor coordinate
int X = get_global_id(0); int X = get_global_id(0);
int Y = get_global_id(1); int Y = get_global_id(1);
...@@ -31,38 +31,37 @@ __kernel void MaxPooling2d(__global float4 *input, __global float4 *output, cons ...@@ -31,38 +31,37 @@ __kernel void MaxPooling2d(__global float4 *input, __global float4 *output, cons
output[(output_shape.y * X + Y) * output_shape.w + Z] = maximum; output[(output_shape.y * X + Y) * output_shape.w + Z] = maximum;
} }
// __constant sampler_t sample_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST; __constant sampler_t sample_none = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_NONE | CLK_FILTER_NEAREST;
//__kernel void MaxPooling2dImage2d(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape, __kernel void MaxPooling2d_IMG(__read_only image2d_t input, __write_only image2d_t output, const int4 input_shape,
// const int4 output_shape, const int2 stride, const int2 kernel_size, const int4 output_shape, const int2 stride, const int2 kernel_size, const int2 padding) {
// const int2 padding) { // axis to dst tensor coordinate
// // axis to dst tensor coordinate int X = get_global_id(0);
// int X = get_global_id(0); int Y = get_global_id(1);
// int Y = get_global_id(1); int Z = get_global_id(2);
// int Z = get_global_id(2);
// // boundary check
// // boundary check if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) {
// if (X >= output_shape.x || Y >= output_shape.y || Z >= output_shape.w) { return;
// return; }
// }
// float4 maximum = (float4)(-10000.0f);
// float4 maximum = (float4)(-10000.0f); int xs = X * stride.x + padding.x;
// int xs = X * stride.x + padding.x; int ys = Y * stride.y + padding.y;
// int ys = Y * stride.y + padding.y;
// for (int kx = 0; kx < kernel_size.x; ++kx) {
// for (int ky = 0; ky < kernel_size.y; ++ky) { int x_c = xs + kx;
// int y_c = ys + ky; if (x_c < 0 || x_c >= input_shape.x) {
// if (y_c < 0 || y_c >= input_shape.y) { continue;
// continue; }
// } for (int ky = 0; ky < kernel_size.y; ++ky) {
// for (int kx = 0; kx < kernel_size.x; ++kx) { int y_c = ys + ky;
// int x_c = xs + kx; if (y_c < 0 || y_c >= input_shape.y) {
// if (x_c < 0 || x_c >= input_shape.x) { continue;
// continue; }
// } float4 src = read_imagef(input, sample_none, (int2)(x_c, y_c * input_shape.w + Z));
// float4 src = read_imagef(input, sample_none, (int2)(x_c, y_c * input_shape.w + Z)); maximum = max(src, maximum);
// maximum = max(src, maximum); }
// } }
// } write_imagef(output, (int2)(X, Y * output_shape.w + Z), maximum);
// write_imagef(output, (int2)(X, Y * output_shape.w + Z), maximum); }
//}
\ No newline at end of file
...@@ -13,13 +13,14 @@ ...@@ -13,13 +13,14 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "src/runtime/kernel/opencl/kernel/arithmetic.h"
#include <set> #include <set>
#include <vector> #include <vector>
#include <string> #include <string>
#include "schema/model_generated.h" #include "schema/model_generated.h"
#include "src/kernel_registry.h" #include "src/kernel_registry.h"
#include "src/runtime/kernel/opencl/utils.h" #include "src/runtime/kernel/opencl/utils.h"
#include "src/runtime/kernel/opencl/kernel/arithmetic.h"
#ifndef PROGRAM_WITH_IL #ifndef PROGRAM_WITH_IL
#include "src/runtime/kernel/opencl/cl/fp32/arithmetic_buffer.cl.inc" #include "src/runtime/kernel/opencl/cl/fp32/arithmetic_buffer.cl.inc"
#include "src/runtime/kernel/opencl/cl/fp32/arithmetic_image2d.cl.inc" #include "src/runtime/kernel/opencl/cl/fp32/arithmetic_image2d.cl.inc"
...@@ -41,8 +42,8 @@ std::vector<size_t> ArithmeticOpenCLKernel::InitGlobalSize() const { ...@@ -41,8 +42,8 @@ std::vector<size_t> ArithmeticOpenCLKernel::InitGlobalSize() const {
void ArithmeticOpenCLKernel::Image2dGetWorkGroupSize() { void ArithmeticOpenCLKernel::Image2dGetWorkGroupSize() {
global_size_ = InitGlobalSize(); global_size_ = InitGlobalSize();
int max_work_group_size = runtime_->GetKernelMaxWorkGroupSize(kernel_(), (*runtime_->Device())()); int max_work_group_size = runtime_->GetKernelMaxWorkGroupSize(kernel_(), (*runtime_->Device())());
local_size_ = GetLocalSize(global_size_, max_work_group_size); local_size_ = GetCommonLocalSize(global_size_, max_work_group_size);
global_size_ = GetGlobalSize(local_size_, global_size_); global_size_ = GetCommonGlobalSize(local_size_, global_size_);
} }
void ArithmeticOpenCLKernel::BufferGetWorkGroupSize() { void ArithmeticOpenCLKernel::BufferGetWorkGroupSize() {
......
...@@ -31,12 +31,11 @@ ...@@ -31,12 +31,11 @@
#endif #endif
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
using mindspore::kernel::KERNEL_ARCH::kGPU; using mindspore::kernel::KERNEL_ARCH::kGPU;
using mindspore::lite::KernelRegistrar; using mindspore::lite::KernelRegistrar;
using mindspore::lite::RET_ERROR; using mindspore::lite::RET_ERROR;
using mindspore::lite::RET_OK; using mindspore::lite::RET_OK;
using mindspore::schema::PrimitiveType_DepthwiseConv2D;
namespace mindspore::kernel { namespace mindspore::kernel {
...@@ -117,11 +116,9 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() { ...@@ -117,11 +116,9 @@ int DepthwiseConv2dOpenCLKernel::InitBuffer() {
return RET_OK; return RET_OK;
} }
int DepthwiseConv2dOpenCLKernel::ReSize() { int DepthwiseConv2dOpenCLKernel::ReSize() { return RET_OK; }
return RET_OK;
}
int DepthwiseConv2dOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t>* img_size) { int DepthwiseConv2dOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM);
size_t im_dst_x, im_dst_y; size_t im_dst_x, im_dst_y;
if (inputs_[0]->GetFormat() == schema::Format_NHWC4) { if (inputs_[0]->GetFormat() == schema::Format_NHWC4) {
...@@ -141,16 +138,18 @@ int DepthwiseConv2dOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t>* i ...@@ -141,16 +138,18 @@ int DepthwiseConv2dOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t>* i
*img_size = vec; *img_size = vec;
return RET_OK; return RET_OK;
} }
int DepthwiseConv2dOpenCLKernel::GetGlobalSize(size_t idx, std::vector<size_t>* global_size) {
int DepthwiseConv2dOpenCLKernel::GetGlobalSize(size_t idx, std::vector<size_t> *global_size) {
size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM);
std::vector <size_t> global = {(size_t) outputs_[0]->Width(), (size_t) outputs_[0]->Height(), CO4}; std::vector<size_t> global = {(size_t)outputs_[0]->Width(), (size_t)outputs_[0]->Height(), CO4};
*global_size = std::move(global); *global_size = std::move(global);
return RET_OK; return RET_OK;
} }
int DepthwiseConv2dOpenCLKernel::GetLocalSize(size_t idx, const std::vector<size_t>& global_size,
std::vector<size_t>* local_size) { int DepthwiseConv2dOpenCLKernel::GetLocalSize(size_t idx, const std::vector<size_t> &global_size,
std::vector<size_t> *local_size) {
size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM);
std::vector <size_t> local = {1, 1, CO4}; std::vector<size_t> local = {1, 1, CO4};
*local_size = std::move(local); *local_size = std::move(local);
return RET_OK; return RET_OK;
} }
...@@ -161,8 +160,8 @@ int DepthwiseConv2dOpenCLKernel::Run() { ...@@ -161,8 +160,8 @@ int DepthwiseConv2dOpenCLKernel::Run() {
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM); size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM);
size_t CI4 = UP_DIV(inputs_[0]->Channel(), C4NUM); size_t CI4 = UP_DIV(inputs_[0]->Channel(), C4NUM);
std::vector <size_t> global = {(size_t) outputs_[0]->Width(), (size_t) outputs_[0]->Height(), CO4}; std::vector<size_t> global = {(size_t)outputs_[0]->Width(), (size_t)outputs_[0]->Height(), CO4};
std::vector <size_t> local; std::vector<size_t> local;
GetLocalSize(0, global, &local); GetLocalSize(0, global, &local);
float relu_clip1 = 6.0; float relu_clip1 = 6.0;
......
...@@ -28,11 +28,10 @@ namespace mindspore::kernel { ...@@ -28,11 +28,10 @@ namespace mindspore::kernel {
class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { class DepthwiseConv2dOpenCLKernel : public OpenCLKernel {
public: public:
explicit DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, explicit DepthwiseConv2dOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs) const std::vector<lite::tensor::Tensor *> &outputs)
: OpenCLKernel(parameter, inputs, outputs), : OpenCLKernel(parameter, inputs, outputs), packed_weight_(nullptr), bias_data_(nullptr), kernel_(nullptr) {}
packed_weight_(nullptr), bias_data_(nullptr), kernel_(nullptr) {}
~DepthwiseConv2dOpenCLKernel() override {}; ~DepthwiseConv2dOpenCLKernel() override{};
int Init() override; int Init() override;
...@@ -42,20 +41,16 @@ class DepthwiseConv2dOpenCLKernel : public OpenCLKernel { ...@@ -42,20 +41,16 @@ class DepthwiseConv2dOpenCLKernel : public OpenCLKernel {
int InitBuffer(); int InitBuffer();
int GetImageSize(size_t idx, std::vector<size_t>* img_size) override; int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
int GetGlobalSize(size_t idx, std::vector<size_t>* global_size) override; int GetGlobalSize(size_t idx, std::vector<size_t> *global_size) override;
int GetLocalSize(size_t idx, const std::vector<size_t>& global_size, int GetLocalSize(size_t idx, const std::vector<size_t> &global_size, std::vector<size_t> *local_size) override;
std::vector<size_t>* local_size) override;
private: private:
FLOAT_t *packed_weight_; FLOAT_t *packed_weight_;
FLOAT_t *bias_data_; FLOAT_t *bias_data_;
cl::Kernel kernel_; cl::Kernel kernel_;
enum class MEM_TYPE { enum class MEM_TYPE { BUF, IMG } mem_type_{MEM_TYPE::IMG};
BUF, IMG
} mem_type_{MEM_TYPE::IMG};
}; };
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_ #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_DEPTHWISE_H_
...@@ -64,12 +64,18 @@ int PoolingOpenCLKernel::Init() { ...@@ -64,12 +64,18 @@ int PoolingOpenCLKernel::Init() {
#ifdef PROGRAM_WITH_IL #ifdef PROGRAM_WITH_IL
ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name); ocl_runtime->CreateKernelFromIL(kernel_(), kernel_name);
#else #else
if (mem_type_ == MEM_TYPE::BUF) {
kernel_name += "_BUF";
} else {
kernel_name += "_IMG";
}
std::set<std::string> build_options; std::set<std::string> build_options;
ocl_runtime->LoadSource(program_name, source); ocl_runtime->LoadSource(program_name, source);
ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options); ocl_runtime->BuildKernel(kernel_, program_name, kernel_name, build_options);
#endif #endif
outputs_[0]->SetFormat(schema::Format_NHWC4); outputs_[0]->SetFormat(schema::Format_NHWC4);
MS_LOG(DEBUG) << kernel_name << " Init Done!"; MS_LOG(DEBUG) << kernel_name << " Init Done!";
return RET_OK; return RET_OK;
} }
...@@ -81,8 +87,30 @@ std::vector<size_t> PoolingOpenCLKernel::InitGlobalSize() const { ...@@ -81,8 +87,30 @@ std::vector<size_t> PoolingOpenCLKernel::InitGlobalSize() const {
return global; return global;
} }
int PoolingOpenCLKernel::InitBuffer() { return 0; } int PoolingOpenCLKernel::GetImageSize(size_t idx, std::vector<size_t> *img_size) {
int PoolingOpenCLKernel::ReSize() { return 0; } size_t CO4 = UP_DIV(outputs_[0]->Channel(), C4NUM);
size_t im_dst_x, im_dst_y;
if (inputs_[0]->GetFormat() == schema::Format_NHWC4) {
im_dst_x = outputs_[0]->Height();
im_dst_y = outputs_[0]->Width() * CO4;
} else {
im_dst_y = outputs_[0]->Width();
im_dst_x = outputs_[0]->Height() * CO4;
}
#ifdef ENABLE_FP16
size_t img_dtype = CL_HALF_FLOAT;
#else
size_t img_dtype = CL_FLOAT;
#endif
img_size->clear();
std::vector<size_t> vec{im_dst_x, im_dst_y, img_dtype};
*img_size = vec;
return RET_OK;
}
int PoolingOpenCLKernel::InitBuffer() { return RET_OK; }
int PoolingOpenCLKernel::ReSize() { return RET_OK; }
int PoolingOpenCLKernel::Run() { int PoolingOpenCLKernel::Run() {
MS_LOG(DEBUG) << this->Name() << " Running!"; MS_LOG(DEBUG) << this->Name() << " Running!";
...@@ -110,12 +138,11 @@ int PoolingOpenCLKernel::Run() { ...@@ -110,12 +138,11 @@ int PoolingOpenCLKernel::Run() {
std::vector<size_t> local_size; std::vector<size_t> local_size;
std::vector<size_t> global_size = InitGlobalSize(); std::vector<size_t> global_size = InitGlobalSize();
int max_work_group_size = ocl_runtime->GetKernelMaxWorkGroupSize(kernel_(), (*ocl_runtime->Device())()); int max_work_group_size = ocl_runtime->GetKernelMaxWorkGroupSize(kernel_(), (*ocl_runtime->Device())());
local_size = GetLocalSize(global_size, max_work_group_size); local_size = GetCommonLocalSize(global_size, max_work_group_size);
global_size = GetGlobalSize(local_size, global_size); global_size = GetCommonGlobalSize(local_size, global_size);
// run opengl kernel // run opengl kernel
ocl_runtime->RunKernel(kernel_, global_size, local_size, nullptr); ocl_runtime->RunKernel(kernel_, global_size, local_size, nullptr);
return RET_OK; return RET_OK;
} }
......
...@@ -25,11 +25,11 @@ ...@@ -25,11 +25,11 @@
namespace mindspore::kernel { namespace mindspore::kernel {
class PoolingOpenCLKernel : public LiteKernel { class PoolingOpenCLKernel : public OpenCLKernel {
public: public:
explicit PoolingOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs, explicit PoolingOpenCLKernel(OpParameter *parameter, const std::vector<lite::tensor::Tensor *> &inputs,
const std::vector<lite::tensor::Tensor *> &outputs) const std::vector<lite::tensor::Tensor *> &outputs)
: LiteKernel(parameter, inputs, outputs) { : OpenCLKernel(parameter, inputs, outputs) {
parameter_ = reinterpret_cast<PoolingParameter *>(parameter); parameter_ = reinterpret_cast<PoolingParameter *>(parameter);
} }
~PoolingOpenCLKernel() override{}; ~PoolingOpenCLKernel() override{};
...@@ -38,10 +38,11 @@ class PoolingOpenCLKernel : public LiteKernel { ...@@ -38,10 +38,11 @@ class PoolingOpenCLKernel : public LiteKernel {
int ReSize() override; int ReSize() override;
int Run() override; int Run() override;
int InitBuffer(); int InitBuffer();
int GetImageSize(size_t idx, std::vector<size_t> *img_size) override;
private: private:
std::vector<size_t> InitGlobalSize() const; std::vector<size_t> InitGlobalSize() const;
enum class MEM_TYPE { BUF, IMG } mem_type_{MEM_TYPE::IMG};
PoolingParameter *parameter_; PoolingParameter *parameter_;
cl::Kernel kernel_; cl::Kernel kernel_;
}; };
...@@ -49,4 +50,3 @@ class PoolingOpenCLKernel : public LiteKernel { ...@@ -49,4 +50,3 @@ class PoolingOpenCLKernel : public LiteKernel {
} // namespace mindspore::kernel } // namespace mindspore::kernel
#endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_POOLING_H_ #endif // MINDSPORE_LITE_SRC_BACKEND_OPENCL_POOLING_H_
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
std::vector<size_t> GetGlobalSize(const std::vector<size_t> &local, const std::vector<size_t> &global) { std::vector<size_t> GetCommonGlobalSize(const std::vector<size_t> &local, const std::vector<size_t> &global) {
std::vector<size_t> result(3, 1); std::vector<size_t> result(3, 1);
for (int i = 0; i < 3; ++i) { for (int i = 0; i < 3; ++i) {
result[i] = AlignByN(global[i], local[i]); result[i] = AlignByN(global[i], local[i]);
...@@ -30,7 +30,7 @@ std::vector<size_t> GetGlobalSize(const std::vector<size_t> &local, const std::v ...@@ -30,7 +30,7 @@ std::vector<size_t> GetGlobalSize(const std::vector<size_t> &local, const std::v
return result; return result;
} }
std::vector<size_t> GetLocalSize(const std::vector<size_t> &global, int max_size) { std::vector<size_t> GetCommonLocalSize(const std::vector<size_t> &global, int max_size) {
size_t wg_z = GetBiggestDividerWithPriority(global[2], 8); size_t wg_z = GetBiggestDividerWithPriority(global[2], 8);
size_t wg_xy_size = max_size / wg_z; size_t wg_xy_size = max_size / wg_z;
size_t wg_x = std::min(DivideRoundUp(global[0], 2), wg_xy_size); size_t wg_x = std::min(DivideRoundUp(global[0], 2), wg_xy_size);
......
...@@ -75,10 +75,10 @@ T AlignByN(T number, N n) { ...@@ -75,10 +75,10 @@ T AlignByN(T number, N n) {
} }
// GetGlobalSize // GetGlobalSize
std::vector<size_t> GetGlobalSize(const std::vector<size_t> &local, const std::vector<size_t> &global); std::vector<size_t> GetCommonGlobalSize(const std::vector<size_t> &local, const std::vector<size_t> &global);
// GetLocalSize // GetLocalSize
std::vector<size_t> GetLocalSize(const std::vector<size_t> &global, int max_size); std::vector<size_t> GetCommonLocalSize(const std::vector<size_t> &global, int max_size);
std::string CLErrorCode(cl_int error_code); std::string CLErrorCode(cl_int error_code);
......
...@@ -43,33 +43,42 @@ TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) { ...@@ -43,33 +43,42 @@ TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) {
MS_LOG(INFO) << "ocl runtime"; MS_LOG(INFO) << "ocl runtime";
auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance(); auto ocl_runtime = lite::opencl::OpenCLRuntime::GetInstance();
ocl_runtime->Init(); ocl_runtime->Init();
auto allocator = ocl_runtime->GetAllocator();
MS_LOG(INFO) << "PoolingParameter"; MS_LOG(INFO) << "PoolingParameter";
auto param = new PoolingParameter; auto param = new PoolingParameter;
InitParameter(param); InitParameter(param);
// define tensor // define tensor
MS_LOG(INFO) << "define tensor"; MS_LOG(INFO) << "define tensor1";
std::vector<int> input_shape = {1, 16, 256, 192}; std::vector<int> input_shape = {1, 16, 256, 192};
std::vector<int> output_shape = {1, 8, 128, 192}; std::vector<int> output_shape = {1, 8, 128, 192};
auto data_type = kNumberTypeFloat32; auto data_type = kNumberTypeFloat32;
auto tensorType = schema::NodeType_ValueNode; auto tensorType = schema::NodeType_ValueNode;
MS_LOG(INFO) << "define tensor2";
auto input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensorType); auto input_tensor = new lite::tensor::Tensor(data_type, input_shape, schema::Format_NHWC4, tensorType);
auto output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensorType); auto output_tensor = new lite::tensor::Tensor(data_type, output_shape, schema::Format_NHWC4, tensorType);
MS_LOG(INFO) << "define input";
std::vector<lite::tensor::Tensor *> inputs{input_tensor}; std::vector<lite::tensor::Tensor *> inputs{input_tensor};
std::vector<lite::tensor::Tensor *> outputs{output_tensor}; std::vector<lite::tensor::Tensor *> outputs{output_tensor};
// run // run
MS_LOG(INFO) << "pooling_kernel";
auto *pooling_kernel = new kernel::PoolingOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs); auto *pooling_kernel = new kernel::PoolingOpenCLKernel(reinterpret_cast<OpParameter *>(param), inputs, outputs);
MS_LOG(INFO) << "pooling_kernel init";
pooling_kernel->Init(); pooling_kernel->Init();
std::vector<kernel::LiteKernel *> kernels{pooling_kernel}; std::vector<kernel::LiteKernel *> kernels{pooling_kernel};
inputs[0]->MallocData(allocator);
auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels); auto *pGraph = new kernel::SubGraphOpenCLKernel(inputs, outputs, kernels, kernels, kernels);
MS_LOG(INFO) << "pGraph init";
pGraph->Init(); pGraph->Init();
// load data // load data
MS_LOG(INFO) << "load data"; MS_LOG(INFO) << "load data1";
std::string input_file = "maxpool_in.bin"; std::string input_file = "maxpool_in.bin";
std::string expect_file = "maxpool_out.bin"; std::string expect_file = "maxpool_out.bin";
MS_LOG(INFO) << "load data2";
LoadTestData(input_tensor->Data(), input_tensor->Size(), input_file); LoadTestData(input_tensor->Data(), input_tensor->Size(), input_file);
auto *input_data = reinterpret_cast<float *>(input_tensor->Data()); auto *input_data = reinterpret_cast<float *>(input_tensor->Data());
printf("input[0:10]:"); printf("input[0:10]:");
...@@ -81,6 +90,7 @@ TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) { ...@@ -81,6 +90,7 @@ TEST_F(TestMaxPoolingOpenCL, MaxPool_1_32_512_96) {
pGraph->Run(); pGraph->Run();
MS_LOG(INFO) << "compare result"; MS_LOG(INFO) << "compare result";
std::cout << "compare result" << std::endl;
CompareOutput(output_tensor, expect_file); CompareOutput(output_tensor, expect_file);
} }
......
...@@ -24,9 +24,14 @@ namespace mindspore { ...@@ -24,9 +24,14 @@ namespace mindspore {
void LoadTestData(void *dst, size_t dst_size, const std::string &file_path) { void LoadTestData(void *dst, size_t dst_size, const std::string &file_path) {
if (file_path.empty()) { if (file_path.empty()) {
memset(dst, dst_size, dst_size); memset(dst, 0x00, dst_size);
} else { } else {
memcpy(dst, reinterpret_cast<const void *>(dst_size), dst_size); auto src_data = reinterpret_cast<float *>(mindspore::lite::ReadFile(file_path.c_str(), &dst_size));
if (src_data != nullptr) {
memcpy(dst, src_data, dst_size);
} else {
MS_LOG(ERROR) << "read file empty.";
}
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册