提交 76f29748 编写于 作者: 李滨

Merge branch 'reduce' into 'master'

opt: optimize the performance of `Reduce` OP

See merge request applied-machine-learning/sysml/mace!1311
#include <common.h> #include <common.h>
#if REDUCE_TYPE == 1
#define INIT_REDUCE_VALUE (DATA_TYPE4){MAXFLOAT, MAXFLOAT, MAXFLOAT, MAXFLOAT}
#define REDUCE_VALUE(x, y) fmin(x, y)
#elif REDUCE_TYPE == 2 // MAX
#define INIT_REDUCE_VALUE (DATA_TYPE4){-MAXFLOAT, -MAXFLOAT, -MAXFLOAT, -MAXFLOAT}
#define REDUCE_VALUE(x, y) fmax(x, y)
#elif REDUCE_TYPE == 3 // PROD
#define INIT_REDUCE_VALUE (DATA_TYPE4){1, 1, 1, 1}
#define REDUCE_VALUE(x, y) (x * y)
#else // MEAN or SUM
#define INIT_REDUCE_VALUE (DATA_TYPE4){0, 0, 0, 0}
#define REDUCE_VALUE(x, y) (x + y)
#endif
__kernel void reduce(OUT_OF_RANGE_PARAMS __kernel void reduce(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3 GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input, __read_only image2d_t input,
__local float4 *local_buffer, __private const int out_height,
__private const int group_num, __private const int out_width,
__private const int compute_size,
__private const int last_index,
__private const int in_height, __private const int in_height,
__private const int in_width, __private const int in_width,
__private const float scale, __private const int org_height,
__private const int org_width,
__private const int channel_blocks, __private const int channel_blocks,
__write_only image2d_t output) { __write_only image2d_t output) {
const int w = get_local_id(0); const int ow = get_global_id(0);
const int h = get_local_id(1); const int oh = get_global_id(1);
const int bc = get_global_id(2); const int bc = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP #ifndef NON_UNIFORM_WORK_GROUP
if (bc >= global_size_dim2) if (bc >= global_size_dim2)
return; return;
#endif #endif
const int width = get_local_size(0);
const int index = mad24(h, width, w);
const int b = bc / channel_blocks; const int b = bc / channel_blocks;
const int ch = mad24(b, -channel_blocks, bc); const int c = bc % channel_blocks;
const int tile_w = in_width / out_width;
const int tile_h = in_height / out_height;
const int start_w = tile_w * ow;
const int start_h = tile_h * oh;
DATA_TYPE4 in; const int size_w = select(tile_w, in_width - start_w, ow >= out_width - 1);
const int size_h = select(tile_h, in_height - start_h, oh >= out_height - 1);
const int end_h = start_h + size_h;
const int end_w = start_w + size_w;
#if REDUCE_TYPE == 1 DATA_TYPE4 in;
DATA_TYPE4 part_result = (DATA_TYPE4){MAXFLOAT, MAXFLOAT, MAXFLOAT, MAXFLOAT}; DATA_TYPE4 out = INIT_REDUCE_VALUE;
#elif REDUCE_TYPE == 2
DATA_TYPE4 part_result = (DATA_TYPE4){-MAXFLOAT, -MAXFLOAT, -MAXFLOAT, -MAXFLOAT};
#elif REDUCE_TYPE == 3
DATA_TYPE4 part_result = (DATA_TYPE4){1, 1, 1, 1};
#else
DATA_TYPE4 part_result = (DATA_TYPE4){0, 0, 0, 0};
#endif
const bool after_last = (last_index > 0 && index >= last_index);
// After last index, each kernel only computes (compute_size - 1) elements.
const int actual_compute_size = select(compute_size,
compute_size - 1,
after_last);
const int base_offset = mul24(index, actual_compute_size);
const int offset= select(base_offset,
base_offset + last_index,
after_last);
#pragma unroll #pragma unroll
for (int i = 0; i < actual_compute_size; ++i) { for (int h = start_h; h < end_h; ++h) {
int element_idx = offset + i; for (int w = start_w; w < end_w; ++w) {
int h_idx = element_idx / in_width; int pos_x = mad24(c, in_width, w);
int w_idx = mad24(h_idx, -in_width, element_idx); int pos_y = mad24(b, in_height, h);
int pos_x = mad24(ch, in_width, w_idx); in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y));
int pos_y = mad24(b, in_height, h_idx); out = REDUCE_VALUE(out, in);
in = READ_IMAGET(input, SAMPLER, (int2)(pos_x, pos_y)); }
// MIN
#if REDUCE_TYPE == 1
part_result = fmin(part_result, in);
// MAX
#elif REDUCE_TYPE == 2
part_result = fmax(part_result, in);
// PROD
#elif REDUCE_TYPE == 3
part_result = part_result * in;
// MEAN or SUM
#else
part_result = part_result + in;
#endif
} }
#if REDUCE_TYPE == 0 #if REDUCE_TYPE == 0
part_result = part_result * scale; if (out_height == 1 && out_width == 1) {
out = out / (org_height * org_width);
}
#endif #endif
local_buffer[index] = part_result;
barrier(CLK_LOCAL_MEM_FENCE); int pos_x = mad24(c, out_width, ow);
int pos_y = mad24(b, out_height, oh);
if (w == 0 && h == 0) { WRITE_IMAGET(output, (int2)(pos_x, pos_y), out);
#if REDUCE_TYPE == 1
DATA_TYPE4 out = (DATA_TYPE4){MAXFLOAT, MAXFLOAT, MAXFLOAT, MAXFLOAT};
#elif REDUCE_TYPE == 2
DATA_TYPE4 out = (DATA_TYPE4){-MAXFLOAT, -MAXFLOAT, -MAXFLOAT, -MAXFLOAT};
#elif REDUCE_TYPE == 3
DATA_TYPE4 out = (DATA_TYPE4){1, 1, 1, 1};
#else
DATA_TYPE4 out = (DATA_TYPE4){0, 0, 0, 0};
#endif
#pragma unroll
for (int i = 0; i < group_num; ++i) {
#if REDUCE_TYPE == 1
out = fmin(out, local_buffer[i]);
#elif REDUCE_TYPE == 2
out = fmax(out, local_buffer[i]);
#elif REDUCE_TYPE == 3
out = out * local_buffer[i];
#else
out = out + local_buffer[i];
#endif
}
WRITE_IMAGET(output, (int2)(ch, b), out);
}
} }
...@@ -15,124 +15,164 @@ ...@@ -15,124 +15,164 @@
#include "mace/ops/opencl/image/reduce.h" #include "mace/ops/opencl/image/reduce.h"
#include <algorithm> #include <algorithm>
#include <utility>
#include <vector>
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace opencl { namespace opencl {
namespace image { namespace image {
namespace {
const index_t TILE_SIZE = 16;
cl::Image *InitScratchImageAndGetPointer(OpContext *context, DataType dtype,
ScratchImage *scratch_image,
const std::vector<index_t> &shape) {
std::vector<size_t> image_shape;
OpenCLUtil::CalImage2DShape(shape, OpenCLBufferType::IN_OUT_CHANNEL,
&image_shape);
auto mace_image = scratch_image->Scratch(
context->device()->allocator(), image_shape, dtype);
cl::Image *image = static_cast<cl::Image *>(mace_image->buffer());
return image;
}
} // namespace
MaceStatus ReduceKernel::BuildReduceKernel(OpenCLRuntime *runtime) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("reduce");
built_options.emplace("-Dreduce=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DT_FLOAT));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DT_FLOAT));
built_options.emplace(MakeString("-DREDUCE_TYPE=", reduce_type_));
MACE_RETURN_IF_ERROR(runtime->BuildKernel(
"reduce", kernel_name, built_options, &kernel_));
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
return MaceStatus::MACE_SUCCESS;
}
MaceStatus ReduceKernel::GraduallyComputeReduce(
OpContext *context, const index_t batch, const index_t channel_blocks,
const index_t in_height, const index_t in_width,
const index_t out_height, const index_t out_width,
const index_t org_height, const index_t org_width,
const cl::Image *input, cl::Image *output) {
MACE_OUT_OF_RANGE_DEFINITION;
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
if (kernel_.get() == nullptr) {
MACE_RETURN_IF_ERROR(BuildReduceKernel(runtime));
}
const uint32_t gws[3] = {static_cast<uint32_t>(out_width),
static_cast<uint32_t>(out_height),
static_cast<uint32_t>(batch * channel_blocks)};
std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_);
MACE_OUT_OF_RANGE_INIT(kernel_);
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_);
MACE_SET_3D_GWS_ARGS(kernel_, gws);
kernel_.setArg(idx++, *input);
kernel_.setArg(idx++, static_cast<int>(out_height));
kernel_.setArg(idx++, static_cast<int>(out_width));
kernel_.setArg(idx++, static_cast<int>(in_height));
kernel_.setArg(idx++, static_cast<int>(in_width));
kernel_.setArg(idx++, static_cast<int>(org_height));
kernel_.setArg(idx++, static_cast<int>(org_width));
kernel_.setArg(idx++, static_cast<int>(channel_blocks));
kernel_.setArg(idx++, *output);
std::string tuning_key = Concat(
"reduce_opencl_kernel", gws[0], gws[1], gws[2]);
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
MaceStatus ReduceKernel::Compute( MaceStatus ReduceKernel::Compute(
OpContext *context, OpContext *context,
const Tensor *input, const Tensor *input,
Tensor *output) { Tensor *output) {
MACE_CHECK_NOTNULL(input); MACE_CHECK_NOTNULL(input);
index_t batch = input->dim(0); const index_t batch = input->dim(0);
const index_t in_height = input->dim(1); const index_t org_height = input->dim(1);
const index_t in_width = input->dim(2); const index_t org_width = input->dim(2);
index_t in_height = org_height;
index_t in_width = org_width;
const index_t channels = input->dim(3); const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels); const index_t channel_blocks = RoundUpDiv4(channels);
const uint32_t image_size = static_cast<uint32_t >(in_height * in_width);
std::vector<uint32_t> gws(3);
std::vector<uint32_t> lws(3);
std::vector<index_t> output_shape{batch, 1, 1, channels}; std::vector<index_t> output_shape{batch, 1, 1, channels};
std::vector<size_t> output_image_shape; std::vector<size_t> output_image_shape;
OpenCLUtil::CalImage2DShape(output_shape, OpenCLBufferType::IN_OUT_CHANNEL, OpenCLUtil::CalImage2DShape(output_shape, OpenCLBufferType::IN_OUT_CHANNEL,
&output_image_shape); &output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape)); MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape));
auto runtime = context->device()->gpu_runtime()->opencl_runtime(); MaceStatus result = MaceStatus::MACE_RUNTIME_ERROR;
MACE_OUT_OF_RANGE_DEFINITION; if (in_height <= TILE_SIZE && in_width <= TILE_SIZE) {
result = GraduallyComputeReduce(context, batch, channel_blocks, in_height,
if (kernel_.get() == nullptr) { in_width, 1, 1, org_height, org_width,
std::set<std::string> built_options; input->opencl_image(),
MACE_OUT_OF_RANGE_CONFIG; output->opencl_image());
MACE_NON_UNIFORM_WG_CONFIG; } else {
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("reduce"); ScratchImageManager *scratch_manager =
built_options.emplace("-Dreduce=" + kernel_name); context->device()->gpu_runtime()->scratch_image_manager();
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DT_FLOAT)); ScratchImage scratch_inter_image(scratch_manager);
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DT_FLOAT)); auto out_height = RoundUpDiv(in_height, TILE_SIZE);
built_options.emplace(MakeString("-DREDUCE_TYPE=", reduce_type_)); auto out_width = RoundUpDiv(in_width, TILE_SIZE);
if (runtime->gpu_type() != GPUType::QUALCOMM_ADRENO) { const std::vector<index_t> inter_shape =
built_options.emplace("-DNON_QUALCOMM_ADRENO"); {{batch, out_height, out_width, channels}};
} cl::Image *inter_image = InitScratchImageAndGetPointer(
MACE_RETURN_IF_ERROR(runtime->BuildKernel("reduce", context, input->dtype(), &scratch_inter_image, inter_shape);
kernel_name, result = GraduallyComputeReduce(context, batch, channel_blocks, in_height,
built_options, in_width, out_height, out_width,
&kernel_)); org_height, org_width,
kwg_size_ = input->opencl_image(), inter_image);
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_)); MACE_RETURN_IF_ERROR(result);
}
// In the reduce.cl file, the computation is divided into two steps.
// The first step computes `compute_size` times parallelly, and the second
// step computes `group_num` times. In order to speed up the computation, we
// make the computation times of these two steps as uniform as possible.
uint32_t local_wg_size = static_cast<uint32_t>(sqrt(in_height * in_width));
// Increase the times of the second step for it's not parallel
local_wg_size *= 2;
local_wg_size = std::min(local_wg_size, kwg_size_);
gws = {4, local_wg_size / 4, static_cast<uint32_t>(batch * channel_blocks)};
if (gws[1] == 0) {
gws[1] = 1;
}
lws = {gws[0], gws[1], 1}; in_height = out_height;
const int group_num = lws[0] * lws[1] * lws[2]; in_width = out_width;
// Each kernel intends to compute compute_size elements. out_height = RoundUpDiv(in_height, TILE_SIZE);
const int compute_size = (image_size + group_num - 1) / group_num; out_width = RoundUpDiv(in_width, TILE_SIZE);
const int last_index = image_size % group_num;
const float scale = 1.f / (in_width * in_height);
MACE_OUT_OF_RANGE_INIT(kernel_); if (in_height > TILE_SIZE || in_width > TILE_SIZE) {
if (!IsVecEqual(input_shape_, input->shape())) { ScratchImage scratch_inter2_image(scratch_manager);
uint32_t idx = 0; const std::vector<index_t> inter2_shape =
MACE_OUT_OF_RANGE_SET_ARGS(kernel_); {{batch, out_height, out_width, channels}};
MACE_SET_3D_GWS_ARGS(kernel_, gws); cl::Image *inter2_image = InitScratchImageAndGetPointer(
kernel_.setArg(idx++, *(input->opencl_image())); context, input->dtype(), &scratch_inter2_image, inter2_shape);
kernel_.setArg(idx++, (group_num * 4 * sizeof(float)),
nullptr);
kernel_.setArg(idx++, static_cast<int32_t>(group_num));
kernel_.setArg(idx++, static_cast<int32_t>(compute_size));
kernel_.setArg(idx++, static_cast<int32_t>(last_index));
kernel_.setArg(idx++, static_cast<int32_t>(in_height));
kernel_.setArg(idx++, static_cast<int32_t>(in_width));
kernel_.setArg(idx++, scale);
kernel_.setArg(idx++, static_cast<int32_t>(channel_blocks));
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
}
cl::Event event; while (out_height > 1 || out_width > 1) {
cl_int error; result = GraduallyComputeReduce(context, batch, channel_blocks,
if (runtime->IsNonUniformWorkgroupsSupported()) { in_height, in_width, out_height,
error = runtime->command_queue().enqueueNDRangeKernel( out_width, org_height, org_width,
kernel_, cl::NullRange, cl::NDRange(gws[0], gws[1], gws[2]), inter_image, inter2_image);
cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event); MACE_RETURN_IF_ERROR(result);
} else { in_height = out_height;
std::vector<uint32_t> roundup_gws(lws.size()); in_width = out_width;
for (size_t i = 0; i < lws.size(); ++i) { out_height = RoundUpDiv(in_height, TILE_SIZE);
roundup_gws[i] = RoundUp(gws[i], lws[i]); out_width = RoundUpDiv(in_width, TILE_SIZE);
std::swap(inter_image, inter2_image);
}
} }
error = runtime->command_queue().enqueueNDRangeKernel(
kernel_, cl::NullRange,
cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws[2]),
cl::NDRange(lws[0], lws[1], lws[2]), nullptr, &event);
}
MACE_CL_RET_STATUS(error);
MACE_OUT_OF_RANGE_VALIDATION;
if (context->future() != nullptr) { result = GraduallyComputeReduce(context, batch, channel_blocks, in_height,
context->future()->wait_fn = [runtime, event](CallStats *stats) { in_width, 1, 1, org_height, org_width,
event.wait(); inter_image, output->opencl_image());
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
} }
return MaceStatus::MACE_SUCCESS; return result;
} }
} // namespace image } // namespace image
......
...@@ -42,6 +42,15 @@ class ReduceKernel : public OpenCLReduceKernel { ...@@ -42,6 +42,15 @@ class ReduceKernel : public OpenCLReduceKernel {
const Tensor *input, const Tensor *input,
Tensor *output) override; Tensor *output) override;
private:
MaceStatus BuildReduceKernel(OpenCLRuntime *runtime);
MaceStatus GraduallyComputeReduce(
OpContext *context, const index_t batch, const index_t channel_blocks,
const index_t in_height, const index_t in_width,
const index_t out_height, const index_t out_width,
const index_t org_height, const index_t org_width,
const cl::Image *input, cl::Image *output);
private: private:
ReduceType reduce_type_; ReduceType reduce_type_;
const std::vector<int> axis_; const std::vector<int> axis_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册