提交 e81515d3 编写于 作者: Y yangfei

repair bug of softmax op kernel

上级 82f9487b
...@@ -64,6 +64,17 @@ class CLHelper { ...@@ -64,6 +64,17 @@ class CLHelper {
auto work_size_2 = n * h; auto work_size_2 = n * h;
return {work_size_0, work_size_1, work_size_2};
}else if(image_dim.size()==2){
auto image_width = image.ImageWidth();
auto work_size_0 = image_width / image_dim[1];
auto work_size_1 = image_dim[1];
auto work_size_2 = image_dim[0];
return {work_size_0, work_size_1, work_size_2}; return {work_size_0, work_size_1, work_size_2};
} }
PADDLE_MOBILE_THROW_EXCEPTION("not support this dim, need imp"); PADDLE_MOBILE_THROW_EXCEPTION("not support this dim, need imp");
......
...@@ -37,7 +37,7 @@ limitations under the License. */ ...@@ -37,7 +37,7 @@ limitations under the License. */
#include "framework/cl/cl_image.h" #include "framework/cl/cl_image.h"
#endif #endif
int debug_to = 4; int debug_to = 115;
namespace paddle_mobile { namespace paddle_mobile {
namespace framework { namespace framework {
......
...@@ -20,23 +20,23 @@ namespace operators { ...@@ -20,23 +20,23 @@ namespace operators {
template <> template <>
bool ReluKernel<GPU_CL, float>::Init(ReluParam<GPU_CL>* param) { bool ReluKernel<GPU_CL, float>::Init(ReluParam<GPU_CL>* param) {
this->cl_helper_.AddKernel("relu", "relu.cl"); // this->cl_helper_.AddKernel("relu", "relu.cl");
return true; return true;
} }
template <> template <>
void ReluKernel<GPU_CL, float>::Compute(const ReluParam<GPU_CL>& param) { void ReluKernel<GPU_CL, float>::Compute(const ReluParam<GPU_CL>& param) {
auto kernel = this->cl_helper_.KernelAt(0); // auto kernel = this->cl_helper_.KernelAt(0);
const auto* input = param.InputX(); // const auto* input = param.InputX();
auto* output = param.Out(); // auto* output = param.Out();
auto default_work_size = this->cl_helper_.DefaultWorkSize(*output); // auto default_work_size = this->cl_helper_.DefaultWorkSize(*output);
auto inputImage = input->GetCLImage(); // auto inputImage = input->GetCLImage();
auto outputImage = output->GetCLImage(); // auto outputImage = output->GetCLImage();
clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage); // clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage);
clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage); // clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage);
const size_t work_size[2] = {input->ImageWidth(), input->ImageHeight()}; // const size_t work_size[2] = {input->ImageWidth(), input->ImageHeight()};
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, // clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
work_size, NULL, 0, NULL, NULL); // work_size, NULL, 0, NULL, NULL);
} }
template class ReluKernel<GPU_CL, float>; template class ReluKernel<GPU_CL, float>;
......
...@@ -36,11 +36,14 @@ void SoftmaxKernel<GPU_CL, float>::Compute(const SoftmaxParam<GPU_CL> &param) { ...@@ -36,11 +36,14 @@ void SoftmaxKernel<GPU_CL, float>::Compute(const SoftmaxParam<GPU_CL> &param) {
clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage); clSetKernelArg(kernel, 0, sizeof(cl_mem), &inputImage);
clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage); clSetKernelArg(kernel, 1, sizeof(cl_mem), &outputImage);
const auto &inputDim = input->dims(); const auto &inputDim = input->dims();
int dims[4] = {inputDim[0], inputDim[1], inputDim[2], inputDim[3]}; int dims[4] = {1, 1, 1, 1};
clSetKernelArg(kernel, 2, sizeof(int), dims); for (int i = 0; i < inputDim.size(); i++) {
clSetKernelArg(kernel, 3, sizeof(int), dims + 1); dims[4-inputDim.size()+i] = inputDim[i];
clSetKernelArg(kernel, 4, sizeof(int), dims + 2); }
clSetKernelArg(kernel, 5, sizeof(int), dims + 3); clSetKernelArg(kernel, 2, sizeof(int), &dims);
clSetKernelArg(kernel, 3, sizeof(int), &dims[1]);
clSetKernelArg(kernel, 4, sizeof(int), &dims[2]);
clSetKernelArg(kernel, 5, sizeof(int), &dims[3]);
clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL, clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 3, NULL,
default_work_size.data(), NULL, 0, NULL, NULL); default_work_size.data(), NULL, 0, NULL, NULL);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册