提交 89bb3a0b 编写于 作者: L liuruilong

update softmax reshape kernel

上级 71a1cf75
...@@ -49,34 +49,26 @@ class CLHelper { ...@@ -49,34 +49,26 @@ class CLHelper {
cl_context CLContext() { return scope_->Context(); } cl_context CLContext() { return scope_->Context(); }
std::vector<size_t> DefaultWorkSize(const CLImage &image) { std::vector<size_t> DefaultWorkSize(const CLImage &image) {
if (image.GetImageType() == Invalid) {
PADDLE_MOBILE_THROW_EXCEPTION(" not support image type");
}
// n c h w // n c h w
auto image_dim = image.dims(); auto image_dim = image.dims();
if (image_dim.size() == 4) { if (image_dim.size() == 4) {
auto n = image_dim[0]; auto n = image_dim[0];
auto h = image_dim[2]; auto h = image_dim[2];
auto w = image_dim[3]; auto w = image_dim[3];
auto image_width = image.ImageWidth(); auto image_width = image.ImageWidth();
auto work_size_0 = image_width / w; auto work_size_0 = image_width / w;
auto work_size_1 = w; auto work_size_1 = w;
auto work_size_2 = n * h; auto work_size_2 = n * h;
return {work_size_0, work_size_1, work_size_2}; return {work_size_0, work_size_1, work_size_2};
} else if (image_dim.size() == 2) { } else if (image_dim.size() == 2) {
auto image_width = image.ImageWidth(); return {1, image.ImageWidth(), image.ImageHeight()};
} else if (image_dim.size() == 1) {
auto work_size_0 = image_width / image_dim[1]; return {1, image.ImageWidth(), 1};
auto work_size_1 = image_dim[1];
auto work_size_2 = image_dim[0];
return {work_size_0, work_size_1, work_size_2};
} }
PADDLE_MOBILE_THROW_EXCEPTION("not support this dim, need imp"); PADDLE_MOBILE_THROW_EXCEPTION(" not support this dim, need imp ");
} }
private: private:
......
...@@ -119,6 +119,9 @@ void TensorToCLImage(const Tensor *tensor, CLImage *cl_image, ...@@ -119,6 +119,9 @@ void TensorToCLImage(const Tensor *tensor, CLImage *cl_image,
} }
#ifdef PADDLE_MOBILE_DEBUG #ifdef PADDLE_MOBILE_DEBUG
Print &operator<<(Print &printer, const CLImage &cl_image) { Print &operator<<(Print &printer, const CLImage &cl_image) {
if (cl_image.GetImageType() == Invalid) {
PADDLE_MOBILE_THROW_EXCEPTION(" not support image type");
}
printer << " dims: " << cl_image.dims() << "\n"; printer << " dims: " << cl_image.dims() << "\n";
int stride = cl_image.numel() / 20; int stride = cl_image.numel() / 20;
stride = stride > 0 ? stride : 1; stride = stride > 0 ? stride : 1;
......
...@@ -20,6 +20,7 @@ limitations under the License. */ ...@@ -20,6 +20,7 @@ limitations under the License. */
#include "framework/cl/cl_half.h" #include "framework/cl/cl_half.h"
#include "framework/cl/cl_tool.h" #include "framework/cl/cl_tool.h"
#include "framework/cl/cl_deleter.h"
#include "framework/ddim.h" #include "framework/ddim.h"
#include "framework/tensor.h" #include "framework/tensor.h"
...@@ -88,11 +89,20 @@ class CLImage { ...@@ -88,11 +89,20 @@ class CLImage {
" empty image tensor data shouldn't have value"); " empty image tensor data shouldn't have value");
} }
DLOG << " init empty image "; DLOG << " init empty image ";
InitCLImage(context, command_queue, nullptr, dim); if (tensor_dims_.size() <= 2) {
DLOG << " dim <= 2 folder ~~~~~ ";
InitCLImage2C(context, command_queue, tensor_data_, tensor_dims_);
} else {
DLOG << " dim > 2 norm ~~~~~ ";
InitCLImage(context, command_queue, tensor_data_, tensor_dims_);
}
// InitCLImage(context, command_queue, nullptr, dim);
initialized_ = true; initialized_ = true;
} }
cl_mem GetCLImage() const { return cl_image_; } cl_mem GetCLImage() const { return cl_image_.get(); }
const DDim &ImageDims() const { return image_dims_; } const DDim &ImageDims() const { return image_dims_; }
...@@ -201,12 +211,13 @@ class CLImage { ...@@ -201,12 +211,13 @@ class CLImage {
}; };
cid.buffer = nullptr; cid.buffer = nullptr;
cl_int err; cl_int err;
cl_image_ = clCreateImage( cl_mem cl_image = clCreateImage(
context, CL_MEM_READ_WRITE | (data ? CL_MEM_COPY_HOST_PTR : 0), context, CL_MEM_READ_WRITE | (data ? CL_MEM_COPY_HOST_PTR : 0),
&cf, // const cl_image_format *image_format &cf, // const cl_image_format *image_format
&cid, // const cl_image_desc *image_desc &cid, // const cl_image_desc *image_desc
data, // void *host_ptr data, // void *host_ptr
&err); &err);
cl_image_.reset(cl_image);
if (err != CL_SUCCESS) { if (err != CL_SUCCESS) {
CL_CHECK_ERRORS(err); CL_CHECK_ERRORS(err);
PADDLE_MOBILE_THROW_EXCEPTION(" create image 2d error "); PADDLE_MOBILE_THROW_EXCEPTION(" create image 2d error ");
...@@ -283,7 +294,7 @@ class CLImage { ...@@ -283,7 +294,7 @@ class CLImage {
} }
bool initialized_ = false; bool initialized_ = false;
cl_mem cl_image_; std::unique_ptr<_cl_mem, CLMemDeleter> cl_image_;
size_t image_width_; size_t image_width_;
size_t width_of_one_block_; size_t width_of_one_block_;
size_t height_of_one_block_; size_t height_of_one_block_;
......
...@@ -37,7 +37,7 @@ limitations under the License. */ ...@@ -37,7 +37,7 @@ limitations under the License. */
#include "framework/cl/cl_image.h" #include "framework/cl/cl_image.h"
#endif #endif
int debug_to = 33; int debug_to = 32;
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
......
...@@ -14,6 +14,31 @@ limitations under the License. */ ...@@ -14,6 +14,31 @@ limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void reshape(__read_only image2d_t input,
__write_only image2d_t output,
__private const int d0,
__private const int d1,
__private const int d2,
__private const int d3,
__private const int x0,
__private const int x1,
__private const int x2,
__private const int x3) {
const int x = get_global_id(0);
const int y = get_global_id(1);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 in = read_imageh(input, sampler, (int2)(x, y));
write_imageh(output, (int2)(x, y), in);
}
/*
__kernel void reshape(__read_only image2d_t input, __kernel void reshape(__read_only image2d_t input,
__write_only image2d_t output, __write_only image2d_t output,
__private const int d0, __private const int d0,
...@@ -49,3 +74,5 @@ __kernel void reshape(__read_only image2d_t input, ...@@ -49,3 +74,5 @@ __kernel void reshape(__read_only image2d_t input,
} }
write_imageh(output, (int2)(x, y), r); write_imageh(output, (int2)(x, y), r);
} }
*/
...@@ -14,6 +14,41 @@ limitations under the License. */ ...@@ -14,6 +14,41 @@ limitations under the License. */
#pragma OPENCL EXTENSION cl_khr_fp16 : enable #pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void softmax(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int group
) {
const int out_c = get_global_id(0); // block index
const int out_w = get_global_id(1); // index in one block
const int out_nh = get_global_id(2);
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half maxv = 0.0f;
for (int i = 0; i < group; ++i) {
half4 temp = read_imageh(input_image, sampler, (int2)(i, 0));
maxv = max(maxv, max(temp.x, max(temp.y, max(temp.z, temp.w))));
}
half4 rsum = (half4)(0.0f);
for (int i = 0; i < group; ++i) {
half4 r = read_imageh(input_image, sampler, (int2)(i, 0));
rsum += exp(r - maxv);
}
half sum = rsum.x + rsum.y + rsum.z + rsum.w;
half4 rr = read_imageh(input_image, sampler, (int2)(out_w, out_nh));
half4 result = exp(rr - maxv) / sum;
write_imageh(output_image, (int2)(out_w, out_nh), result);
}
/*
__kernel void softmax(__read_only image2d_t input, __kernel void softmax(__read_only image2d_t input,
__write_only image2d_t output, __write_only image2d_t output,
__private const int d0, __private const int d0,
...@@ -42,3 +77,5 @@ __kernel void softmax(__read_only image2d_t input, ...@@ -42,3 +77,5 @@ __kernel void softmax(__read_only image2d_t input,
write_imageh(output, (int2)(z * d3 + x, y), r); write_imageh(output, (int2)(z * d3 + x, y), r);
} }
*/
...@@ -43,21 +43,21 @@ bool ConvAddBNReluKernel<GPU_CL, float>::Init( ...@@ -43,21 +43,21 @@ bool ConvAddBNReluKernel<GPU_CL, float>::Init(
const int C = mean->numel(); const int C = mean->numel();
for (int j = 0; j < C; ++j) { // for (int j = 0; j < C; ++j) {
DLOG << " mean - " << j << mean->data<float>()[j]; // DLOG << " mean - " << j << mean->data<float>()[j];
} // }
//
for (int j = 0; j < C; ++j) { // for (int j = 0; j < C; ++j) {
DLOG << " variance - " << j << variance->data<float>()[j]; // DLOG << " variance - " << j << variance->data<float>()[j];
} // }
//
for (int j = 0; j < C; ++j) { // for (int j = 0; j < C; ++j) {
DLOG << " scale - " << j << scale->data<float>()[j]; // DLOG << " scale - " << j << scale->data<float>()[j];
} // }
//
for (int j = 0; j < C; ++j) { // for (int j = 0; j < C; ++j) {
DLOG << " bias - " << j << bias->data<float>()[j]; // DLOG << " bias - " << j << bias->data<float>()[j];
} // }
// //
// DLOG << " climage mean: " << *mean; // DLOG << " climage mean: " << *mean;
...@@ -85,21 +85,21 @@ bool ConvAddBNReluKernel<GPU_CL, float>::Init( ...@@ -85,21 +85,21 @@ bool ConvAddBNReluKernel<GPU_CL, float>::Init(
framework::CLImage *new_scale = new framework::CLImage(); framework::CLImage *new_scale = new framework::CLImage();
for (int j = 0; j < C; ++j) { // for (int j = 0; j < C; ++j) {
DLOG << " new scale - " << j << new_scale_ptr[j]; // DLOG << " new scale - " << j << new_scale_ptr[j];
} // }
//
for (int j = 0; j < C; ++j) { // for (int j = 0; j < C; ++j) {
DLOG << " new bias - " << j << new_bias_ptr[j]; // DLOG << " new bias - " << j << new_bias_ptr[j];
} // }
new_scale->SetTensorData(new_scale_ptr, variance->dims()); new_scale->SetTensorData(new_scale_ptr, variance->dims());
new_scale->InitCLImage(this->cl_helper_.CLContext(), new_scale->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue()); cl_helper_.CLCommandQueue());
DLOG << " climage - y bias: " << *(param->Bias()); // DLOG << " climage - y bias: " << *(param->Bias());
//
DLOG << " climage - new scale: " << *new_scale; // DLOG << " climage - new scale: " << *new_scale;
framework::CLImage *new_bias = new framework::CLImage(); framework::CLImage *new_bias = new framework::CLImage();
...@@ -107,9 +107,9 @@ bool ConvAddBNReluKernel<GPU_CL, float>::Init( ...@@ -107,9 +107,9 @@ bool ConvAddBNReluKernel<GPU_CL, float>::Init(
new_bias->InitCLImage(this->cl_helper_.CLContext(), new_bias->InitCLImage(this->cl_helper_.CLContext(),
cl_helper_.CLCommandQueue()); cl_helper_.CLCommandQueue());
DLOG << " climage - new bias: " << *new_bias; // DLOG << " climage - new bias: " << *new_bias;
//
DLOG << " climage - filter: " << *(param->Filter()); // DLOG << " climage - filter: " << *(param->Filter());
param->SetNewScale(new_scale); param->SetNewScale(new_scale);
param->SetNewBias(new_bias); param->SetNewBias(new_bias);
...@@ -237,7 +237,7 @@ void ConvAddBNReluKernel<GPU_CL, float>::Compute( ...@@ -237,7 +237,7 @@ void ConvAddBNReluKernel<GPU_CL, float>::Compute(
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
status = status =
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL); default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
} }
......
...@@ -118,7 +118,7 @@ void ConvAddKernel<GPU_CL, float>::Compute( ...@@ -118,7 +118,7 @@ void ConvAddKernel<GPU_CL, float>::Compute(
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
status = status =
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL); default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
} }
......
...@@ -155,7 +155,7 @@ void ConvKernel<GPU_CL, float>::Compute(const ConvParam<GPU_CL> &param) { ...@@ -155,7 +155,7 @@ void ConvKernel<GPU_CL, float>::Compute(const ConvParam<GPU_CL> &param) {
DLOG << " begin enqueue "; DLOG << " begin enqueue ";
status = status =
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL); default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
DLOG << " end enqueue "; DLOG << " end enqueue ";
......
...@@ -77,7 +77,7 @@ void DepthwiseConvKernel<GPU_CL, float>::Compute( ...@@ -77,7 +77,7 @@ void DepthwiseConvKernel<GPU_CL, float>::Compute(
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
status = status =
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL); default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
......
...@@ -36,9 +36,12 @@ void ReshapeKernel<GPU_CL, float>::Compute(const ReshapeParam<GPU_CL> &param) { ...@@ -36,9 +36,12 @@ void ReshapeKernel<GPU_CL, float>::Compute(const ReshapeParam<GPU_CL> &param) {
const auto &outputDim = output->dims(); const auto &outputDim = output->dims();
int dims[4] = {1, 1, 1, 1}; int dims[4] = {1, 1, 1, 1};
int odims[4] = {1, 1, 1, 1}; int odims[4] = {1, 1, 1, 1};
// 1 1000 1 1
for (int i = 0; i < inputDim.size(); i++) { for (int i = 0; i < inputDim.size(); i++) {
dims[4 - inputDim.size() + i] = inputDim[i]; dims[4 - inputDim.size() + i] = inputDim[i];
} }
// 1 1 1 1000
for (int i = 0; i < outputDim.size(); i++) { for (int i = 0; i < outputDim.size(); i++) {
odims[4 - outputDim.size() + i] = outputDim[i]; odims[4 - outputDim.size() + i] = outputDim[i];
} }
......
...@@ -33,20 +33,41 @@ void SoftmaxKernel<GPU_CL, float>::Compute(const SoftmaxParam<GPU_CL> &param) { ...@@ -33,20 +33,41 @@ void SoftmaxKernel<GPU_CL, float>::Compute(const SoftmaxParam<GPU_CL> &param) {
auto *output = param.Out(); auto *output = param.Out();
auto inputImage = input->GetCLImage(); auto inputImage = input->GetCLImage();
auto outputImage = output->GetCLImage(); auto outputImage = output->GetCLImage();
clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage);
clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage); DLOG << " softmax - output image dim " << output->ImageDims();
const auto &inputDim = input->dims(); DLOG << " softmax - output image tensor dim " << output->dims();
int dims[4] = {1, 1, 1, 1};
for (int i = 0; i < inputDim.size(); i++) { int group = output->ImageWidth();
dims[4 - inputDim.size() + i] = inputDim[i];
} cl_int status;
clSetKernelArg(kernel, 2, sizeof(int), &dims);
clSetKernelArg(kernel, 3, sizeof(int), &dims[1]); status = clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage);
clSetKernelArg(kernel, 4, sizeof(int), &dims[2]); CL_CHECK_ERRORS(status);
clSetKernelArg(kernel, 5, sizeof(int), &dims[3]);
status = clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage);
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, CL_CHECK_ERRORS(status);
status = clSetKernelArg(kernel, 2, sizeof(int), &group);
CL_CHECK_ERRORS(status);
// const auto &inputDim = input->dims();
//
// int dims[4] = {1, 1, 1, 1};
//
// for (int i = 0; i < inputDim.size(); i++) {
// dims[4 - inputDim.size() + i] = inputDim[i];
// }
//
// clSetKernelArg(kernel, 2, sizeof(int), &dims);
// clSetKernelArg(kernel, 3, sizeof(int), &dims[1]);
// clSetKernelArg(kernel, 4, sizeof(int), &dims[2]);
// clSetKernelArg(kernel, 5, sizeof(int), &dims[3]);
DLOG << "default_work_size: " << default_work_size;
status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, default_work_size.size(), NULL,
default_work_size.data(), NULL, 0, NULL, NULL); default_work_size.data(), NULL, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
} }
template class SoftmaxKernel<GPU_CL, float>; template class SoftmaxKernel<GPU_CL, float>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册