提交 2da7cba9 编写于 作者: xiebaiyuan's avatar xiebaiyuan

[OPENCL] softmax with test, test=develop

上级 500dbb62
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cl_common.h>
__kernel void softmax(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int out_W) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
const int in_c = out_c;
const int in_w = out_w;
const int in_nh = out_nh;
int2 input_pos;
int2 output_pos;
input_pos.x = in_c * out_W + in_w;
input_pos.y = in_nh;
output_pos.x = out_c * out_W + out_w;
output_pos.y = out_nh;
const sampler_t sampler =
CLK_NORMALIZED_COORDS_TRUE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
CL_DTYPE4 input_max = 0.0f;
CL_DTYPE4 input_tmp;
for (int i = 0; i < out_W; i++) {
input_tmp = READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, sampler, (int2)(in_c * out_W + i, in_nh));
input_max = max(input_max, input_tmp);
}
CL_DTYPE4 sum = (CL_DTYPE4)0.0f;
for (int i = 0; i < out_W; i++) {
input_tmp = READ_IMG_TYPE(
CL_DTYPE_CHAR, input_image, sampler, (int2)(in_c * out_W + i, in_nh));
sum += exp(input_tmp - input_max);
}
CL_DTYPE4 input =
READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler, input_pos);
CL_DTYPE4 output = exp(input - input_max) / sum;
WRITE_IMG_TYPE(CL_DTYPE_CHAR, output_image, output_pos, output);
}
...@@ -36,6 +36,7 @@ add_kernel(pad2d_opencl OPENCL basic SRCS pad2d_image_compute.cc DEPS ${cl_kerne ...@@ -36,6 +36,7 @@ add_kernel(pad2d_opencl OPENCL basic SRCS pad2d_image_compute.cc DEPS ${cl_kerne
add_kernel(box_coder_opencl OPENCL basic SRCS box_coder_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(box_coder_opencl OPENCL basic SRCS box_coder_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(pixel_shuffle_opencl OPENCL basic SRCS pixel_shuffle_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(pixel_shuffle_opencl OPENCL basic SRCS pixel_shuffle_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(expand_opencl OPENCL basic SRCS expand_image_compute.cc DEPS ${cl_kernel_deps}) add_kernel(expand_opencl OPENCL basic SRCS expand_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(softmax_opencl OPENCL basic SRCS softmax_image_compute.cc DEPS ${cl_kernel_deps})
# extra # extra
# wait to add ... # wait to add ...
...@@ -82,6 +83,9 @@ lite_cc_test(test_pixel_shuffle_image_opencl SRCS pixel_shuffle_image_compute_te ...@@ -82,6 +83,9 @@ lite_cc_test(test_pixel_shuffle_image_opencl SRCS pixel_shuffle_image_compute_te
lite_cc_test(test_expand_image_opencl SRCS expand_image_compute_test.cc lite_cc_test(test_expand_image_opencl SRCS expand_image_compute_test.cc
DEPS expand_opencl op_registry program context) DEPS expand_opencl op_registry program context)
lite_cc_test(test_softmax_image_opencl SRCS softmax_image_compute_test.cc
DEPS softmax_opencl op_registry program context)
lite_cc_test(test_elementwise_add_image_opencl SRCS elementwise_add_image_compute_test.cc lite_cc_test(test_elementwise_add_image_opencl SRCS elementwise_add_image_compute_test.cc
DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context) DEPS elementwise_add_opencl fusion_elementwise_add_activation_opencl op_registry program context)
lite_cc_test(test_elementwise_sub_image_opencl SRCS elementwise_sub_image_compute_test.cc lite_cc_test(test_elementwise_sub_image_opencl SRCS elementwise_sub_image_compute_test.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/backends/opencl/cl_half.h"
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
#ifdef LITE_WITH_PROFILE
#include "lite/core/profile/profiler.h"
#endif
#include "lite/backends/opencl/cl_utility.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
class SoftmaxComputeImage2D : public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::SoftmaxParam;
std::string doc() const override {
return "Softmax using cl::Image2D, kFP16";
}
void PrepareForRun() override {
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
auto& context = ctx_->As<OpenCLContext>();
context.cl_context()->AddKernel(kernel_func_name_,
"image/softmax_kernel.cl",
build_options_,
time_stamp_);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_ << time_stamp_;
kernel_ = context.cl_context()->GetKernel(kernel_key.str());
}
void ReInitWhenNeeded() override {
VLOG(1) << "ReInitWhenNeeded: " << kernel_func_name_;
softmax_param_ = param_.get_mutable<param_t>();
auto x_dims = softmax_param_->x->dims();
auto out_dims = softmax_param_->output->dims();
VLOG(1) << "x_dims: " << x_dims;
VLOG(1) << "out_dims: " << out_dims;
VLOG(1) << "axis: " << softmax_param_->axis;
CHECK_EQ(out_dims.size(), 4) << "Softmax only support out_dims.size() == 4"
<< out_dims;
if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) ||
first_epoch_for_reinit_) {
last_x_dims_ = x_dims;
first_epoch_for_reinit_ = false;
// compute image shape
paddle::lite::CLImageConverterDefault default_convertor;
out_img_shape_ = default_convertor.InitImageDimInfoWith(
softmax_param_->output->dims());
VLOG(1) << "out_img_shape_: " << out_img_shape_[0] << " "
<< out_img_shape_[1];
// compute global work size
auto image_width = out_dims[3] * ((out_dims[1] + 3) / 4);
size_t work_size_0 = image_width / out_dims[3];
size_t work_size_1 = out_dims[3];
size_t work_size_2 = out_dims[0] * out_dims[2];
global_work_size_ = cl::NDRange{work_size_0, work_size_1, work_size_2};
VLOG(1) << "global_work_size_: " << global_work_size_[0] << " "
<< global_work_size_[1] << " " << global_work_size_[2];
}
}
void Run() override {
auto* x_img = softmax_param_->x->data<half_t, cl::Image2D>();
auto* out_img = softmax_param_->output->mutable_data<half_t, cl::Image2D>(
out_img_shape_[0], out_img_shape_[1]);
auto out_dims = softmax_param_->output->dims();
int out_w = out_dims[3];
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
auto kernel = kernel_;
cl_int status;
status = kernel.setArg(0, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(1, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(2, out_w);
CL_CHECK_FATAL(status);
status = EnqueueNDRangeKernel(context,
kernel,
cl::NullRange,
global_work_size_,
cl::NullRange,
nullptr,
event_);
CL_CHECK_FATAL(status);
}
#ifdef LITE_WITH_PROFILE
void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
ch->cl_event =
event_; // `event_` defined in `kernel.h`, valid after kernel::Run
}
#endif
private:
std::string kernel_func_name_{"softmax"};
std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
param_t* softmax_param_{nullptr};
cl::Kernel kernel_;
bool first_epoch_for_reinit_{true};
DDim last_x_dims_;
DDim out_img_shape_ = DDim(std::vector<DDim::value_type>(
{static_cast<DDim::value_type>(1), static_cast<DDim::value_type>(1)}));
cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(softmax,
kOpenCL,
kFP16,
kImageDefault,
paddle::lite::kernels::opencl::SoftmaxComputeImage2D,
image2d)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <random>
#include <gtest/gtest.h>
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/opencl/test_helper.h"
#define FP16_MAX_DIFF (5e-1)
namespace paddle {
namespace lite {
template <typename dtype>
void softmax_compute_ref(const operators::SoftmaxParam& param) {
const dtype* x_data = param.x->mutable_data<const dtype>();
dtype* output_data = param.output->mutable_data<dtype>();
DDim x_dims = param.x->dims();
ASSERT_EQ(x_dims.data(), param.output->dims().data());
auto x_rank = x_dims.size();
int axis = param.axis;
if (axis < 0) {
axis += x_rank;
}
int axis_size = x_dims[axis];
int outer_num = x_dims.Slice(0, axis).production();
int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int compute_size = outer_num * inner_num;
for (int i = 0; i < compute_size; i++) {
int idx_inner = i % inner_num;
int idx_outer = (i / inner_num) * axis_size;
int start = idx_outer * inner_num + idx_inner;
int offset;
offset = start;
dtype max_data = std::numeric_limits<dtype>::lowest();
for (int j = 0; j < axis_size; j++) {
max_data = x_data[offset] > max_data ? x_data[offset] : max_data;
offset += inner_num;
}
offset = start;
dtype sum_data = (dtype)0;
for (int j = 0; j < axis_size; j++) {
output_data[offset] = exp(x_data[offset] - max_data);
sum_data += output_data[offset];
offset += inner_num;
}
offset = start;
for (int j = 0; j < axis_size; j++) {
output_data[offset] /= sum_data;
offset += inner_num;
}
}
}
TEST(softmax_image2d, compute) {
#if 1
for (auto n : {1, 3}) {
for (auto c : {1, 4}) {
for (auto h : {5, 1}) {
for (auto w : {1, 6}) {
for (auto axis : {/*-2,*/ -1 /*, 0, 1, 2*/}) {
#else
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 11, 4}) {
for (auto h : {3, 1, 11, 4}) {
for (auto w : {1, 3, 4, 12}) {
for (auto axis : {-4, -3, -2, -1, 0, 1, 2, 3}) {
#endif
LOG(INFO) << "create kernel ...";
auto kernels =
KernelRegistry::Global().Create("softmax",
TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
// prepare opencl kernel params
auto kernel = std::move(kernels.front());
LOG(INFO) << "prepare to test kernel ====> " << kernel->doc();
LOG(INFO) << n << c << h << w;
operators::SoftmaxParam param;
lite::Tensor x;
lite::Tensor output;
operators::SoftmaxParam param_ref;
lite::Tensor x_ref;
lite::Tensor output_ref;
auto in_dim = DDim(std::vector<int64_t>({n, c, h, w}));
auto out_dim = DDim(std::vector<int64_t>({n, c, h, w}));
x.Resize(in_dim);
x_ref.Resize(in_dim);
output.Resize(out_dim);
output_ref.Resize(out_dim);
param.x = &x;
param.axis = axis;
param.output = &output;
param_ref.x = &x_ref;
param_ref.axis = axis;
param_ref.output = &output_ref;
auto* x_ref_data = x_ref.mutable_data<float>();
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
kernel->SetParam(param);
std::unique_ptr<KernelContext> softmax_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(softmax_context->As<OpenCLContext>()));
kernel->SetContext(std::move(softmax_context));
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-2, 2);
std::vector<float> input_v(n * c * h * w);
int index = 0;
for (auto& i : input_v) {
x_ref_data[index] = index;
i = index++;
}
VLOG(1) << "input_v ..... ";
for (size_t i = 0; i < input_v.size(); i++) {
VLOG(10) << input_v[i];
}
LOG(INFO) << "prepare input";
CLImageConverterDefault* default_converter =
new CLImageConverterDefault();
DDim x_image_shape = default_converter->InitImageDimInfoWith(
DDim(std::vector<int64_t>({n, c, h, w})));
LOG(INFO) << "x_image_shape = " << x_image_shape[0] << " "
<< x_image_shape[1];
std::vector<half_t> x_image_data(x_image_shape.production() *
4); // 4 : RGBA
default_converter->NCHWToImage(
input_v.data(), x_image_data.data(), in_dim);
auto* x_image = x.mutable_data<half_t, cl::Image2D>(
x_image_shape[0], x_image_shape[1], x_image_data.data());
VLOG(1) << "x_image_data ..... ";
for (size_t i = 0; i < x_image_data.size(); i++) {
VLOG(10) << Half2Float(x_image_data[i]);
}
DDim out_image_shape =
default_converter->InitImageDimInfoWith(out_dim);
LOG(INFO) << "out_image_shape = " << out_image_shape[0] << " "
<< out_image_shape[1];
auto* out_image = output.mutable_data<half_t, cl::Image2D>(
out_image_shape[0], out_image_shape[1]);
// run
kernel->Launch();
CLRuntime::Global()->command_queue().finish();
// handle output
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
half_t* out_image_data =
new half_t[out_image_shape.production() * 4];
TargetWrapperCL::ImgcpySync(out_image_data,
out_image,
out_image_shape[0],
out_image_shape[1],
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
VLOG(1) << "out_image_data ..... ";
for (size_t i = 0; i < out_image_shape.production() * 4; i++) {
VLOG(10) << Half2Float(out_image_data[i]);
}
std::vector<float> out_data(out_image_shape.production() * 4);
default_converter->ImageToNCHW(
out_image_data, out_data.data(), out_image_shape, out_dim);
VLOG(1) << "out_data ..... ";
for (int i = 0; i < out_dim.production(); i++) {
VLOG(10) << out_data[i];
}
auto* output_ref_data = output_ref.mutable_data<float>();
softmax_compute_ref<float>(param_ref);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(out_data[i], output_ref_data[i], 1e-2);
}
}
}
}
}
}
}
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(softmax, kOpenCL, kFP16, kImageDefault, image2d);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册