提交 bcd624f8 编写于 作者: L liuqi

Finish winograd convolution algorithm: support fused convolution.

上级 4deb3c69
......@@ -77,7 +77,6 @@ extern void Register_Pooling(OperatorRegistry *op_registry);
extern void Register_ResizeBilinear(OperatorRegistry *op_registry);
extern void Register_Softmax(OperatorRegistry *op_registry);
extern void Register_SpaceToBatchND(OperatorRegistry *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry);
extern void Register_GEMM(OperatorRegistry *op_registry);
extern void Register_WinogradTransform(OperatorRegistry *op_registry);
extern void Register_WinogradInverseTransform(OperatorRegistry *op_registry);
......@@ -101,7 +100,6 @@ OperatorRegistry::OperatorRegistry() {
Register_ResizeBilinear(this);
Register_Softmax(this);
Register_SpaceToBatchND(this);
Register_FoldedBatchNorm(this);
Register_GEMM(this);
Register_WinogradTransform(this);
Register_WinogradInverseTransform(this);
......
......@@ -19,7 +19,7 @@ class Registry {
void Register(const SrcType &key, Creator creator) {
VLOG(2) << "Registering: " << key;
std::lock_guard<std::mutex> lock(register_mutex_);
MACE_CHECK(registry_.count(key) == 0, "Key already registered.");
MACE_CHECK(registry_.count(key) == 0, "Key already registered: ", key);
registry_[key] = creator;
}
......
......@@ -107,17 +107,35 @@ __kernel void winograd_transform_2x2(__read_only image2d_t input,
}
__kernel void winograd_inverse_transform_2x2(__read_only image2d_t input,
#ifdef BIAS
__read_only image2d_t bias, /* cout%4 * cout/4 */
#endif
__write_only image2d_t output,
__private const int out_height,
__private const int out_width,
__private const int round_hw,
__private const int round_w) {
__private const int round_w,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha) {
const int width_idx = get_global_id(0);
const int height_idx = get_global_id(1);
const int out_channel = get_global_size(1);
int width = width_idx;
int height = height_idx;
const int batch = width_idx / round_hw;
int t = width_idx % round_hw;
const int out_height_idx = (t / round_w) << 1;
const int out_width_idx = (t % round_w) << 1;
const int out_chan_idx = height_idx;
const int coord_x = mad24(out_chan_idx, out_width, out_width_idx);
const int coord_y = mad24(batch, out_height, out_height_idx);
#ifdef BIAS
DATA_TYPE4 bias_value =
READ_IMAGET(bias, SAMPLER, (int2)(out_chan_idx, 0));
#endif
DATA_TYPE4 in0[4], in1[4], in2[4], in3[4];
#pragma unroll
......@@ -157,13 +175,20 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input,
in1[0] = in1[0] + in1[1] + in1[2];
in1[1] = in1[1] - in1[2] - in1[3];
const int batch = width_idx / round_hw;
int t = width_idx % round_hw;
const int out_height_idx = (t / round_w) << 1;
const int out_width_idx = (t % round_w) << 1;
const int out_chan_idx = height_idx;
const int coord_x = mad24(out_chan_idx, out_width, out_width_idx);
const int coord_y = mad24(batch, out_height, out_height_idx);
#ifdef BIAS
in0[0] += bias_value;
in0[1] += bias_value;
in1[0] += bias_value;
in1[1] += bias_value;
#endif
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
in0[0] = do_activation(in0[0], relux_max_limit, prelu_alpha);
in0[1] = do_activation(in0[1], relux_max_limit, prelu_alpha);
in1[0] = do_activation(in1[0], relux_max_limit, prelu_alpha);
in1[1] = do_activation(in1[1], relux_max_limit, prelu_alpha);
#endif
WRITE_IMAGET(output, (int2)(coord_x, coord_y), in0[0]);
......
......@@ -109,6 +109,7 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i
template<typename T>
void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input_tensor,
const Tensor *bias,
Tensor *output_tensor,
StatsFuture *future) {
std::vector<index_t> output_shape = {batch_, height_, width_, input_tensor->dim(1)};
......@@ -121,10 +122,29 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te
built_options.emplace("-Dwinograd_inverse_transform_2x2=" + obfuscated_kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
if ((input_tensor->dim(1) % 4 == 0 || input_tensor->dim(0) == 1) &&
input_tensor->dim(2) % 4 == 0) {
built_options.emplace("-DDIVISIBLE_FOUR");
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
switch (activation_) {
case NOOP:
break;
case RELU:
built_options.emplace("-DUSE_RELU");
break;
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
case SIGMOID:
built_options.emplace("-DUSE_SIGMOID");
break;
defeult:
LOG(FATAL) << "Unknown activation type: " << activation_;
}
auto runtime = OpenCLRuntime::Global();
auto wino_kernel = runtime->BuildKernel("winograd_transform",
obfuscated_kernel_name,
......@@ -134,11 +154,16 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te
const uint32_t round_w = (width_ + 1) / 2;
uint32_t idx = 0;
wino_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(input_tensor->buffer())));
if (bias != nullptr) {
wino_kernel.setArg(idx++, *(static_cast<const cl::Image2D *>(bias->buffer())));
}
wino_kernel.setArg(idx++, *(static_cast<cl::Image2D *>(output_tensor->buffer())));
wino_kernel.setArg(idx++, static_cast<uint32_t>(output_shape[1]));
wino_kernel.setArg(idx++, static_cast<uint32_t>(output_shape[2]));
wino_kernel.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
wino_kernel.setArg(idx++, static_cast<uint32_t>(round_w));
wino_kernel.setArg(idx++, relux_max_limit_);
wino_kernel.setArg(idx++, prelu_alpha_);
const size_t gws[2] = {static_cast<size_t>(input_tensor->dim(2)),
static_cast<size_t>(RoundUpDiv4(input_tensor->dim(1)))};
......
......@@ -8,6 +8,7 @@
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/kernels/conv_pool_2d_util.h"
#include "mace/kernels/activation.h"
namespace mace {
namespace kernels {
......@@ -47,22 +48,37 @@ struct WinogradTransformFunctor<DeviceType::OPENCL, T> : WinogradTransformFuncto
struct WinogradInverseTransformFunctorBase {
WinogradInverseTransformFunctorBase(const int batch,
const int height,
const int width)
: batch_(batch), height_(height), width_(width) {}
const int width,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
: batch_(batch),
height_(height),
width_(width),
activation_(activation),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
const int batch_;
const int height_;
const int width_;
const ActivationType activation_;
const float relux_max_limit_;
const float prelu_alpha_;
};
template<DeviceType D, typename T>
struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase {
WinogradInverseTransformFunctor(const int batch,
const int height,
const int width)
: WinogradInverseTransformFunctorBase(batch, height, width) {}
const int width,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
: WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit, prelu_alpha) {}
void operator()(const Tensor *input,
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
MACE_NOT_IMPLEMENTED;
......@@ -74,10 +90,14 @@ template<typename T>
struct WinogradInverseTransformFunctor<DeviceType::OPENCL, T> : WinogradInverseTransformFunctorBase {
WinogradInverseTransformFunctor(const int batch,
const int height,
const int width)
: WinogradInverseTransformFunctorBase(batch, height, width) {}
const int width,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
: WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit, prelu_alpha) {}
void operator()(const Tensor *input,
const Tensor *bias,
Tensor *output,
StatsFuture *future);
};
......
......@@ -20,7 +20,7 @@ static void ReluBenchmark(
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Activation", "ReluBM")
.Input("InputImage")
......@@ -79,7 +79,7 @@ static void ReluxBenchmark(
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Activation", "ReluxBM")
.Input("InputImage")
......@@ -140,7 +140,7 @@ static void PreluBenchmark(
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Activation", "PreluBM")
.Input("InputImage")
......@@ -201,7 +201,7 @@ static void TanhBenchmark(
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Activation", "TanhBM")
.Input("InputImage")
......@@ -260,7 +260,7 @@ static void SigmoidBenchmark(
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Activation", "SigmoidBM")
.Input("InputImage")
......
......@@ -20,7 +20,7 @@ void TestSimpleRelu() {
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Activation", "ReluTest")
.Input("InputImage")
......@@ -33,7 +33,7 @@ void TestSimpleRelu() {
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("Activation", "ReluTest")
.Input("Input")
......@@ -70,7 +70,7 @@ void TestUnalignedSimpleRelu() {
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Activation", "ReluTest")
.Input("InputImage")
......@@ -83,7 +83,7 @@ void TestUnalignedSimpleRelu() {
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("Activation", "ReluTest")
.Input("Input")
......@@ -125,7 +125,7 @@ void TestSimpleRelux() {
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Activation", "ReluxTest")
.Input("InputImage")
......@@ -139,7 +139,7 @@ void TestSimpleRelux() {
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("Activation", "ReluxTest")
.Input("Input")
......@@ -179,7 +179,7 @@ void TestSimpleReluRelux() {
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Activation", "ReluxTest")
.Input("InputImage")
......@@ -193,7 +193,7 @@ void TestSimpleReluRelux() {
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("Activation", "ReluxTest")
.Input("Input")
......@@ -237,7 +237,7 @@ void TestSimplePrelu() {
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Activation", "PreluTest")
.Input("InputImage")
......@@ -251,7 +251,7 @@ void TestSimplePrelu() {
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("Activation", "PreluTest")
.Input("Input")
......@@ -293,7 +293,7 @@ void TestSimpleTanh() {
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Activation", "TanhTest")
.Input("InputImage")
......@@ -306,7 +306,7 @@ void TestSimpleTanh() {
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("Activation", "TanhTest")
.Input("Input")
......@@ -348,7 +348,7 @@ void TestSimpleSigmoid() {
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("Activation", "SigmoidTest")
.Input("InputImage")
......@@ -361,7 +361,7 @@ void TestSimpleSigmoid() {
// Transfer output
ImageToBuffer<D, float>(net, "OutputImage", "Output",
kernels::BufferType::IN_OUT);
kernels::BufferType::IN_OUT_CHANNEL);
} else {
OpDefBuilder("Activation", "SigmoidTest")
.Input("Input")
......
......@@ -96,17 +96,18 @@ static void Conv2d(int iters,
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, OPENCL);
// ICNet
//BM_CONV_2D(1, 512, 15, 15, 1, 1, 1, VALID, 1024, half);
BM_CONV_2D(1, 512, 15, 15, 1, 1, 1, VALID, 1024, half);
//// SNPE GPU ExecutionDuration = 448us, % ALU Utilization = 105
//BM_CONV_2D(1, 64, 60, 60, 1, 1, 1, VALID, 128, half);
BM_CONV_2D(1, 64, 60, 60, 1, 1, 1, VALID, 128, half);
//// SNPE GPU ExecutionDuration = 258us, % ALU Utilization = 108
//BM_CONV_2D(1, 32, 60, 60, 1, 1, 1, VALID, 128, half);
//
//BM_CONV_2D(1, 128, 60, 60, 3, 3, 1, VALID, 128, half);
BM_CONV_2D(1, 32, 60, 60, 1, 1, 1, VALID, 128, half);
BM_CONV_2D(1, 128, 60, 60, 3, 3, 1, VALID, 128, half);
//// SNPE GPU ExecutionDuration = 506us, % ALU Utilization = 106.8
//BM_CONV_2D(1, 32, 60, 60, 3, 3, 1, SAME, 32, half);
//BM_CONV_2D(1, 3, 512, 512, 7, 7, 2, SAME, 64, half);
//BM_CONV_2D(1, 512, 64, 64, 1, 1, 1, SAME, 256, half);
BM_CONV_2D(1, 32, 60, 60, 3, 3, 1, SAME, 32, half);
BM_CONV_2D(1, 3, 512, 512, 7, 7, 2, SAME, 64, half);
BM_CONV_2D(1, 512, 64, 64, 1, 1, 1, SAME, 256, half);
BM_CONV_2D(1, 128, 16, 16, 3, 3, 1, VALID, 32, half);
BM_CONV_2D(1, 128, 64, 64, 3, 3, 1, VALID, 32, half);
BM_CONV_2D(1, 128, 128, 128, 3, 3, 1, VALID, 32, half);
......
......@@ -7,10 +7,11 @@
namespace mace {
void Register_FoldedBatchNorm(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
REGISTER_OPERATOR(op_registry,
OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
FoldedBatchNormOp<DeviceType::CPU, float>);
#if MACE_ENABLE_NEON
......@@ -21,16 +22,18 @@ void Register_FoldedBatchNorm(OperatorRegistry *op_registry) {
FoldedBatchNormOp<DeviceType::NEON, float>);
#endif // MACE_ENABLE_NEON
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
REGISTER_OPERATOR(op_registry,
OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
FoldedBatchNormOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
REGISTER_OPERATOR(op_registry,
OpKeyBuilder("FoldedBatchNorm")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
FoldedBatchNormOp<DeviceType::OPENCL, half>);
}
......
......@@ -11,7 +11,6 @@ namespace mace {
class WinogradConvlutionTest : public OpsTestBase {};
void TransposeFilter(const std::vector<float> &input,
const std::vector<index_t> &input_shape,
std::vector<float> &output) {
......@@ -48,14 +47,18 @@ void WinogradConvolution(const index_t batch,
GenerateRandomRealTypeData<float>(filter_shape, filter_data);
net.AddRandomInput<D, float>("Input", {batch, height, width, in_channels});
net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data);
net.AddRandomInput<D, T>("Bias", {out_channels});
BufferToImage<D, T>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(net, "Filter", "FilterImage",
kernels::BufferType::FILTER);
BufferToImage<D, T>(net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage")
.Input("FilterImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", padding)
......@@ -102,6 +105,7 @@ void WinogradConvolution(const index_t batch,
// Inverse transform
OpDefBuilder("WinogradInverseTransform", "WinogradInverseTransformTest")
.Input("WinoGemm")
.Input("BiasImage")
.AddIntArg("batch", batch)
.AddIntArg("height", output_shape[1])
.AddIntArg("width", output_shape[2])
......@@ -113,7 +117,7 @@ void WinogradConvolution(const index_t batch,
net.Sync();
ImageToBuffer<D, float>(net, "WinoOutputImage", "WinoOutput",
kernels::BufferType::IN_OUT_CHANNEL);
kernels::BufferType::IN_OUT_CHANNEL);
if (DataTypeToEnum<T>::value == DataType::DT_HALF) {
ExpectTensorNear<float>(expected, *net.GetOutput("WinoOutput"), 1e-1);
} else {
......@@ -121,7 +125,6 @@ void WinogradConvolution(const index_t batch,
}
}
TEST_F(WinogradConvlutionTest, AlignedConvolution) {
WinogradConvolution<DeviceType::OPENCL, float>(1, 32, 32, 32, 16, Padding::VALID);
WinogradConvolution<DeviceType::OPENCL, float>(1, 32, 32, 32, 16, Padding::SAME);
......
......@@ -9,6 +9,7 @@
#include "mace/core/operator.h"
#include "mace/kernels/winograd_transform.h"
#include "mace/kernels/activation.h"
namespace mace {
......@@ -19,13 +20,18 @@ class WinogradInverseTransformOp : public Operator<D, T> {
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("batch", 1),
OperatorBase::GetSingleArgument<int>("height", 0),
OperatorBase::GetSingleArgument<int>("width", 0)) {}
OperatorBase::GetSingleArgument<int>("width", 0),
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
const Tensor *bias = this->InputSize() == 2 ? this->Input(BIAS) : nullptr;
Tensor *output_tensor = this->Output(OUTPUT);
functor_(input_tensor, output_tensor, future);
functor_(input_tensor, bias, output_tensor, future);
return true;
}
......@@ -33,7 +39,7 @@ class WinogradInverseTransformOp : public Operator<D, T> {
kernels::WinogradInverseTransformFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT);
OP_INPUT_TAGS(INPUT, BIAS);
OP_OUTPUT_TAGS(OUTPUT);
};
......
......@@ -10,15 +10,6 @@ licenses(["notice"]) # Apache 2.0
load("@com_google_protobuf//:protobuf.bzl", "py_proto_library")
py_proto_library(
name = "mace_py",
srcs = ["mace.proto"],
default_runtime = "@com_google_protobuf//:protobuf_python",
protoc = "@com_google_protobuf//:protoc",
srcs_version = "PY2AND3",
deps = ["@com_google_protobuf//:protobuf_python"],
)
py_proto_library(
name = "caffe_py",
srcs = ["caffe.proto"],
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册