提交 fdf938ce 编写于 作者: L Liangliang He

Merge branch 'fc-mali' into 'master'

FC and CWise ops support opencl 1.1/1.2.

See merge request !343
......@@ -12,6 +12,7 @@
#include "gflags/gflags.h"
#include "mace/public/mace.h"
#include "mace/public/mace_runtime.h"
#include "mace/utils/logging.h"
#include "mace/benchmark/stat_summarizer.h"
......@@ -95,9 +96,23 @@ inline int64_t NowMicros() {
return static_cast<int64_t>(tv.tv_sec) * 1000000 + tv.tv_usec;
}
DeviceType ParseDeviceType(const std::string &device_str) {
if (device_str.compare("CPU") == 0) {
return DeviceType::CPU;
} else if (device_str.compare("NEON") == 0) {
return DeviceType::NEON;
} else if (device_str.compare("OPENCL") == 0) {
return DeviceType::OPENCL;
} else if (device_str.compare("HEXAGON") == 0) {
return DeviceType::HEXAGON;
} else {
return DeviceType::CPU;
}
}
bool RunInference(MaceEngine *engine,
const std::vector<mace::MaceInputInfo> &input_infos,
std::map<std::string, float*> *output_infos,
const std::map<std::string, mace::MaceTensor> &input_infos,
std::map<std::string, mace::MaceTensor> *output_infos,
StatSummarizer *summarizer,
int64_t *inference_time_us) {
MACE_CHECK_NOTNULL(output_infos);
......@@ -106,28 +121,16 @@ bool RunInference(MaceEngine *engine,
if (summarizer) {
run_metadata_ptr = &run_metadata;
}
if (input_infos.size() == 1 && output_infos->size() == 1) {
const int64_t start_time = NowMicros();
bool s = engine->Run(input_infos[0].data, input_infos[0].shape,
output_infos->begin()->second, run_metadata_ptr);
const int64_t end_time = NowMicros();
if (!s) {
LOG(ERROR) << "Error during inference.";
return s;
}
*inference_time_us = end_time - start_time;
} else {
const int64_t start_time = NowMicros();
bool s = engine->Run(input_infos, *output_infos, run_metadata_ptr);
const int64_t end_time = NowMicros();
const int64_t start_time = NowMicros();
mace::MaceStatus s = engine->Run(input_infos, output_infos, run_metadata_ptr);
const int64_t end_time = NowMicros();
if (!s) {
LOG(ERROR) << "Error during inference.";
return s;
}
*inference_time_us = end_time - start_time;
if (s != mace::MaceStatus::MACE_SUCCESS) {
LOG(ERROR) << "Error during inference.";
return false;
}
*inference_time_us = end_time - start_time;
if (summarizer != nullptr) {
summarizer->ProcessMetadata(run_metadata);
......@@ -137,8 +140,8 @@ bool RunInference(MaceEngine *engine,
}
bool Run(MaceEngine *engine,
const std::vector<mace::MaceInputInfo> &input_infos,
std::map<std::string, float*> *output_infos,
const std::map<std::string, mace::MaceTensor> &input_infos,
std::map<std::string, mace::MaceTensor> *output_infos,
StatSummarizer *summarizer,
int num_runs,
double max_time_sec,
......@@ -261,12 +264,7 @@ int Main(int argc, char **argv) {
stats_options.show_summary = FLAGS_show_summary;
stats.reset(new StatSummarizer(stats_options));
DeviceType device_type = CPU;
if (FLAGS_device == "OPENCL") {
device_type = OPENCL;
} else if (FLAGS_device == "NEON") {
device_type = NEON;
}
mace::DeviceType device_type = ParseDeviceType(FLAGS_device);
// config runtime
mace::ConfigOmpThreads(FLAGS_omp_num_threads);
......@@ -302,50 +300,45 @@ int Main(int argc, char **argv) {
mace::MACE_MODEL_TAG::LoadModelData(FLAGS_model_data_file.c_str());
NetDef net_def = mace::MACE_MODEL_TAG::CreateNet(model_data);
std::vector<mace::MaceInputInfo> input_infos(input_count);
std::map<std::string, float*> output_infos;
std::vector<std::unique_ptr<float[]>> input_datas(input_count);
std::vector<std::unique_ptr<float[]>> output_datas(output_count);
std::map<std::string, mace::MaceTensor> inputs;
std::map<std::string, mace::MaceTensor> outputs;
for (size_t i = 0; i < input_count; ++i) {
int64_t input_size = std::accumulate(input_shape_vec[i].begin(),
input_shape_vec[i].end(), 1,
std::multiplies<int64_t>());
input_datas[i].reset(new float[input_size]);
// Allocate input and output
int64_t input_size =
std::accumulate(input_shape_vec[i].begin(), input_shape_vec[i].end(), 1,
std::multiplies<int64_t>());
auto buffer_in = std::shared_ptr<float>(new float[input_size],
std::default_delete<float[]>());
// load input
std::ifstream in_file(FLAGS_input_file + "_" + FormatName(input_names[i]),
std::ios::in | std::ios::binary);
if (in_file.is_open()) {
in_file.read(reinterpret_cast<char *>(input_datas[i].get()),
in_file.read(reinterpret_cast<char *>(buffer_in.get()),
input_size * sizeof(float));
in_file.close();
} else {
LOG(INFO) << "Open input file failed";
return -1;
}
input_infos[i].name = input_names[i];
input_infos[i].shape = input_shape_vec[i];
input_infos[i].data = input_datas[i].get();
inputs[input_names[i]] = mace::MaceTensor(input_shape_vec[i], buffer_in);
}
for (size_t i = 0; i < output_count; ++i) {
int64_t output_size = std::accumulate(output_shape_vec[i].begin(),
output_shape_vec[i].end(), 1,
std::multiplies<int64_t>());
output_datas[i].reset(new float[output_size]);
output_infos[output_names[i]] = output_datas[i].get();
int64_t output_size =
std::accumulate(output_shape_vec[i].begin(),
output_shape_vec[i].end(), 1,
std::multiplies<int64_t>());
auto buffer_out = std::shared_ptr<float>(new float[output_size],
std::default_delete<float[]>());
outputs[output_names[i]] = mace::MaceTensor(output_shape_vec[i],
buffer_out);
}
// Init model
LOG(INFO) << "Run init";
std::unique_ptr<mace::MaceEngine> engine_ptr;
if (input_count == 1 && output_count == 1) {
engine_ptr.reset(new mace::MaceEngine(&net_def, device_type));
} else {
engine_ptr.reset(new mace::MaceEngine(&net_def, device_type,
input_names, output_names));
}
if (device_type == DeviceType::OPENCL) {
std::unique_ptr<mace::MaceEngine> engine_ptr(
new mace::MaceEngine(&net_def, device_type, input_names, output_names));
if (device_type == DeviceType::OPENCL || device_type == DeviceType::HEXAGON) {
mace::MACE_MODEL_TAG::UnloadModelData(model_data);
}
......@@ -355,7 +348,7 @@ int Main(int argc, char **argv) {
int64_t num_warmup_runs = 0;
if (FLAGS_warmup_runs > 0) {
bool status =
Run(engine_ptr.get(), input_infos, &output_infos, nullptr,
Run(engine_ptr.get(), inputs, &outputs, nullptr,
FLAGS_warmup_runs, -1.0,
inter_inference_sleep_seconds, &warmup_time_us, &num_warmup_runs);
if (!status) {
......@@ -370,7 +363,7 @@ int Main(int argc, char **argv) {
int64_t no_stat_time_us = 0;
int64_t no_stat_runs = 0;
bool status =
Run(engine_ptr.get(), input_infos, &output_infos,
Run(engine_ptr.get(), inputs, &outputs,
nullptr, FLAGS_max_num_runs, max_benchmark_time_seconds,
inter_inference_sleep_seconds, &no_stat_time_us, &no_stat_runs);
if (!status) {
......@@ -379,7 +372,7 @@ int Main(int argc, char **argv) {
int64_t stat_time_us = 0;
int64_t stat_runs = 0;
status = Run(engine_ptr.get(), input_infos, &output_infos,
status = Run(engine_ptr.get(), inputs, &outputs,
stats.get(), FLAGS_max_num_runs, max_benchmark_time_seconds,
inter_inference_sleep_seconds, &stat_time_us, &stat_runs);
if (!status) {
......
......@@ -480,12 +480,12 @@ uint64_t OpenCLRuntime::GetKernelWaveSize(const cl::Kernel &kernel) {
}
const bool OpenCLRuntime::IsNonUniformWorkgroupsSupported() {
if (gpu_type_ == GPUType::QUALCOMM_ADRENO &&
opencl_version_ == "2.0") {
return true;
} else {
return false;
}
return (gpu_type_ == GPUType::QUALCOMM_ADRENO &&
opencl_version_ == "2.0");
}
const GPUType OpenCLRuntime::gpu_type() const {
return gpu_type_;
}
const GPUType OpenCLRuntime::ParseGPUTypeFromDeviceName(
......
......@@ -66,6 +66,7 @@ class OpenCLRuntime {
uint64_t GetKernelWaveSize(const cl::Kernel &kernel);
const bool IsNonUniformWorkgroupsSupported();
const GPUType ParseGPUTypeFromDeviceName(const std::string &device_name);
const GPUType gpu_type() const;
cl::Kernel BuildKernel(const std::string &program_name,
const std::string &kernel_name,
const std::set<std::string> &build_options);
......
......@@ -114,6 +114,7 @@ struct CWiseFunctor<DeviceType::OPENCL, T> : CWiseFunctorBase {
StatsFuture *future);
cl::Kernel kernel_;
uint32_t kwg_size_;
std::vector<index_t> input_shape_;
};
......
#include <common.h>
__kernel void activation(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void activation(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
#ifdef USE_PRELU
__read_only image2d_t alpha,
......
#include <common.h>
__kernel void addn(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void addn(GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t input1,
#if INPUT_NUM > 2
......
#include <common.h>
// Supported data types: half/float
__kernel void batch_norm(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void batch_norm(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__read_only image2d_t scale,
__read_only image2d_t offset,
......
#include <common.h>
// Supported data types: half/float
__kernel void bias_add(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void bias_add(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__read_only image2d_t bias,
__write_only image2d_t output) {
......
#include <common.h>
__kernel void filter_buffer_to_image(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void filter_buffer_to_image(GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, /* h, w, oc, ic */
__private const int input_offset,
__private const int filter_h,
......@@ -53,8 +52,7 @@ __kernel void filter_buffer_to_image(
WRITE_IMAGET(output, coord, values);
}
__kernel void filter_image_to_buffer(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void filter_image_to_buffer(GLOBAL_WORK_GROUP_SIZE_DIM2
__global DATA_TYPE *output, /* h, w, oc, ic */
__private const int filter_h,
__private const int filter_w,
......@@ -102,8 +100,7 @@ __kernel void filter_image_to_buffer(
}
}
__kernel void dw_filter_buffer_to_image(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void dw_filter_buffer_to_image(GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, /* h, w, ic, m */
__private const int input_offset,
__private const int filter_w,
......@@ -160,8 +157,7 @@ __kernel void dw_filter_buffer_to_image(
WRITE_IMAGET(output, coord, values);
}
__kernel void in_out_buffer_to_image(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void in_out_buffer_to_image(GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, /* nhwc */
__private const int input_offset,
__private const int height,
......@@ -202,8 +198,7 @@ __kernel void in_out_buffer_to_image(
WRITE_IMAGET(output, coord, values);
}
__kernel void in_out_image_to_buffer(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void in_out_image_to_buffer(GLOBAL_WORK_GROUP_SIZE_DIM2
__global DATA_TYPE *output, /* nhwc */
__private const int height,
__private const int width,
......@@ -242,8 +237,7 @@ __kernel void in_out_image_to_buffer(
}
}
__kernel void arg_buffer_to_image(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void arg_buffer_to_image(GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, /* nhwc */
__private const int input_offset,
__private const int count,
......@@ -278,8 +272,7 @@ __kernel void arg_buffer_to_image(
WRITE_IMAGET(output, coord, values);
}
__kernel void arg_image_to_buffer(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void arg_image_to_buffer(GLOBAL_WORK_GROUP_SIZE_DIM2
__global DATA_TYPE *output, /* nhwc */
__private const int count,
__read_only image2d_t input) {
......@@ -312,8 +305,7 @@ __kernel void arg_image_to_buffer(
}
__kernel void in_out_height_buffer_to_image(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void in_out_height_buffer_to_image(GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, //nhwc
__private const int input_offset,
__private const int height,
......@@ -355,8 +347,7 @@ __kernel void in_out_height_buffer_to_image(
WRITE_IMAGET(output, coord, values);
}
__kernel void in_out_height_image_to_buffer(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void in_out_height_image_to_buffer(GLOBAL_WORK_GROUP_SIZE_DIM2
__global DATA_TYPE *output, //nhwc
__private const int height,
__private const int width,
......@@ -394,8 +385,7 @@ __kernel void in_out_height_image_to_buffer(
}
__kernel void in_out_width_buffer_to_image(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void in_out_width_buffer_to_image(GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, /* nhwc */
__private const int input_offset,
__private const int height,
......@@ -437,8 +427,7 @@ __kernel void in_out_width_buffer_to_image(
}
// only support 3x3 now
__kernel void winograd_filter_buffer_to_image(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void winograd_filter_buffer_to_image(GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, //Oc, Ic, H, W
__private const int input_offset,
__private const int in_channels,
......@@ -529,8 +518,7 @@ __kernel void winograd_filter_buffer_to_image(
}
// only support 3x3 now
__kernel void winograd_filter_image_to_buffer(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void winograd_filter_image_to_buffer(GLOBAL_WORK_GROUP_SIZE_DIM2
__global DATA_TYPE *output, //Oc, Ic, H, W
__private const int height,
__private const int width,
......
#include <common.h>
// assume channes_per_group mod 4 = 0 && groups mod 4 == 0
__kernel void channel_shuffle(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void channel_shuffle(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int groups,
__private const int channels_per_group,
......
......@@ -19,18 +19,18 @@
#ifndef NON_UNIFORM_WORK_GROUP
#define UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2 \
#define GLOBAL_WORK_GROUP_SIZE_DIM2 \
__private const int global_size_dim0, \
__private const int global_size_dim1,
#define UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3 \
#define GLOBAL_WORK_GROUP_SIZE_DIM3 \
__private const int global_size_dim0, \
__private const int global_size_dim1, \
__private const int global_size_dim2,
#else
#define UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
#define UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
#define GLOBAL_WORK_GROUP_SIZE_DIM2
#define GLOBAL_WORK_GROUP_SIZE_DIM3
#endif
......
......@@ -22,8 +22,7 @@ DATA_TYPE4 stitch_vector(DATA_TYPE4 left,
}
// Supported data type: half/float
__kernel void concat_channel(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void concat_channel(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input0,
__read_only image2d_t input1,
__private const int input0_chan,
......@@ -84,8 +83,7 @@ __kernel void concat_channel(
}
// Required: All input channels are divisible by 4
__kernel void concat_channel_multi(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void concat_channel_multi(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int chan_blk_offset,
__write_only image2d_t output) {
......
#include <common.h>
__kernel void conv_2d(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void conv_2d(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * cin, kh * kw * cout/4 */
#ifdef BIAS
......
#include <common.h>
__kernel void conv_2d_1x1(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void conv_2d_1x1(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * cin, cout/4 */
#ifdef BIAS
......
#include <common.h>
__kernel void conv_2d_3x3(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void conv_2d_3x3(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * cin , kh * kw * cout/4 */
#ifdef BIAS
......
#include <common.h>
__kernel void cwise(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__private const float value,
__write_only image2d_t output) {
__kernel void cwise(GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__private const float value,
__write_only image2d_t output) {
const int w = get_global_id(0);
const int hb = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
if (w >= global_size_dim0 || hb >= global_size_dim1) return;
#endif
DATA_TYPE4 in0 = READ_IMAGET(input, SAMPLER, (int2)(w, hb));
DATA_TYPE4 in1 = (DATA_TYPE4){value, value, value, value};
DATA_TYPE4 out;
......
#include <common.h>
__kernel void depth_to_space(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void depth_to_space(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int block_size,
__private const int input_height,
......@@ -36,7 +35,7 @@ __kernel void depth_to_space(
}
__kernel void space_to_depth(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int block_size,
__private const int input_height,
......
#include <common.h>
// Only multiplier = 1 is supported
__kernel void depthwise_conv2d(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void depthwise_conv2d(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * kh * kw * m, cin/4 */
#ifdef BIAS
......@@ -138,8 +137,7 @@ __kernel void depthwise_conv2d(
WRITE_IMAGET(output, (int2)(out_x_base + w, out_hb), out3);
}
__kernel void depthwise_conv2d_s1(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void depthwise_conv2d_s1(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t filter, /* cout%4 * kh * kw * m, cin/4 */
#ifdef BIAS
......
#include <common.h>
__kernel void eltwise(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void eltwise(GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input0, /* [c%4 * w * c/4, h * b] */
__read_only image2d_t input1,
#ifdef COEFF_SUM
......
#include <common.h>
// output = weight * input + bias
__kernel void fully_connected(__read_only image2d_t input,
__kernel void fully_connected(GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input,
__read_only image2d_t weight,
#ifdef BIAS
__read_only image2d_t bias,
......@@ -15,6 +16,10 @@ __kernel void fully_connected(__read_only image2d_t input,
const int out_blk_idx = get_global_id(1);
const int input_chan_blk = (input_channel + 3) >> 2;
#ifndef NON_UNIFORM_WORK_GROUP
if (batch_idx >= global_size_dim0 || out_blk_idx >= global_size_dim1) return;
#endif
float4 input_value;
float4 w0, w1, w2, w3;
......@@ -57,7 +62,8 @@ __kernel void fully_connected(__read_only image2d_t input,
}
// output = weight * input + bias
__kernel void fully_connected_width(__read_only image2d_t input,
__kernel void fully_connected_width(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__read_only image2d_t weight,
#ifdef BIAS
__read_only image2d_t bias,
......@@ -73,6 +79,7 @@ __kernel void fully_connected_width(__read_only image2d_t input,
const int width_blk_idx = get_global_id(1);
const int width_blk_count = get_global_size(1);
const int batch_out_blk_idx = get_global_id(2);
const int batch_idx = batch_out_blk_idx / out_blks;
const int out_blk_idx = batch_out_blk_idx % out_blks;
......@@ -115,6 +122,16 @@ __kernel void fully_connected_width(__read_only image2d_t input,
short inter_idx = mad24((short)get_local_id(2), local_size, inter_out_offset);
intermediate_output[inter_idx] = sum;
#ifdef NON_QUALCOMM_ADRENO
barrier(CLK_LOCAL_MEM_FENCE);
#endif
#ifndef NON_UNIFORM_WORK_GROUP
if (batch_out_blk_idx >= global_size_dim2) {
return;
}
#endif
if (inter_out_offset == 0) {
#ifdef BIAS
DATA_TYPE4 result = READ_IMAGET(bias, SAMPLER, (int2)(out_blk_idx, 0));
......@@ -122,7 +139,7 @@ __kernel void fully_connected_width(__read_only image2d_t input,
DATA_TYPE4 result = (DATA_TYPE4)(0, 0, 0, 0);
#endif
for(short i = 0; i < local_width_blk_size; ++i) {
for (short i = 0; i < local_width_blk_size; ++i) {
result += vload4(0, intermediate_output+inter_idx);
inter_idx += 4;
}
......
#include <common.h>
// C = A * B
__kernel void matmul(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void matmul(GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t A,
__read_only image2d_t B,
__write_only image2d_t C,
......
......@@ -19,8 +19,7 @@ inline int calculate_avg_block_size(const int pool_size,
}
// Supported data type: half/float
__kernel void pooling(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void pooling(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int in_height,
__private const int in_width,
......
#include <common.h>
__kernel void resize_bilinear_nocache(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void resize_bilinear_nocache(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__write_only image2d_t output,
__private const float height_scale,
......
#include <common.h>
__kernel void slice(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void slice(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int chan_blk_offset,
__write_only image2d_t output) {
......
#include <common.h>
__kernel void softmax(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void softmax(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int channels,
__private const int remain_channels,
......
#include <common.h>
__kernel void space_to_batch(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void space_to_batch(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t space_data,
__write_only image2d_t batch_data,
__private const int block_height,
......@@ -48,8 +47,7 @@ __kernel void space_to_batch(
WRITE_IMAGET(batch_data, batch_coord, value);
}
__kernel void batch_to_space(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_3
__kernel void batch_to_space(GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t batch_data,
__write_only image2d_t space_data,
__private const int block_height,
......
#include <common.h>
__kernel void winograd_transform_2x2(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void winograd_transform_2x2(GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input,
__write_only image2d_t output,
__private const int in_height,
......@@ -116,8 +115,7 @@ __kernel void winograd_transform_2x2(
}
}
__kernel void winograd_inverse_transform_2x2(
UNIFORM_WORK_GROUP_SIZE_PARAMS_IN_DIM_2
__kernel void winograd_inverse_transform_2x2(GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input,
#ifdef BIAS
__read_only image2d_t bias, /* cout%4 * cout/4 */
......
......@@ -23,8 +23,10 @@ void CWiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const index_t width_pixels = channel_blocks * width;
const index_t batch_height_pixels = batch * height;
auto runtime = OpenCLRuntime::Global();
const uint32_t gws[2] = {static_cast<uint32_t>(width_pixels),
static_cast<uint32_t>(batch_height_pixels)};
if (kernel_.get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("cwise");
......@@ -32,19 +34,27 @@ void CWiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(MakeString("-DCWISE_TYPE=", type_));
if (runtime->IsNonUniformWorkgroupsSupported()) {
built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
}
kernel_ = runtime->BuildKernel("cwise", kernel_name, built_options);
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0;
if (!runtime->IsNonUniformWorkgroupsSupported()) {
kernel_.setArg(idx++, gws[0]);
kernel_.setArg(idx++, gws[1]);
}
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, static_cast<float>(coeff_));
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
}
const uint32_t gws[2] = {static_cast<uint32_t>(width_pixels),
static_cast<uint32_t>(batch_height_pixels)};
const std::vector<uint32_t> lws = {64, 16, 1};
const std::vector<uint32_t> lws = {kwg_size_ / 16, 16, 1};
std::stringstream ss;
ss << "cwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1)
<< "_" << output->dim(2) << "_" << output->dim(3);
......
......@@ -27,6 +27,10 @@ void FCWXKernel(cl::Kernel *kernel,
auto runtime = OpenCLRuntime::Global();
if (kernel->get() == nullptr) {
const index_t batch = output->dim(0);
const index_t output_size = output->dim(3);
const index_t output_blocks = RoundUpDiv4(output_size);
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("fully_connected");
......@@ -55,28 +59,47 @@ void FCWXKernel(cl::Kernel *kernel,
default:
LOG(FATAL) << "Unknown activation type: " << activation;
}
if (runtime->gpu_type() != GPUType::QUALCOMM_ADRENO) {
built_options.emplace("-DNON_QUALCOMM_ADRENO");
}
if (runtime->IsNonUniformWorkgroupsSupported()) {
built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
}
*kernel =
runtime->BuildKernel("fully_connected", kernel_name, built_options);
const index_t batch = output->dim(0);
const index_t output_size = output->dim(3);
const index_t output_blocks = RoundUpDiv4(output_size);
const uint32_t wave_size =
static_cast<uint32_t>(runtime->GetKernelWaveSize(*kernel));
if (runtime->gpu_type() == GPUType::QUALCOMM_ADRENO) {
built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
const uint32_t wave_size =
static_cast<uint32_t>(runtime->GetKernelWaveSize(*kernel));
*gws = {4, (wave_size / 4), static_cast<uint32_t>(batch * output_blocks)};
*gws = {4, (wave_size / 4), static_cast<uint32_t>(batch * output_blocks)};
const uint32_t kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(*kernel));
const uint32_t inter_local_blks = kwg_size / ((*gws)[0] * (*gws)[1]);
*lws = {(*gws)[0], (*gws)[1], inter_local_blks};
const uint32_t kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(*kernel));
const uint32_t inter_local_blks = kwg_size / ((*gws)[0] * (*gws)[1]);
*lws = {(*gws)[0], (*gws)[1], inter_local_blks};
} else {
*gws = {4, 8, static_cast<uint32_t>(batch * output_blocks)};
const uint32_t kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(*kernel));
const uint32_t inter_local_blks = kwg_size / ((*gws)[0] * (*gws)[1]);
*lws = {(*gws)[0], (*gws)[1], inter_local_blks};
}
}
if (!IsVecEqual(*prev_input_shape, input->shape())) {
const index_t batch = output->dim(0);
const index_t output_blocks = RoundUpDiv4(output->dim(3));
(*gws)[2] = static_cast<uint32_t>(batch * output_blocks);
uint32_t idx = 0;
if (!runtime->IsNonUniformWorkgroupsSupported()) {
kernel->setArg(idx++, (*gws)[0]);
kernel->setArg(idx++, (*gws)[1]);
kernel->setArg(idx++, (*gws)[2]);
}
kernel->setArg(idx++, *(input->opencl_image()));
kernel->setArg(idx++, *(weight->opencl_image()));
if (bias != nullptr) {
......@@ -91,15 +114,25 @@ void FCWXKernel(cl::Kernel *kernel,
kernel->setArg(idx++, static_cast<int>(output_blocks));
kernel->setArg(idx++, relux_max_limit);
(*gws)[2] = static_cast<uint32_t>(batch * output_blocks);
*prev_input_shape = input->shape();
}
cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
*kernel, cl::NullRange, cl::NDRange((*gws)[0], (*gws)[1], (*gws)[2]),
cl::NDRange((*lws)[0], (*lws)[1], (*lws)[2]), nullptr, &event);
MACE_CHECK_CL_SUCCESS(error);
cl_int error;
if (runtime->IsNonUniformWorkgroupsSupported()) {
error = runtime->command_queue().enqueueNDRangeKernel(
*kernel, cl::NullRange, cl::NDRange((*gws)[0], (*gws)[1], (*gws)[2]),
cl::NDRange((*lws)[0], (*lws)[1], (*lws)[2]), nullptr, &event);
} else {
std::vector<uint32_t> roundup_gws(lws->size());
for (size_t i = 0; i < lws->size(); ++i) {
roundup_gws[i] = RoundUp((*gws)[i], (*lws)[i]);
}
error = runtime->command_queue().enqueueNDRangeKernel(
*kernel, cl::NullRange,
cl::NDRange(roundup_gws[0], roundup_gws[1], roundup_gws[2]),
cl::NDRange((*lws)[0], (*lws)[1], (*lws)[2]), nullptr, &event);
}
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
......@@ -125,8 +158,8 @@ void FCWTXKernel(cl::Kernel *kernel,
StatsFuture *future) {
MACE_CHECK_NOTNULL(gws);
MACE_CHECK_NOTNULL(lws);
auto runtime = OpenCLRuntime::Global();
if (kernel->get() == nullptr) {
auto runtime = OpenCLRuntime::Global();
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("fully_connected");
......@@ -136,6 +169,9 @@ void FCWTXKernel(cl::Kernel *kernel,
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
if (runtime->IsNonUniformWorkgroupsSupported()) {
built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
}
switch (activation) {
case NOOP:
break;
......@@ -157,10 +193,23 @@ void FCWTXKernel(cl::Kernel *kernel,
*kernel =
runtime->BuildKernel("fully_connected", kernel_name, built_options);
*lws = {16, 64, 1};
uint32_t kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(*kernel));
*lws = {16, kwg_size/16, 1};
}
if (!IsVecEqual(*prev_input_shape, input->shape())) {
const index_t batch = output->dim(0);
const index_t output_blocks = RoundUpDiv4(output->dim(3));
*gws = {
static_cast<uint32_t>(batch), static_cast<uint32_t>(output_blocks),
};
uint32_t idx = 0;
if (!runtime->IsNonUniformWorkgroupsSupported()) {
kernel->setArg(idx++, (*gws)[0]);
kernel->setArg(idx++, (*gws)[1]);
}
kernel->setArg(idx++, *(input->opencl_image()));
kernel->setArg(idx++, *(weight->opencl_image()));
if (bias != nullptr) {
......@@ -173,12 +222,6 @@ void FCWTXKernel(cl::Kernel *kernel,
// FIXME handle flexable data type: half not supported
kernel->setArg(idx++, relux_max_limit);
const index_t batch = output->dim(0);
const index_t output_blocks = RoundUpDiv4(output->dim(3));
*gws = {
static_cast<uint32_t>(batch), static_cast<uint32_t>(output_blocks),
};
*prev_input_shape = input->shape();
}
......
......@@ -57,7 +57,6 @@ void SimpleValidTest() {
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
......
......@@ -225,7 +225,7 @@ void TestWXFormat(const index_t batch,
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<DeviceType::OPENCL, T>(&net, "Weight", "WeightImage",
kernels::BufferType::WEIGHT_WIDTH);
BufferToImage<DeviceType::OPENCL, float>(&net, "Bias", "BiasImage",
BufferToImage<DeviceType::OPENCL, T>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FC", "FullyConnectedTest")
......@@ -236,7 +236,7 @@ void TestWXFormat(const index_t batch,
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on opencl
// Run
net.RunOp(DeviceType::OPENCL);
ImageToBuffer<DeviceType::OPENCL, float>(&net, "OutputImage", "OPENCLOutput",
......
#!/bin/bash
set -x
Usage() {
echo "Usage: bash tools/benchmark.sh target_soc model_output_dir option_args"
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册