未验证 提交 89da9953 编写于 作者: X xiaogang 提交者: GitHub

【opencl】add dropout opencl kernel (#3141)

* feat: add dropout opencl kernel
上级 b3653b7e
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cl_common.h>
__kernel void dropout(__read_only image2d_t input_image,
__write_only image2d_t output_image,
__private const int out_W,
__private const float dropoutPro) {
const int out_c = get_global_id(0);
const int out_w = get_global_id(1);
const int out_nh = get_global_id(2);
int2 output_pos;
output_pos.x = out_c * out_W + out_w;
output_pos.y = out_nh;
const sampler_t sampler = CLK_NORMALIZED_COORDS_TRUE |
CLK_ADDRESS_CLAMP |
CLK_FILTER_NEAREST;
half4 input;
half4 output;
input = READ_IMG_TYPE(CL_DTYPE_CHAR, input_image, sampler,output_pos);
half4 dropout = (half4)(1 - dropoutPro);
output = dropout * input;
write_imageh(output_image, output_pos, output);
}
......@@ -31,6 +31,7 @@ add_kernel(lrn_opencl OPENCL basic SRCS lrn_image_compute.cc DEPS ${cl_kernel_de
add_kernel(bilinear_interp_opencl OPENCL basic SRCS bilinear_interp_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(slice_opencl OPENCL basic SRCS slice_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(instance_norm_opencl OPENCL basic SRCS instance_norm_image_compute.cc DEPS ${cl_kernel_deps})
add_kernel(dropout_opencl OPENCL basic SRCS dropout_image_compute.cc DEPS ${cl_kernel_deps})
# extra
# wait to add ...
......@@ -89,6 +90,9 @@ lite_cc_test(test_slice_image_opencl SRCS slice_image_compute_test.cc
lite_cc_test(test_instance_norm_image_opencl SRCS instance_norm_image_compute_test.cc
DEPS instance_norm_opencl op_registry program context)
lite_cc_test(test_dropout_image_opencl SRCS dropout_image_compute_test.cc
DEPS dropout_opencl op_registry program context)
######################
# buffer kernel #
######################
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/backends/opencl/cl_half.h"
#include "lite/backends/opencl/cl_include.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/opencl/image_helper.h"
#include "lite/operators/op_params.h"
#include "lite/utils/replace_stl/stream.h"
#include "lite/utils/string.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace opencl {
class DropoutComputeImage2D : public KernelLite<TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault)> {
public:
using param_t = operators::DropoutParam;
std::string doc() const override {
return "Dropout using cl::Image2D, kFP16";
}
void PrepareForRun() override {
auto& context = ctx_->As<OpenCLContext>();
VLOG(1) << "kernel_func_name_:" << kernel_func_name_;
context.cl_context()->AddKernel(
kernel_func_name_, "image/dropout_kernel.cl", build_options_);
}
void Run() override {
const auto& param = *param_.get_mutable<param_t>();
const auto& in_dims = param.x->dims();
const auto& out_dims = param.output->dims();
auto* x_img = param.x->data<half_t, cl::Image2D>();
const float dropout_prob = param.dropout_prob;
int input_dims[4] = {1, 1, 1, 1};
for (int i = 0; i < in_dims.size(); i++) {
input_dims[4 - in_dims.size() + i] = in_dims[i];
}
int out_w = input_dims[3];
auto out_image_shape = InitImageDimInfoWith(out_dims);
auto* out_img = param.output->mutable_data<half_t, cl::Image2D>(
out_image_shape["width"], out_image_shape["height"]);
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
STL::stringstream kernel_key;
kernel_key << kernel_func_name_ << build_options_;
auto kernel = context.cl_context()->GetKernel(kernel_key.str());
cl_int status;
int arg_idx = 0;
status = kernel.setArg(arg_idx, *x_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, *out_img);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, out_w);
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, dropout_prob);
CL_CHECK_FATAL(status);
const std::vector<size_t>& default_work_size =
DefaultWorkSize(out_dims,
DDim(std::vector<DDim::value_type>{
static_cast<int64_t>(out_image_shape["width"]),
static_cast<int64_t>(out_image_shape["height"])}));
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(default_work_size.data()[0]),
static_cast<cl::size_type>(default_work_size.data()[1]),
static_cast<cl::size_type>(default_work_size.data()[2])};
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
global_work_size,
cl::NullRange,
nullptr,
event_.get());
CL_CHECK_FATAL(status);
context.cl_wait_list()->emplace(out_img, event_);
}
private:
std::string kernel_func_name_{"dropout"};
std::string build_options_{"-DCL_DTYPE_half"};
std::shared_ptr<cl::Event> event_{new cl::Event};
};
} // namespace opencl
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(dropout,
kOpenCL,
kFP16,
kImageDefault,
paddle::lite::kernels::opencl::DropoutComputeImage2D,
image2d)
.BindInput("X",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kOpenCL),
PRECISION(kFP16),
DATALAYOUT(kImageDefault))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <memory>
#include <random>
#include "lite/backends/opencl/target_wrapper.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
#include "lite/kernels/opencl/test_helper.h"
#define FP16_MAX_DIFF (5e-1)
namespace paddle {
namespace lite {
void dropout(const float* input_data,
const DDim& in_dim,
float* output_data,
const float prob) {
for (int i = 0; i < in_dim.production(); i++) {
output_data[i] = input_data[i] * (1 - prob);
}
}
TEST(dropout_image2d_fp16, compute) {
LOG(INFO) << "to get kernel ...";
auto kernels = KernelRegistry::Global().Create(
"dropout", TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault));
ASSERT_FALSE(kernels.empty());
auto kernel = std::move(kernels.front());
LOG(INFO) << "get kernel:" << kernel->doc();
lite::Tensor x, out;
operators::DropoutParam param;
param.x = &x;
param.output = &out;
param.dropout_prob = 0.6;
std::unique_ptr<KernelContext> context(new KernelContext);
context->As<OpenCLContext>().InitOnce();
kernel->SetParam(param);
std::unique_ptr<KernelContext> dropout_context(new KernelContext);
context->As<OpenCLContext>().CopySharedTo(
&(dropout_context->As<OpenCLContext>()));
kernel->SetContext(std::move(dropout_context));
const DDim in_dim = DDim(std::vector<DDim::value_type>{4, 11, 107, 107});
const DDim out_dim = DDim(std::vector<DDim::value_type>{4, 11, 107, 107});
x.Resize(in_dim);
out.Resize(out_dim);
std::default_random_engine engine;
std::uniform_real_distribution<float> dist(-5, 5);
std::vector<float> input_v(4 * 11 * 107 * 107);
for (auto& i : input_v) {
i = dist(engine);
}
LOG(INFO) << "prepare input";
CLImageConverterDefault* default_converter = new CLImageConverterDefault();
DDim image_shape = default_converter->InitImageDimInfoWith(in_dim);
LOG(INFO) << "image_shape = " << image_shape[0] << " " << image_shape[1];
std::vector<half_t> x_image_data(image_shape.production() * 4); // 4 : RGBA
default_converter->NCHWToImage(input_v.data(), x_image_data.data(), in_dim);
auto* x_image = x.mutable_data<half_t, cl::Image2D>(
image_shape[0], image_shape[1], x_image_data.data());
LOG(INFO) << "x_image:" << x_image;
auto* out_image =
out.mutable_data<half_t, cl::Image2D>(image_shape[0], image_shape[1]);
LOG(INFO) << "out_image:" << out_image;
kernel->Launch();
auto* wait_list = context->As<OpenCLContext>().cl_wait_list();
auto* out_ptr = param.output->data<half_t, cl::Image2D>();
auto it = wait_list->find(out_ptr);
if (it != wait_list->end()) {
VLOG(4) << "--- Find the sync event for the target cl tensor. ---";
auto& event = *(it->second);
event.wait();
} else {
LOG(FATAL) << "Could not find the sync event for the target cl tensor.";
}
std::unique_ptr<float[]> out_ref(new float[out_dim.production()]);
dropout(input_v.data(), in_dim, out_ref.get(), 0.6);
const size_t cl_image2d_row_pitch{0};
const size_t cl_image2d_slice_pitch{0};
half_t* out_image_data = new half_t[image_shape.production() * 4];
TargetWrapperCL::ImgcpySync(out_image_data,
out_image,
image_shape[0],
image_shape[1],
cl_image2d_row_pitch,
cl_image2d_slice_pitch,
IoDirection::DtoH);
float* out_data = new float[image_shape.production() * 4];
default_converter->ImageToNCHW(
out_image_data, out_data, image_shape, out_dim);
for (int i = 0; i < out_dim.production(); i++) {
auto abs_diff = abs(out_data[i] - out_ref[i]);
auto relative_diff = COMPUTE_RELATIVE_DIFF(out_data[i], out_ref[i]);
EXPECT_EQ((relative_diff <= FP16_MAX_DIFF) || (abs_diff <= FP16_MAX_DIFF),
true);
if ((relative_diff > FP16_MAX_DIFF) && (abs_diff > FP16_MAX_DIFF)) {
LOG(ERROR) << "error idx:" << i << " out_data[" << i
<< "]:" << out_data[i] << " "
"out_ref["
<< i << "]:" << out_ref[i] << " abs_diff:" << abs_diff
<< " relative_diff:" << relative_diff
<< " FP16_MAX_DIFF:" << FP16_MAX_DIFF;
}
}
}
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(dropout, kOpenCL, kFP16, kImageDefault, image2d);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册