提交 4deb3c69 编写于 作者: L liuqi

Support SAME padding for winograd convolution.

上级 3d1bf3eb
......@@ -31,32 +31,32 @@ __kernel void winograd_transform_2x2(__read_only image2d_t input,
DATA_TYPE4 tv2[4];
DATA_TYPE4 tv3[4];
int y = nh_idx;
int y = select(nh_idx, -1, height_idx < 0 || height_idx >= in_height);
#pragma unroll
for (short i = 0; i < 4; ++i) {
int x = width_idx + i;
x = select(wc_idx + i, -1, x >= in_width);
x = select(wc_idx + i, -1, x < 0 || x >= in_width);
input0[i] = READ_IMAGET(input, SAMPLER, (int2)(x, y));
}
y = select(nh_idx + 1, -1, height_idx + 1 >= in_height);
y = select(nh_idx + 1, -1, height_idx + 1 < 0 || height_idx + 1 >= in_height);
#pragma unroll
for (short i = 0; i < 4; ++i) {
int x = width_idx + i;
x = select(wc_idx + i, -1, x >= in_width);
x = select(wc_idx + i, -1, x < 0 || x >= in_width);
input1[i] = READ_IMAGET(input, SAMPLER, (int2)(x, y));
}
y = select(nh_idx + 2, -1, height_idx + 2 >= in_height);
y = select(nh_idx + 2, -1, height_idx + 2 < 0 || height_idx + 2 >= in_height);
#pragma unroll
for (short i = 0; i < 4; ++i) {
int x = width_idx + i;
x = select(wc_idx + i, -1, x >= in_width);
x = select(wc_idx + i, -1, x < 0 || x >= in_width);
input2[i] = READ_IMAGET(input, SAMPLER, (int2)(x, y));
}
y = select(nh_idx + 3, -1, height_idx + 3 >= in_height);
y = select(nh_idx + 3, -1, height_idx + 3 < 0 || height_idx + 3 >= in_height);
#pragma unroll
for (short i = 0; i < 4; ++i) {
int x = width_idx + i;
x = select(wc_idx + i, -1, x >= in_width);
x = select(wc_idx + i, -1, x < 0 || x >= in_width);
input3[i] = READ_IMAGET(input, SAMPLER, (int2)(x, y));
}
......
......@@ -6,6 +6,7 @@
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
......@@ -35,32 +36,66 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
auto runtime = OpenCLRuntime::Global();
auto b2f_kernel = runtime->BuildKernel("winograd_transform",
auto wino_kernel = runtime->BuildKernel("winograd_transform",
obfuscated_kernel_name,
built_options);
uint32_t idx = 0;
b2f_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input_tensor->buffer())));
b2f_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output_tensor->buffer())));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(3)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(round_w));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(paddings[0] / 2));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(paddings[1] / 2));
wino_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input_tensor->buffer())));
wino_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output_tensor->buffer())));
wino_kernel.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(1)));
wino_kernel.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(2)));
wino_kernel.setArg(idx++, static_cast<uint32_t>(input_tensor->dim(3)));
wino_kernel.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
wino_kernel.setArg(idx++, static_cast<uint32_t>(round_w));
wino_kernel.setArg(idx++, static_cast<uint32_t>(paddings[0] / 2));
wino_kernel.setArg(idx++, static_cast<uint32_t>(paddings[1] / 2));
const size_t gws[2] = {static_cast<size_t>(out_width),
static_cast<size_t>(RoundUpDiv4(input_tensor->dim(3)))};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(b2f_kernel);
const std::vector<uint32_t> lws = {128, 8};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(wino_kernel);
auto params_generator = [&]()->std::vector<std::vector<uint32_t>> {
std::vector<uint32_t> local_ws(2, 0);
local_ws[0] = std::min<uint32_t>(gws[0], kwg_size);
local_ws[1] = std::min<uint32_t>(gws[1], kwg_size / local_ws[0]);
return {{local_ws[0], local_ws[1]},
{local_ws[1], local_ws[0]},
{kwg_size / 4, 4},
{kwg_size / 8, 8},
{kwg_size / 16, 16},
{kwg_size / 32, 32},
{kwg_size / 64, 64},
{kwg_size / 128, 128},
{kwg_size / 256, 256},
{kwg_size / 512, 512},
{kwg_size, 1},
{1, kwg_size}
};
};
cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
b2f_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1]),
cl::NDRange(lws[0], lws[1]),
nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
auto func = [&](const std::vector<uint32_t>& params)->cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
wino_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1]),
cl::NDRange(params[0], params[1]),
nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss;
ss << "winograd_transform_kernel_"
<< input_tensor->dim(0) << "_"
<< input_tensor->dim(1) << "_"
<< input_tensor->dim(2) << "_"
<< input_tensor->dim(3);
OpenCLProfilingTimer timer(&event);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
......@@ -91,31 +126,65 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te
built_options.emplace("-DDIVISIBLE_FOUR");
}
auto runtime = OpenCLRuntime::Global();
auto b2f_kernel = runtime->BuildKernel("winograd_transform",
auto wino_kernel = runtime->BuildKernel("winograd_transform",
obfuscated_kernel_name,
built_options);
const uint32_t round_h = (height_ + 1) / 2;
const uint32_t round_w = (width_ + 1) / 2;
uint32_t idx = 0;
b2f_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input_tensor->buffer())));
b2f_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output_tensor->buffer())));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(output_shape[1]));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(output_shape[2]));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(round_w));
wino_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input_tensor->buffer())));
wino_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output_tensor->buffer())));
wino_kernel.setArg(idx++, static_cast<uint32_t>(output_shape[1]));
wino_kernel.setArg(idx++, static_cast<uint32_t>(output_shape[2]));
wino_kernel.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
wino_kernel.setArg(idx++, static_cast<uint32_t>(round_w));
const size_t gws[2] = {static_cast<size_t>(input_tensor->dim(2)),
static_cast<size_t>(RoundUpDiv4(input_tensor->dim(1)))};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(b2f_kernel);
const std::vector<uint32_t> lws = {128, 8};
const uint32_t kwg_size = runtime->GetKernelMaxWorkGroupSize(wino_kernel);
auto params_generator = [&]()->std::vector<std::vector<uint32_t>> {
std::vector<uint32_t> local_ws(2, 0);
local_ws[0] = std::min<uint32_t>(gws[0], kwg_size);
local_ws[1] = std::min<uint32_t>(gws[1], kwg_size / local_ws[0]);
return {{local_ws[0], local_ws[1]},
{local_ws[1], local_ws[0]},
{kwg_size / 4, 4},
{kwg_size / 8, 8},
{kwg_size / 16, 16},
{kwg_size / 32, 32},
{kwg_size / 64, 64},
{kwg_size / 128, 128},
{kwg_size / 256, 256},
{kwg_size / 512, 512},
{kwg_size, 1},
{1, kwg_size}
};
};
cl::Event event;
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
b2f_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1]),
cl::NDRange(lws[0], lws[1]),
nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
auto func = [&](const std::vector<uint32_t>& params)->cl_int {
cl_int error = runtime->command_queue().enqueueNDRangeKernel(
wino_kernel, cl::NullRange,
cl::NDRange(gws[0], gws[1]),
cl::NDRange(params[0], params[1]),
nullptr, &event);
MACE_CHECK(error == CL_SUCCESS) << "Error code: " << error;
return error;
};
std::stringstream ss;
ss << "winograd_inverse_transform_kernel_"
<< input_tensor->dim(0) << "_"
<< input_tensor->dim(1) << "_"
<< input_tensor->dim(2) << "_"
<< input_tensor->dim(3);
OpenCLProfilingTimer timer(&event);
Tuner<uint32_t>::Get()->template TuneOrRun<cl_int>(ss.str(),
lws,
params_generator,
func,
&timer);
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
......
......@@ -9,213 +9,8 @@
namespace mace {
class WinogradTransformOpTest : public OpsTestBase {};
class WinogradConvlutionTest : public OpsTestBase {};
//TEST_F(WinogradTransformOpTest, WinogradInputTransform) {
// srand(time(NULL));
//
// // generate random input
// index_t batch = 7;
// index_t height = 61;
// index_t width = 71;
// index_t channels = 31;
//
// index_t p = batch * ((height - 1) / 2) * ((width - 1) / 2);
//
// const std::string A_file = "/data/local/tmp/test/A";
// const std::string C_file = "/data/local/tmp/test/C";
// const std::vector<index_t> A_shape = {batch, height, width, channels};
// const int A_size = std::accumulate(A_shape.begin(), A_shape.end(), 1, std::multiplies<int>());
// const std::vector<index_t> C_shape = {16, channels, p, 1};
// const int C_size = std::accumulate(C_shape.begin(), C_shape.end(), 1, std::multiplies<int>());
//
// std::vector<float> A_data(A_size, 0.0);
// std::ifstream in_file(A_file, std::ios::in | std::ios::binary);
// if (in_file.is_open()) {
// in_file.read(reinterpret_cast<char *>(A_data.data()),
// A_size * sizeof(float));
// in_file.close();
// } else {
// VLOG(0) << "open A file failed";
// }
// auto C_tensor = unique_ptr<Tensor>(new Tensor(GetDeviceAllocator(DeviceType::OPENCL),
// DataTypeToEnum<float>::v()));
// C_tensor->Resize(C_shape);
// std::vector<float> C_data(C_size, 0.0);
// std::ifstream C_in_file(C_file, std::ios::in | std::ios::binary);
// if (C_in_file.is_open()) {
// C_in_file.read(reinterpret_cast<char *>(C_data.data()),
// C_size * sizeof(float));
// C_in_file.close();
// Tensor::MappingGuard C_mapper(C_tensor.get());
// float *batch_ptr = C_tensor->mutable_data<float>();
// MACE_CHECK(static_cast<size_t>(C_tensor->size()) ==
// C_data.size());
// memcpy(batch_ptr, C_data.data(), C_data.size() * sizeof(float));
// } else {
// VLOG(0) << "open C file failed";
// }
// // Construct graph
// OpsTestNet net;
// // Add input data
// net.AddInputFromArray<DeviceType::OPENCL, float>(
// "A", A_shape, A_data);
//
// // Run on opencl
// BufferToImage<DeviceType::OPENCL, float>(net, "A", "AImage",
// kernels::BufferType::IN_OUT_CHANNEL);
//
// OpDefBuilder("WinogradTransform", "WinogradTransformTest")
// .Input("AImage")
// .Output("OutputImage")
// .Finalize(net.NewOperatorDef());
//
// // Run on opencl
// net.RunOp(DeviceType::OPENCL);
// net.Sync();
//
// ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "OPENCLOutput",
// kernels::BufferType::IN_OUT_HEIGHT);
// ExpectTensorNear<float>(*(C_tensor.get()), *net.GetOutput("OPENCLOutput"), 1e-4);
//}
//
//TEST_F(WinogradTransformOpTest, FilterTransform) {
// srand(time(NULL));
//
// // generate random input
// index_t out_chan = 31;
// index_t in_chan = 31;
// index_t height = 3;
// index_t width = 3;
//
// index_t p = (in_chan + 3) / 4;
//
// const std::string A_file = "/data/local/tmp/test/filter_in";
// const std::string C_file = "/data/local/tmp/test/filter_out";
// const std::vector<index_t> A_shape = {out_chan, in_chan, height, width};
// const int A_size = std::accumulate(A_shape.begin(), A_shape.end(), 1, std::multiplies<int>());
// const std::vector<index_t> C_shape = {16, out_chan, in_chan, 1};
// const int C_size = std::accumulate(C_shape.begin(), C_shape.end(), 1, std::multiplies<int>());
//
// std::vector<float> A_data(A_size, 0.0);
// std::ifstream in_file(A_file, std::ios::in | std::ios::binary);
// if (in_file.is_open()) {
// in_file.read(reinterpret_cast<char *>(A_data.data()),
// A_size * sizeof(float));
// in_file.close();
// } else {
// VLOG(0) << "open A file failed";
// }
// auto C_tensor = unique_ptr<Tensor>(new Tensor(GetDeviceAllocator(DeviceType::OPENCL),
// DataTypeToEnum<float>::v()));
// C_tensor->Resize(C_shape);
// std::vector<float> C_data(C_size, 0.0);
// std::ifstream C_in_file(C_file, std::ios::in | std::ios::binary);
// if (C_in_file.is_open()) {
// C_in_file.read(reinterpret_cast<char *>(C_data.data()),
// C_size * sizeof(float));
// C_in_file.close();
// Tensor::MappingGuard C_mapper(C_tensor.get());
// float *batch_ptr = C_tensor->mutable_data<float>();
// MACE_CHECK(static_cast<size_t>(C_tensor->size()) ==
// C_data.size());
// memcpy(batch_ptr, C_data.data(), C_data.size() * sizeof(float));
// } else {
// VLOG(0) << "open C file failed";
// }
// // Construct graph
// OpsTestNet net;
// // Add input data
// net.AddInputFromArray<DeviceType::OPENCL, float>(
// "A", A_shape, A_data);
//
// // Run on opencl
//
// OpDefBuilder("BufferToImage", "WinogradFilterTransformTest")
// .Input("A")
// .AddIntArg("buffer_type", kernels::WINOGRAD_FILTER)
// .Output("OutputImage")
// .Finalize(net.NewOperatorDef());
//
// // Run on opencl
// net.RunOp(DeviceType::OPENCL);
//
// ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "OPENCLOutput",
// kernels::BufferType::WINOGRAD_FILTER);
// ExpectTensorNear<float>(*(C_tensor.get()), *net.GetOutput("OPENCLOutput"), 1e-4);
//}
//
//
//TEST_F(WinogradTransformOpTest, WinogradInverseTransform) {
// srand(time(NULL));
//
// // generate random input
// index_t n = 7;
// index_t out_height = 59;
// index_t out_width = 69;
// index_t out_chan = 31;
//
// index_t p = n * ((out_height + 1) / 2) * ((out_width + 1) / 2);
//
// const std::string A_file = "/data/local/tmp/test/gemm";
// const std::string C_file = "/data/local/tmp/test/res";
// const std::vector<index_t> A_shape = {16, out_chan, p, 1};
// const int A_size = std::accumulate(A_shape.begin(), A_shape.end(), 1, std::multiplies<int>());
// const std::vector<index_t> C_shape = {n, out_height, out_width, out_chan};
// const int C_size = std::accumulate(C_shape.begin(), C_shape.end(), 1, std::multiplies<int>());
//
// std::vector<float> A_data(A_size, 0.0);
// std::ifstream in_file(A_file, std::ios::in | std::ios::binary);
// if (in_file.is_open()) {
// in_file.read(reinterpret_cast<char *>(A_data.data()),
// A_size * sizeof(float));
// in_file.close();
// } else {
// VLOG(0) << "open A file failed";
// }
// auto C_tensor = unique_ptr<Tensor>(new Tensor(GetDeviceAllocator(DeviceType::OPENCL),
// DataTypeToEnum<float>::v()));
// C_tensor->Resize(C_shape);
// std::vector<float> C_data(C_size, 0.0);
// std::ifstream C_in_file(C_file, std::ios::in | std::ios::binary);
// if (C_in_file.is_open()) {
// C_in_file.read(reinterpret_cast<char *>(C_data.data()),
// C_size * sizeof(float));
// C_in_file.close();
// Tensor::MappingGuard C_mapper(C_tensor.get());
// float *batch_ptr = C_tensor->mutable_data<float>();
// MACE_CHECK(static_cast<size_t>(C_tensor->size()) ==
// C_data.size());
// memcpy(batch_ptr, C_data.data(), C_data.size() * sizeof(float));
// } else {
// VLOG(0) << "open C file failed";
// }
// // Construct graph
// OpsTestNet net;
// // Add input data
// net.AddInputFromArray<DeviceType::OPENCL, float>(
// "A", A_shape, A_data);
//
// // Run on opencl
// BufferToImage<DeviceType::OPENCL, float>(net, "A", "AImage",
// kernels::BufferType::IN_OUT_HEIGHT);
//
// OpDefBuilder("WinogradInverseTransform", "WinogradInverseTransformTest")
// .Input("AImage")
// .AddIntArg("batch", n)
// .AddIntArg("height", out_height)
// .AddIntArg("width", out_width)
// .Output("OutputImage")
// .Finalize(net.NewOperatorDef());
//
// // Run on opencl
// net.RunOp(DeviceType::OPENCL);
// net.Sync();
//
// ImageToBuffer<DeviceType::OPENCL, float>(net, "OutputImage", "OPENCLOutput",
// kernels::BufferType::IN_OUT_CHANNEL);
// ExpectTensorNear<float>(*(C_tensor.get()), *net.GetOutput("OPENCLOutput"), 1e-4);
//}
void TransposeFilter(const std::vector<float> &input,
const std::vector<index_t> &input_shape,
......@@ -327,8 +122,19 @@ void WinogradConvolution(const index_t batch,
}
TEST_F(WinogradTransformOpTest, Convolution) {
WinogradConvolution<DeviceType::OPENCL, float>(1, 64, 64, 32, 32, Padding::VALID);
TEST_F(WinogradConvlutionTest, AlignedConvolution) {
WinogradConvolution<DeviceType::OPENCL, float>(1, 32, 32, 32, 16, Padding::VALID);
WinogradConvolution<DeviceType::OPENCL, float>(1, 32, 32, 32, 16, Padding::SAME);
}
TEST_F(WinogradConvlutionTest, UnAlignedConvolution) {
WinogradConvolution<DeviceType::OPENCL, float>(1, 61, 67, 31, 37, Padding::VALID);
WinogradConvolution<DeviceType::OPENCL, float>(1, 61, 67, 37, 31, Padding::SAME);
}
TEST_F(WinogradConvlutionTest, BatchConvolution) {
WinogradConvolution<DeviceType::OPENCL, float>(3, 64, 64, 32, 32, Padding::VALID);
WinogradConvolution<DeviceType::OPENCL, float>(5, 61, 67, 37, 31, Padding::SAME);
}
}
......@@ -54,7 +54,6 @@ static void BMWinogradTransform(
BM_WINOGRAD_TRANSFORM(1, 16, 16, 128, half);
BM_WINOGRAD_TRANSFORM(1, 64, 64, 128, half);
BM_WINOGRAD_TRANSFORM(1, 128, 128, 128, half);
BM_WINOGRAD_TRANSFORM(1, 256, 256, 32, half);
template <DeviceType D, typename T>
static void BMWinogradInverseTransform(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册