提交 d548b461 编写于 作者: L Liangliang He

Merge branch 'conv-pad' into 'master'

Fix arbitrary pad and PReLU bug

See merge request !259
......@@ -46,8 +46,7 @@ void DoActivation(const T *input_ptr,
T *output_ptr,
const index_t size,
const ActivationType type,
const float relux_max_limit,
const float prelu_alpha) {
const float relux_max_limit) {
MACE_CHECK(DataTypeToEnum<T>::value != DataType::DT_HALF);
switch (type) {
......@@ -66,17 +65,6 @@ void DoActivation(const T *input_ptr,
static_cast<T>(relux_max_limit));
}
break;
case PRELU:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
T in = input_ptr[i];
if (in < 0) {
output_ptr[i] = in * prelu_alpha;
} else {
output_ptr[i] = in;
}
}
break;
case TANH:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
......@@ -95,45 +83,70 @@ void DoActivation(const T *input_ptr,
}
}
template <typename T>
void PReLUActivation(const T *input_ptr,
const index_t size,
const index_t input_chan,
const T *alpha_ptr,
T *output_ptr) {
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
const index_t chan_idx = i % input_chan;
T in = input_ptr[i];
if (in < 0) {
output_ptr[i] = in * alpha_ptr[chan_idx];
} else {
output_ptr[i] = in;
}
}
}
template <DeviceType D, typename T>
class ActivationFunctor {
public:
ActivationFunctor(ActivationType type, T relux_max_limit, T prelu_alpha)
ActivationFunctor(ActivationType type, T relux_max_limit)
: activation_(type),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
relux_max_limit_(relux_max_limit){}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) {
void operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future) {
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
DoActivation(input_ptr, output_ptr, output->size(), activation_, relux_max_limit_,
prelu_alpha_);
if (activation_ == PRELU) {
const T *alpha_ptr = alpha == nullptr ? nullptr : alpha->data<T>();
PReLUActivation(input_ptr, output->size(), input->dim(3), alpha_ptr, output_ptr);
} else {
DoActivation(input_ptr, output_ptr, output->size(), activation_, relux_max_limit_);
}
}
private:
ActivationType activation_;
T relux_max_limit_;
T prelu_alpha_;
};
template <>
void ActivationFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input, Tensor *output, StatsFuture *future);
const Tensor *input, const Tensor *alpha, Tensor *output, StatsFuture *future);
template <typename T>
class ActivationFunctor<DeviceType::OPENCL, T> {
public:
ActivationFunctor(ActivationType type, T relux_max_limit, T prelu_alpha)
ActivationFunctor(ActivationType type, T relux_max_limit)
: activation_(type),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
relux_max_limit_(relux_max_limit){}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future);
void operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future);
private:
ActivationType activation_;
T relux_max_limit_;
T prelu_alpha_;
cl::Kernel kernel_;
std::string tuning_key_prefix_;
};
......
......@@ -21,27 +21,23 @@ namespace kernels {
struct BatchNormFunctorBase {
BatchNormFunctorBase(bool folded_constant,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
const float relux_max_limit)
: folded_constant_(folded_constant),
activation_(activation),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
relux_max_limit_(relux_max_limit){}
const bool folded_constant_;
const ActivationType activation_;
const float relux_max_limit_;
const float prelu_alpha_;
};
template <DeviceType D, typename T>
struct BatchNormFunctor : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
const float relux_max_limit)
: BatchNormFunctorBase(
folded_constant, activation, relux_max_limit, prelu_alpha) {}
folded_constant, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *scale,
......@@ -132,7 +128,7 @@ struct BatchNormFunctor : BatchNormFunctorBase {
}
}
DoActivation(output_ptr, output_ptr, output->NumElements(), activation_,
relux_max_limit_, prelu_alpha_);
relux_max_limit_);
}
};
......@@ -150,10 +146,9 @@ template <typename T>
struct BatchNormFunctor<DeviceType::OPENCL, T> : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
const float relux_max_limit)
: BatchNormFunctorBase(
folded_constant, activation, relux_max_limit, prelu_alpha) {}
folded_constant, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *scale,
const Tensor *offset,
......
......@@ -182,15 +182,13 @@ struct Conv2dFunctorBase {
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
const float relux_max_limit)
: strides_(strides),
padding_type_(padding_type),
paddings_(paddings),
dilations_(dilations),
activation_(activation),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
relux_max_limit_(relux_max_limit){}
const int *strides_; // [stride_h, stride_w]
const Padding padding_type_;
......@@ -198,7 +196,6 @@ struct Conv2dFunctorBase {
const int *dilations_; // [dilation_h, dilation_w]
const ActivationType activation_;
const float relux_max_limit_;
const float prelu_alpha_;
};
template <DeviceType D, typename T>
......@@ -208,15 +205,13 @@ struct Conv2dFunctor : Conv2dFunctorBase {
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
const float relux_max_limit)
: Conv2dFunctorBase(strides,
padding_type,
paddings,
dilations,
activation,
relux_max_limit,
prelu_alpha) {}
relux_max_limit) {}
void operator()(const Tensor *input, // NHWC
const Tensor *filter, // HWOI
......@@ -229,11 +224,14 @@ struct Conv2dFunctor : Conv2dFunctorBase {
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), filter->shape().data(), dilations_, strides_,
padding_type_, output_shape.data(), paddings.data());
if (!paddings_.empty()) {
if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), filter->shape().data(), dilations_, strides_,
padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_;
CalcOutputSize(input->shape().data(), filter->shape().data(), paddings_.data(),
dilations_, strides_, RoundType::FLOOR, output_shape.data());
}
output->Resize(output_shape);
......@@ -619,7 +617,7 @@ struct Conv2dFunctor : Conv2dFunctorBase {
}
}
DoActivation(output_data, output_data, output->NumElements(), activation_,
relux_max_limit_, prelu_alpha_);
relux_max_limit_);
}
};
......@@ -637,15 +635,13 @@ struct Conv2dFunctor<DeviceType::OPENCL, T> : Conv2dFunctorBase {
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
const float relux_max_limit)
: Conv2dFunctorBase(strides,
padding_type,
paddings,
dilations,
activation,
relux_max_limit,
prelu_alpha) {}
relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *filter,
......
......@@ -135,6 +135,44 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
output_shape[3] = output_channels;
}
void CalcOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // HWOI
const int *padding_size,
const int *dilations,
const int *strides,
const RoundType round_type,
index_t *output_shape) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
/*
* Convlution arithmetic:
* o = floor((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1
* Pooling arithmetic:
* o = ceil((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
output_shape[0] = input_shape[0];
if (round_type == FLOOR) {
output_shape[1] = static_cast<index_t>(std::floor(1.0 * (input_shape[1] + padding_size[0]
- filter_shape[0] - (filter_shape[0] - 1) * (dilations[0] - 1)) / strides[0]) + 1);
output_shape[2] = static_cast<index_t>(std::floor(1.0 * (input_shape[2] + padding_size[1]
- filter_shape[1] - (filter_shape[1] - 1) * (dilations[1] - 1)) / strides[1]) + 1);
} else {
output_shape[1] = static_cast<index_t>(std::ceil(1.0 * (input_shape[1] + padding_size[0]
- filter_shape[0] - (filter_shape[0] - 1) * (dilations[0] - 1)) / strides[0]) + 1);
output_shape[2] = static_cast<index_t>(std::ceil(1.0 * (input_shape[2] + padding_size[1]
- filter_shape[1] - (filter_shape[1] - 1) * (dilations[1] - 1)) / strides[1]) + 1);
}
output_shape[3] = filter_shape[2];
}
void CalPaddingSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *dilations,
......
......@@ -15,6 +15,11 @@ enum Padding {
FULL = 2, // Pads with one less than the filter size on both sides
};
enum RoundType{
FLOOR = 0,
CEIL = 1,
};
namespace kernels {
void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
......@@ -33,6 +38,14 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NCHW
index_t *output_shape,
int *padding_size);
void CalcOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // HWOI
const int *padding_size,
const int *dilations,
const int *strides,
const RoundType round_type,
index_t *output_shape);
void CalPaddingSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *dilations,
......
......@@ -241,15 +241,13 @@ struct DepthwiseConv2dFunctorBase {
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
const float relux_max_limit)
: strides_(strides),
padding_type_(padding_type),
paddings_(paddings),
dilations_(dilations),
activation_(activation),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
relux_max_limit_(relux_max_limit){}
const int *strides_; // [stride_h, stride_w]
const Padding padding_type_;
......@@ -257,7 +255,6 @@ struct DepthwiseConv2dFunctorBase {
const int *dilations_; // [dilation_h, dilation_w]
const ActivationType activation_;
const float relux_max_limit_;
const float prelu_alpha_;
};
template <DeviceType D, typename T>
......@@ -267,15 +264,13 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
const float relux_max_limit)
: DepthwiseConv2dFunctorBase(strides,
padding_type,
paddings,
dilations,
activation,
relux_max_limit,
prelu_alpha) {}
relux_max_limit) {}
void operator()(const Tensor *input, // NHWC
const Tensor *filter, // HWIM
......@@ -295,11 +290,14 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), fake_filter_shape.data(), dilations_, strides_,
padding_type_, output_shape.data(), paddings.data());
if (!paddings_.empty()) {
if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), fake_filter_shape.data(), dilations_, strides_,
padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_;
CalcOutputSize(input->shape().data(), fake_filter_shape.data(), paddings_.data(),
dilations_, strides_, RoundType::FLOOR, output_shape.data());
}
auto input_shape = fake_filter_shape;
output->Resize(output_shape);
......@@ -405,7 +403,7 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
output_ptr = output->mutable_data<T>();
DoActivation(output_ptr, output_ptr, output->NumElements(), activation_,
relux_max_limit_, prelu_alpha_);
relux_max_limit_);
}
};
......@@ -425,15 +423,13 @@ struct DepthwiseConv2dFunctor<DeviceType::OPENCL, T>
const std::vector<int> &paddings,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
const float relux_max_limit)
: DepthwiseConv2dFunctorBase(strides,
padding_type,
paddings,
dilations,
activation,
relux_max_limit,
prelu_alpha) {}
relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *filter,
......
......@@ -15,23 +15,19 @@ namespace kernels {
struct FullyConnectedBase {
FullyConnectedBase(const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
const float relux_max_limit)
: activation_(activation),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
relux_max_limit_(relux_max_limit){}
const ActivationType activation_;
const float relux_max_limit_;
const float prelu_alpha_;
};
template<DeviceType D, typename T>
struct FullyConnectedFunctor : FullyConnectedBase {
FullyConnectedFunctor(const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha) :
FullyConnectedBase(activation, relux_max_limit, prelu_alpha) {}
const float relux_max_limit) :
FullyConnectedBase(activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *weight,
......@@ -70,16 +66,15 @@ struct FullyConnectedFunctor : FullyConnectedBase {
}
DoActivation(output_ptr, output_ptr, output->NumElements(), activation_,
relux_max_limit_, prelu_alpha_);
relux_max_limit_);
}
};
template<typename T>
struct FullyConnectedFunctor<DeviceType::OPENCL, T> : FullyConnectedBase {
FullyConnectedFunctor(const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha) :
FullyConnectedBase(activation, relux_max_limit, prelu_alpha) {}
const float relux_max_limit) :
FullyConnectedBase(activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *weight,
......
......@@ -14,6 +14,7 @@ namespace kernels {
template <typename T>
void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
......@@ -60,8 +61,10 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
runtime->BuildKernel("activation", kernel_name, built_options);
int idx = 0;
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
if (activation_ == PRELU) {
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(alpha->buffer())));
}
kernel_.setArg(idx++, static_cast<float>(relux_max_limit_));
kernel_.setArg(idx++, static_cast<float>(prelu_alpha_));
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
}
......
......@@ -50,9 +50,6 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
......@@ -79,7 +76,6 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
}
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
kernel_.setArg(idx++, relux_max_limit_);
kernel_.setArg(idx++, prelu_alpha_);
}
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
......
#include <common.h>
__kernel void activation(__read_only image2d_t input,
#ifdef USE_PRELU
__read_only image2d_t alpha,
#endif
__private const float relux_max_limit,
__private const float prelu_alpha,
__write_only image2d_t output) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
......@@ -11,7 +13,12 @@ __kernel void activation(__read_only image2d_t input,
const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 out = do_activation(in, relux_max_limit, prelu_alpha);
#ifdef USE_PRELU
DATA_TYPE4 prelu_alpha = READ_IMAGET(alpha, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 out = do_activation(in, prelu_alpha, relux_max_limit);
#else
DATA_TYPE4 out = do_activation(in, relux_max_limit);
#endif
WRITE_IMAGET(output, (int2)(pos, hb), out);
}
......@@ -9,8 +9,7 @@ __kernel void batch_norm(__read_only image2d_t input,
__private const float epsilon,
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const float prelu_alpha) {
__private const float relux_max_limit) {
const int ch_blk = get_global_id(0);
const int w = get_global_id(1);
const int hb = get_global_id(2);
......@@ -35,8 +34,8 @@ __kernel void batch_norm(__read_only image2d_t input,
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 out = mad(in, bn_scale, bn_offset);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
out = do_activation(out, relux_max_limit, prelu_alpha);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out = do_activation(out, relux_max_limit);
#endif
WRITE_IMAGET(output, (int2)(pos, hb), out);
......
......@@ -22,8 +22,10 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP |
inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
__private const float relux_max_limit,
__private const float prelu_alpha) {
#ifdef USE_PRELU
DATA_TYPE4 prelu_alpha,
#endif
__private const float relux_max_limit) {
DATA_TYPE4 out;
#ifdef USE_RELU
out = fmax(in, 0);
......
......@@ -7,7 +7,6 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const float prelu_alpha,
__private const int in_height,
__private const int in_width,
__private const int in_ch_blks,
......@@ -112,11 +111,11 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
}
}
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha);
out1 = do_activation(out1, relux_max_limit, prelu_alpha);
out2 = do_activation(out2, relux_max_limit, prelu_alpha);
out3 = do_activation(out3, relux_max_limit, prelu_alpha);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit);
out1 = do_activation(out1, relux_max_limit);
out2 = do_activation(out2, relux_max_limit);
out3 = do_activation(out3, relux_max_limit);
#endif
const int out_x_base = mul24(out_ch_blk, out_width);
......
......@@ -7,12 +7,12 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const float prelu_alpha,
__private const int in_height,
__private const int in_width,
__private const int in_ch_blks,
__private const int height,
__private const int width) {
__private const int width,
__private const int stride) {
const int out_ch_blk = get_global_id(0);
const int out_w_blk = get_global_id(1);
const int out_w_blks = get_global_size(1);
......@@ -31,19 +31,12 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
#endif
int4 w;
#if STRIDE == 1
w.x = out_w_blk;
w.y = w.x + out_w_blks;
w.z = w.y + out_w_blks;
w.w = w.z + out_w_blks;
int out_hb_idx = (out_hb % height);
#elif STRIDE == 2
w.x = out_w_blk << 1;
w.y = (out_w_blk + out_w_blks) << 1;
w.z = (out_w_blk + (out_w_blks << 1)) << 1;
w.w = (out_w_blk + (out_w_blks << 1) + out_w_blks) << 1;
int out_hb_idx = (out_hb % height) << 1;
#endif
int in_width_stride = mul24(out_w_blks, stride);
w.x = mul24(out_w_blk, stride);
w.y = w.x + in_width_stride;
w.z = w.y + in_width_stride;
w.w = w.z + in_width_stride;
int out_hb_idx = mul24((out_hb % height), stride);
w.x = select(w.x, INT_MIN, w.x >= in_width);
w.y = select(w.y, INT_MIN, w.y >= in_width);
......@@ -92,11 +85,11 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
filter_x_base += 4;
}
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha);
out1 = do_activation(out1, relux_max_limit, prelu_alpha);
out2 = do_activation(out2, relux_max_limit, prelu_alpha);
out3 = do_activation(out3, relux_max_limit, prelu_alpha);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit);
out1 = do_activation(out1, relux_max_limit);
out2 = do_activation(out2, relux_max_limit);
out3 = do_activation(out3, relux_max_limit);
#endif
const int out_x_base = mul24(out_ch_blk, width);
......
......@@ -7,12 +7,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const float prelu_alpha,
__private const int in_height,
__private const int in_width,
__private const int in_ch_blks,
__private const int out_height,
__private const int out_width,
__private const int stride,
__private const int padding_top,
__private const int padding_left,
__private const int dilation_h,
......@@ -38,21 +38,13 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
DATA_TYPE4 out4 = 0;
#endif
#if STRIDE == 1
int in_width0 = out_w_blk - padding_left;
int in_width1 = in_width0 + out_w_blks;
int in_width2 = in_width1 + out_w_blks;
int in_width3 = in_width2 + out_w_blks;
int in_width4 = in_width3 + out_w_blks;
const int height_idx = (out_hb % out_height) - padding_top;
#elif STRIDE == 2
int in_width0 = (out_w_blk << 1) - padding_left;
int in_width1 = ((out_w_blk + out_w_blks) << 1) - padding_left;
int in_width2 = ((out_w_blk + (out_w_blks << 1)) << 1) - padding_left;
int in_width3 = ((out_w_blk + (out_w_blks << 1) + out_w_blks) << 1) - padding_left;
int in_width4 = ((out_w_blk + (out_w_blks << 2)) << 1) - padding_left;
const int height_idx = ((out_hb % out_height) << 1) - padding_top;
#endif
int in_width_stride = mul24(out_w_blks, stride);
int in_width0 = mad24(out_w_blk, stride, -padding_left);
int in_width1 = in_width0 + in_width_stride;
int in_width2 = in_width1 + in_width_stride;
int in_width3 = in_width2 + in_width_stride;
int in_width4 = in_width3 + in_width_stride;
const int height_idx = mad24((out_hb % out_height), stride, -padding_top);
const int batch_idx = mul24((out_hb / out_height), in_height);
const int rounded_in_ch_x_3 = (rounded_in_ch << 1) + rounded_in_ch;
......@@ -127,12 +119,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
}
}
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha);
out1 = do_activation(out1, relux_max_limit, prelu_alpha);
out2 = do_activation(out2, relux_max_limit, prelu_alpha);
out3 = do_activation(out3, relux_max_limit, prelu_alpha);
out4 = do_activation(out4, relux_max_limit, prelu_alpha);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit);
out1 = do_activation(out1, relux_max_limit);
out2 = do_activation(out2, relux_max_limit);
out3 = do_activation(out3, relux_max_limit);
out4 = do_activation(out4, relux_max_limit);
#endif
const int out_x_base = mul24(out_ch_blk, out_width);
......
......@@ -8,7 +8,6 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h
#endif
__write_only image2d_t output,
__private const float relux_max_limit,
__private const float prelu_alpha,
__private const short in_height,
__private const short in_width,
__private const short in_ch_blks,
......@@ -103,11 +102,11 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h
in_hb_idx += dilation_h;
}
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha);
out1 = do_activation(out1, relux_max_limit, prelu_alpha);
out2 = do_activation(out2, relux_max_limit, prelu_alpha);
out3 = do_activation(out3, relux_max_limit, prelu_alpha);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit);
out1 = do_activation(out1, relux_max_limit);
out2 = do_activation(out2, relux_max_limit);
out3 = do_activation(out3, relux_max_limit);
#endif
const short out_x_base = mul24(out_ch_blk, out_width);
......@@ -134,7 +133,6 @@ __kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4
#endif
__write_only image2d_t output,
__private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha,
__private const short in_height,
__private const short in_width,
__private const short in_ch_blks,
......@@ -220,11 +218,11 @@ __kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4
in_hb_idx += 1;
}
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha);
out1 = do_activation(out1, relux_max_limit, prelu_alpha);
out2 = do_activation(out2, relux_max_limit, prelu_alpha);
out3 = do_activation(out3, relux_max_limit, prelu_alpha);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit);
out1 = do_activation(out1, relux_max_limit);
out2 = do_activation(out2, relux_max_limit);
out3 = do_activation(out3, relux_max_limit);
#endif
const short out_x_base = mul24(out_ch_blk, out_width);
......
......@@ -10,8 +10,7 @@ __kernel void fully_connected(__read_only image2d_t input,
__private const int input_height,
__private const int input_width,
__private const int input_channel,
__private const float relux_max_limit,
__private const float prelu_alpha) {
__private const float relux_max_limit) {
const int batch_idx = get_global_id(0);
const int out_blk_idx = get_global_id(1);
const int input_chan_blk = (input_channel + 3) >> 2;
......@@ -51,8 +50,8 @@ __kernel void fully_connected(__read_only image2d_t input,
input_coord.y++;
}
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
result = do_activation(result, relux_max_limit, prelu_alpha);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
result = do_activation(result, relux_max_limit);
#endif
WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result);
}
......@@ -115,8 +115,7 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input,
__private const int out_width,
__private const int round_hw,
__private const int round_w,
__private const float relux_max_limit,
__private const float prelu_alpha) {
__private const float relux_max_limit) {
const int width_idx = get_global_id(0);
const int height_idx = get_global_id(1);
const int out_channel = get_global_size(1);
......@@ -183,11 +182,11 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input,
#endif
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID)
in0[0] = do_activation(in0[0], relux_max_limit, prelu_alpha);
in0[1] = do_activation(in0[1], relux_max_limit, prelu_alpha);
in1[0] = do_activation(in1[0], relux_max_limit, prelu_alpha);
in1[1] = do_activation(in1[1], relux_max_limit, prelu_alpha);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
in0[0] = do_activation(in0[0], relux_max_limit);
in0[1] = do_activation(in0[1], relux_max_limit);
in1[0] = do_activation(in1[0], relux_max_limit);
in1[1] = do_activation(in1[1], relux_max_limit);
#endif
WRITE_IMAGET(output, (int2)(coord_x, coord_y), in0[0]);
......@@ -205,6 +204,4 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input,
WRITE_IMAGET(output, (int2)(coord_x + 1, coord_y + 1), in1[1]);
}
}
......@@ -17,7 +17,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
......@@ -31,7 +30,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
......@@ -45,7 +43,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future);
......@@ -60,7 +57,7 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
cl::Kernel *kernel,
const Tensor *input, const Tensor *filter, const Tensor *bias, const int stride,
const int *padding, const int *dilations, const ActivationType activation,
const float relux_max_limit, const float prelu_alpha, const DataType dt,
const float relux_max_limit, const DataType dt,
Tensor *output, StatsFuture *future);
// Selection matrix: kernel_size x stride_size
static const Conv2dOpenclFunction selector[5] =
......@@ -69,7 +66,6 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
index_t kernel_h = filter->dim(0);
index_t kernel_w = filter->dim(1);
if (!input->is_image() || strides_[0] != strides_[1] ||
((kernel_h == 1 || kernel_h == 3) && strides_[0] > 2) ||
(dilations_[0] > 1 && (strides_[0] > 1 || kernel_h == 1))) {
LOG(WARNING) << "OpenCL conv2d kernel with "
<< "filter" << kernel_h << "x" << kernel_w << ","
......@@ -82,11 +78,14 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), filter->shape().data(), dilations_, strides_,
padding_type_, output_shape.data(), paddings.data());
if (!paddings_.empty()) {
if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), filter->shape().data(), dilations_, strides_,
padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_;
CalcOutputSize(input->shape().data(), filter->shape().data(), paddings_.data(),
dilations_, strides_, RoundType::FLOOR, output_shape.data());
}
std::vector<size_t> output_image_shape;
......@@ -94,16 +93,13 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
output->ResizeImage(output_shape, output_image_shape);
if (kernel_h == kernel_w && kernel_h <= 5 &&
selector[kernel_h - 1] != nullptr &&
0 < strides_[0] && strides_[0] < 3 ) {
selector[kernel_h - 1] != nullptr) {
auto conv2d_func = selector[kernel_h - 1];
conv2d_func(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, activation_,
relux_max_limit_, prelu_alpha_, DataTypeToEnum<T>::value,
output, future);
relux_max_limit_, DataTypeToEnum<T>::value, output, future);
} else {
Conv2dOpencl(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_,
activation_, relux_max_limit_, prelu_alpha_,
DataTypeToEnum<T>::value, output, future);
activation_, relux_max_limit_, DataTypeToEnum<T>::value, output, future);
}
}
......
......@@ -19,7 +19,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
......@@ -44,7 +43,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
built_options.emplace("-Dconv_2d_1x1=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(MakeString("-DSTRIDE=", stride));
if (bias != nullptr) {
built_options.emplace("-DBIAS");
}
......@@ -57,9 +55,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
......@@ -87,12 +82,12 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
*(static_cast<const cl::Image2D *>(output->buffer())));
// FIXME handle flexable data type: half not supported
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, prelu_alpha);
kernel->setArg(idx++, static_cast<int>(input_height));
kernel->setArg(idx++, static_cast<int>(input_width));
kernel->setArg(idx++, static_cast<int>(input_channel_blocks));
kernel->setArg(idx++, static_cast<int>(height));
kernel->setArg(idx++, static_cast<int>(width));
kernel->setArg(idx++, stride);
}
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
......
......@@ -21,7 +21,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
......@@ -42,7 +41,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace(MakeString("-DSTRIDE=", stride));
switch (activation) {
case NOOP:
break;
......@@ -52,9 +50,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
......@@ -81,12 +76,12 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer())));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, prelu_alpha);
kernel->setArg(idx++, static_cast<int>(input->dim(1)));
kernel->setArg(idx++, static_cast<int>(input->dim(2)));
kernel->setArg(idx++, static_cast<int>(input_channel_blocks));
kernel->setArg(idx++, static_cast<int>(height));
kernel->setArg(idx++, static_cast<int>(width));
kernel->setArg(idx++, stride);
kernel->setArg(idx++, padding[0] / 2);
kernel->setArg(idx++, padding[1] / 2);
kernel->setArg(idx++, dilations[0]);
......
......@@ -21,7 +21,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
......@@ -42,7 +41,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace(MakeString("-DSTRIDE=", stride));
switch (activation) {
case NOOP:
break;
......@@ -52,9 +50,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
......@@ -81,7 +76,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer())));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, prelu_alpha);
kernel->setArg(idx++, static_cast<uint32_t>(input->dim(1)));
kernel->setArg(idx++, static_cast<uint32_t>(input->dim(2)));
kernel->setArg(idx++, static_cast<uint32_t>(input_channel_blocks));
......
......@@ -20,7 +20,6 @@ void DepthwiseConv2d(cl::Kernel *kernel,
const int *dilations,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha,
const DataType dt,
Tensor *output,
StatsFuture *future) {
......@@ -69,9 +68,6 @@ void DepthwiseConv2d(cl::Kernel *kernel,
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
......@@ -96,7 +92,6 @@ void DepthwiseConv2d(cl::Kernel *kernel,
kernel->setArg(
idx++, *(static_cast<const cl::Image2D *>(output->buffer())));
kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, prelu_alpha);
kernel->setArg(idx++, static_cast<short>(input_height));
kernel->setArg(idx++, static_cast<short>(input_width));
kernel->setArg(idx++, static_cast<short>(input_channel_blocks));
......@@ -140,8 +135,8 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()(
<< " is not implemented yet, using slow version";
// TODO(heliangliang) The CPU/NEON kernel should map the buffer
DepthwiseConv2dFunctor<DeviceType::CPU, float>(
strides_, padding_type_, paddings_, dilations_, activation_, relux_max_limit_,
prelu_alpha_)(input, filter, bias, output, future);
strides_, padding_type_, paddings_, dilations_, activation_,
relux_max_limit_)(input, filter, bias, output, future);
return;
}
......@@ -154,11 +149,14 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()(
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), fake_filter_shape.data(), dilations_, strides_,
padding_type_, output_shape.data(), paddings.data());
if (!paddings_.empty()) {
if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), fake_filter_shape.data(), dilations_, strides_,
padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_;
CalcOutputSize(input->shape().data(), fake_filter_shape.data(), paddings_.data(),
dilations_, strides_, RoundType::FLOOR, output_shape.data());
}
std::vector<size_t> output_image_shape;
......@@ -166,7 +164,7 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()(
output->ResizeImage(output_shape, output_image_shape);
DepthwiseConv2d(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_,
activation_, relux_max_limit_, prelu_alpha_,
activation_, relux_max_limit_,
DataTypeToEnum<T>::value, output, future);
}
......
......@@ -48,9 +48,6 @@ void FullyConnectedFunctor<DeviceType::OPENCL, T>::operator()(
case RELUX:
built_options.emplace("-DUSE_RELUX");
break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH:
built_options.emplace("-DUSE_TANH");
break;
......@@ -78,7 +75,6 @@ void FullyConnectedFunctor<DeviceType::OPENCL, T>::operator()(
kernel_.setArg(idx++, static_cast<int>(input->dim(3)));
// FIXME handle flexable data type: half not supported
kernel_.setArg(idx++, relux_max_limit_);
kernel_.setArg(idx++, prelu_alpha_);
}
const uint32_t gws[2] = {
......
......@@ -24,12 +24,14 @@ void PoolingFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
};
std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), filter_shape.data(),
dilations_, strides_, this->padding_type_,
output_shape.data(), paddings.data());
if (!paddings_.empty()) {
if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize(
input->shape().data(), filter_shape.data(), dilations_, strides_,
padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_;
CalcOutputSize(input->shape().data(), filter_shape.data(), paddings_.data(),
dilations_, strides_, RoundType::CEIL, output_shape.data());
}
std::vector<size_t> output_image_shape;
......
......@@ -18,11 +18,14 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i
std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {3, 3, input_tensor->dim(3), 1};
std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(), dilations_.data(),
strides_.data(), padding_type_, output_shape.data(), paddings.data());
if (!paddings_.empty()) {
if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(), dilations_.data(), strides_.data(),
padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_;
CalcOutputSize(input_tensor->shape().data(), filter_shape.data(), paddings_.data(),
dilations_.data(), strides_.data(), RoundType::FLOOR, output_shape.data());
}
const index_t round_h = (output_shape[1] + 1) / 2;
......@@ -126,7 +129,6 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te
kernel_.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
kernel_.setArg(idx++, static_cast<uint32_t>(round_w));
kernel_.setArg(idx++, relux_max_limit_);
kernel_.setArg(idx++, prelu_alpha_);
}
const uint32_t gws[2] = {static_cast<uint32_t>(input_tensor->dim(2)),
......
......@@ -65,12 +65,14 @@ struct PoolingFunctor : PoolingFunctorBase {
};
std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(),
dilations_, strides_, this->padding_type_,
output_shape.data(), paddings.data());
if (!paddings_.empty()) {
if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(), dilations_, strides_,
padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_;
CalcOutputSize(input_tensor->shape().data(), filter_shape.data(), paddings_.data(),
dilations_, strides_, RoundType::CEIL, output_shape.data());
}
output_tensor->Resize(output_shape);
......
......@@ -58,21 +58,18 @@ struct WinogradInverseTransformFunctorBase {
const int height,
const int width,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
const float relux_max_limit)
: batch_(batch),
height_(height),
width_(width),
activation_(activation),
relux_max_limit_(relux_max_limit),
prelu_alpha_(prelu_alpha) {}
relux_max_limit_(relux_max_limit) {}
const int batch_;
const int height_;
const int width_;
const ActivationType activation_;
const float relux_max_limit_;
const float prelu_alpha_;
};
template<DeviceType D, typename T>
......@@ -81,9 +78,8 @@ struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase {
const int height,
const int width,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
: WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit, prelu_alpha) {}
const float relux_max_limit)
: WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *bias,
......@@ -100,9 +96,8 @@ struct WinogradInverseTransformFunctor<DeviceType::OPENCL, T> : WinogradInverseT
const int height,
const int width,
const ActivationType activation,
const float relux_max_limit,
const float prelu_alpha)
: WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit, prelu_alpha) {}
const float relux_max_limit)
: WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *bias,
......
......@@ -18,15 +18,15 @@ class ActivationOp : public Operator<D, T> {
functor_(kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->inputs_[0];
const Tensor *input_tensor = this->Input(0);
const Tensor *alpha_tensor = this->InputSize() >= 2 ? this->Input(1) : nullptr;
Tensor *output_tensor = this->outputs_[0];
output_tensor->ResizeLike(input_tensor);
functor_(input_tensor, output_tensor, future);
functor_(input_tensor, alpha_tensor, output_tensor, future);
return true;
}
......
......@@ -213,17 +213,22 @@ void TestSimplePrelu() {
// Add input data
net.AddInputFromArray<D, float>(
"Input", {2, 2, 2, 2},
{-7, 7, -6, 6, -5, 5, -4, 4, -3, 3, -2, 2, -1, 1, 0, 0});
{-7, 7, -6, 6, -5, -5, -4, -4, -3, 3, -2, 2, -1, -1, 0, 0});
net.AddInputFromArray<D, float>(
"Alpha", {2},
{2.0, 3.0});
if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(net, "Alpha", "AlphaImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("Activation", "PreluTest")
.Input("InputImage")
.Input("AlphaImage")
.Output("OutputImage")
.AddStringArg("activation", "PRELU")
.AddFloatArg("alpha", 2.0)
.Finalize(net.NewOperatorDef());
// Run
......@@ -235,9 +240,9 @@ void TestSimplePrelu() {
} else {
OpDefBuilder("Activation", "PreluTest")
.Input("Input")
.Input("Alpha")
.Output("Output")
.AddStringArg("activation", "PRELU")
.AddFloatArg("alpha", 2.0)
.Finalize(net.NewOperatorDef());
// Run
......@@ -245,7 +250,7 @@ void TestSimplePrelu() {
}
auto expected = CreateTensor<float>(
{2, 2, 2, 2}, {-14, 7, -12, 6, -10, 5, -8, 4, -6, 3, -4, 2, -2, 1, 0, 0});
{2, 2, 2, 2}, {-14, 7, -12, 6, -10, -15, -8, -12, -6, 3, -4, 2, -2, -3, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
......
......@@ -16,7 +16,7 @@ class BatchNormOp : public Operator<D, T> {
public:
BatchNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(false, kernels::ActivationType::NOOP, 0.0f, 0.0f) {
functor_(false, kernels::ActivationType::NOOP, 0.0f) {
epsilon_ = OperatorBase::GetSingleArgument<float>("epsilon",
static_cast<float>(1e-4));
}
......
......@@ -23,7 +23,6 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
this->paddings_,
this->dilations_.data(),
kernels::ActivationType::NOOP,
0.0f,
0.0f) {}
bool Run(StatsFuture *future) override {
......
......@@ -342,7 +342,7 @@ TEST_F(Conv2dOpTest, CPUConv1x1) { TestConv1x1<DeviceType::CPU>(); }
TEST_F(Conv2dOpTest, OPENCLConv1x1) { TestConv1x1<DeviceType::OPENCL>(); }
template <DeviceType D, typename T>
static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
static void TestComplexConvNxNS12(const std::vector<index_t> &shape, const int stride) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) {
......@@ -405,20 +405,31 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.001);
};
for (int kernel_size : {1, 3}) {
for (int stride : {1, 2}) {
func(kernel_size, kernel_size, stride, stride, VALID);
func(kernel_size, kernel_size, stride, stride, SAME);
}
for (int kernel_size : {1, 3, 7}) {
func(kernel_size, kernel_size, stride, stride, VALID);
func(kernel_size, kernel_size, stride, stride, SAME);
}
}
TEST_F(Conv2dOpTest, OPENCLAlignedConvNxNS12) {
TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 32, 32, 64});
TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 16, 16, 32},
1);
TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 16, 16, 32},
2);
}
TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) {
TestComplexConvNxNS12<DeviceType::OPENCL, float>({107, 113, 5, 7});
TestComplexConvNxNS12<DeviceType::OPENCL, float>({17, 113, 5, 7},
1);
TestComplexConvNxNS12<DeviceType::OPENCL, float>({17, 113, 5, 7},
2);
}
TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS34) {
TestComplexConvNxNS12<DeviceType::OPENCL, float>({31, 113, 13, 17},
3);
TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 32, 13, 17},
4);
}
template<DeviceType D>
......@@ -650,3 +661,81 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedDilation4) {
4);
}
template<DeviceType D, typename T>
static void TestArbitraryPadConvNxN(const std::vector<index_t> &shape, const std::vector<int> &paddings) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w) {
srand(time(NULL));
// generate random input
index_t batch = 1;
index_t height = shape[0];
index_t width = shape[1];
index_t input_channels = shape[2];
index_t output_channels = shape[3];
// Construct graph
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("padding_values", paddings)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// run on gpu
BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER);
BufferToImage<D, T>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage")
.Input("FilterImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("padding_values", paddings)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on device
net.RunOp(D);
ImageToBuffer<D, T>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT_CHANNEL);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.001);
};
for (int kernel_size : {3, 5}) {
for (int stride : {2, 3}) {
func(kernel_size, kernel_size, stride, stride);
}
}
}
TEST_F(Conv2dOpTest, OPENCLAlignedPad1) {
TestArbitraryPadConvNxN<DeviceType::OPENCL, float>({32, 32, 32, 64},
{1, 1});
}
TEST_F(Conv2dOpTest, OPENCLAlignedPad2) {
TestArbitraryPadConvNxN<DeviceType::OPENCL, float>({128, 128, 16, 16},
{2, 2});
}
TEST_F(Conv2dOpTest, OPENCLUnalignedPad4) {
TestArbitraryPadConvNxN<DeviceType::OPENCL, float>({107, 113, 5, 7},
{4, 4});
}
......@@ -26,8 +26,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -19,8 +19,7 @@ class FoldedBatchNormOp : public Operator<D, T> {
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -19,8 +19,7 @@ class FullyConnectedOp : public Operator<D, T> {
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -25,8 +25,7 @@ class FusedConv2dOp : public ConvPool2dOpBase<D, T> {
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -24,8 +24,7 @@ class WinogradInverseTransformOp : public Operator<D, T> {
kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f),
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
......
......@@ -5,10 +5,14 @@ import functools
import argparse
import sys
import six
import os.path
FLAGS = None
def main(unused_args):
if not os.path.isfile(FLAGS.input):
print 'input model file not exist'
return -1
net = caffe_pb2.NetParameter()
with open(FLAGS.input) as f:
google.protobuf.text_format.Merge(str(f.read()), net)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册