提交 d548b461 编写于 作者: L Liangliang He

Merge branch 'conv-pad' into 'master'

Fix arbitrary pad and PReLU bug

See merge request !259
...@@ -46,8 +46,7 @@ void DoActivation(const T *input_ptr, ...@@ -46,8 +46,7 @@ void DoActivation(const T *input_ptr,
T *output_ptr, T *output_ptr,
const index_t size, const index_t size,
const ActivationType type, const ActivationType type,
const float relux_max_limit, const float relux_max_limit) {
const float prelu_alpha) {
MACE_CHECK(DataTypeToEnum<T>::value != DataType::DT_HALF); MACE_CHECK(DataTypeToEnum<T>::value != DataType::DT_HALF);
switch (type) { switch (type) {
...@@ -66,17 +65,6 @@ void DoActivation(const T *input_ptr, ...@@ -66,17 +65,6 @@ void DoActivation(const T *input_ptr,
static_cast<T>(relux_max_limit)); static_cast<T>(relux_max_limit));
} }
break; break;
case PRELU:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
T in = input_ptr[i];
if (in < 0) {
output_ptr[i] = in * prelu_alpha;
} else {
output_ptr[i] = in;
}
}
break;
case TANH: case TANH:
#pragma omp parallel for #pragma omp parallel for
for (index_t i = 0; i < size; ++i) { for (index_t i = 0; i < size; ++i) {
...@@ -95,45 +83,70 @@ void DoActivation(const T *input_ptr, ...@@ -95,45 +83,70 @@ void DoActivation(const T *input_ptr,
} }
} }
template <typename T>
void PReLUActivation(const T *input_ptr,
const index_t size,
const index_t input_chan,
const T *alpha_ptr,
T *output_ptr) {
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
const index_t chan_idx = i % input_chan;
T in = input_ptr[i];
if (in < 0) {
output_ptr[i] = in * alpha_ptr[chan_idx];
} else {
output_ptr[i] = in;
}
}
}
template <DeviceType D, typename T> template <DeviceType D, typename T>
class ActivationFunctor { class ActivationFunctor {
public: public:
ActivationFunctor(ActivationType type, T relux_max_limit, T prelu_alpha) ActivationFunctor(ActivationType type, T relux_max_limit)
: activation_(type), : activation_(type),
relux_max_limit_(relux_max_limit), relux_max_limit_(relux_max_limit){}
prelu_alpha_(prelu_alpha) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future) { void operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future) {
const T *input_ptr = input->data<T>(); const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>(); T *output_ptr = output->mutable_data<T>();
DoActivation(input_ptr, output_ptr, output->size(), activation_, relux_max_limit_, if (activation_ == PRELU) {
prelu_alpha_); const T *alpha_ptr = alpha == nullptr ? nullptr : alpha->data<T>();
PReLUActivation(input_ptr, output->size(), input->dim(3), alpha_ptr, output_ptr);
} else {
DoActivation(input_ptr, output_ptr, output->size(), activation_, relux_max_limit_);
}
} }
private: private:
ActivationType activation_; ActivationType activation_;
T relux_max_limit_; T relux_max_limit_;
T prelu_alpha_;
}; };
template <> template <>
void ActivationFunctor<DeviceType::NEON, float>::operator()( void ActivationFunctor<DeviceType::NEON, float>::operator()(
const Tensor *input, Tensor *output, StatsFuture *future); const Tensor *input, const Tensor *alpha, Tensor *output, StatsFuture *future);
template <typename T> template <typename T>
class ActivationFunctor<DeviceType::OPENCL, T> { class ActivationFunctor<DeviceType::OPENCL, T> {
public: public:
ActivationFunctor(ActivationType type, T relux_max_limit, T prelu_alpha) ActivationFunctor(ActivationType type, T relux_max_limit)
: activation_(type), : activation_(type),
relux_max_limit_(relux_max_limit), relux_max_limit_(relux_max_limit){}
prelu_alpha_(prelu_alpha) {}
void operator()(const Tensor *input, Tensor *output, StatsFuture *future); void operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output,
StatsFuture *future);
private: private:
ActivationType activation_; ActivationType activation_;
T relux_max_limit_; T relux_max_limit_;
T prelu_alpha_;
cl::Kernel kernel_; cl::Kernel kernel_;
std::string tuning_key_prefix_; std::string tuning_key_prefix_;
}; };
......
...@@ -21,27 +21,23 @@ namespace kernels { ...@@ -21,27 +21,23 @@ namespace kernels {
struct BatchNormFunctorBase { struct BatchNormFunctorBase {
BatchNormFunctorBase(bool folded_constant, BatchNormFunctorBase(bool folded_constant,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha)
: folded_constant_(folded_constant), : folded_constant_(folded_constant),
activation_(activation), activation_(activation),
relux_max_limit_(relux_max_limit), relux_max_limit_(relux_max_limit){}
prelu_alpha_(prelu_alpha) {}
const bool folded_constant_; const bool folded_constant_;
const ActivationType activation_; const ActivationType activation_;
const float relux_max_limit_; const float relux_max_limit_;
const float prelu_alpha_;
}; };
template <DeviceType D, typename T> template <DeviceType D, typename T>
struct BatchNormFunctor : BatchNormFunctorBase { struct BatchNormFunctor : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant, BatchNormFunctor(const bool folded_constant,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha)
: BatchNormFunctorBase( : BatchNormFunctorBase(
folded_constant, activation, relux_max_limit, prelu_alpha) {} folded_constant, activation, relux_max_limit) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *scale, const Tensor *scale,
...@@ -132,7 +128,7 @@ struct BatchNormFunctor : BatchNormFunctorBase { ...@@ -132,7 +128,7 @@ struct BatchNormFunctor : BatchNormFunctorBase {
} }
} }
DoActivation(output_ptr, output_ptr, output->NumElements(), activation_, DoActivation(output_ptr, output_ptr, output->NumElements(), activation_,
relux_max_limit_, prelu_alpha_); relux_max_limit_);
} }
}; };
...@@ -150,10 +146,9 @@ template <typename T> ...@@ -150,10 +146,9 @@ template <typename T>
struct BatchNormFunctor<DeviceType::OPENCL, T> : BatchNormFunctorBase { struct BatchNormFunctor<DeviceType::OPENCL, T> : BatchNormFunctorBase {
BatchNormFunctor(const bool folded_constant, BatchNormFunctor(const bool folded_constant,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha)
: BatchNormFunctorBase( : BatchNormFunctorBase(
folded_constant, activation, relux_max_limit, prelu_alpha) {} folded_constant, activation, relux_max_limit) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *scale, const Tensor *scale,
const Tensor *offset, const Tensor *offset,
......
...@@ -182,15 +182,13 @@ struct Conv2dFunctorBase { ...@@ -182,15 +182,13 @@ struct Conv2dFunctorBase {
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha)
: strides_(strides), : strides_(strides),
padding_type_(padding_type), padding_type_(padding_type),
paddings_(paddings), paddings_(paddings),
dilations_(dilations), dilations_(dilations),
activation_(activation), activation_(activation),
relux_max_limit_(relux_max_limit), relux_max_limit_(relux_max_limit){}
prelu_alpha_(prelu_alpha) {}
const int *strides_; // [stride_h, stride_w] const int *strides_; // [stride_h, stride_w]
const Padding padding_type_; const Padding padding_type_;
...@@ -198,7 +196,6 @@ struct Conv2dFunctorBase { ...@@ -198,7 +196,6 @@ struct Conv2dFunctorBase {
const int *dilations_; // [dilation_h, dilation_w] const int *dilations_; // [dilation_h, dilation_w]
const ActivationType activation_; const ActivationType activation_;
const float relux_max_limit_; const float relux_max_limit_;
const float prelu_alpha_;
}; };
template <DeviceType D, typename T> template <DeviceType D, typename T>
...@@ -208,15 +205,13 @@ struct Conv2dFunctor : Conv2dFunctorBase { ...@@ -208,15 +205,13 @@ struct Conv2dFunctor : Conv2dFunctorBase {
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha)
: Conv2dFunctorBase(strides, : Conv2dFunctorBase(strides,
padding_type, padding_type,
paddings, paddings,
dilations, dilations,
activation, activation,
relux_max_limit, relux_max_limit) {}
prelu_alpha) {}
void operator()(const Tensor *input, // NHWC void operator()(const Tensor *input, // NHWC
const Tensor *filter, // HWOI const Tensor *filter, // HWOI
...@@ -229,11 +224,14 @@ struct Conv2dFunctor : Conv2dFunctorBase { ...@@ -229,11 +224,14 @@ struct Conv2dFunctor : Conv2dFunctorBase {
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input->shape().data(), filter->shape().data(), dilations_, strides_, kernels::CalcNHWCPaddingAndOutputSize(
padding_type_, output_shape.data(), paddings.data()); input->shape().data(), filter->shape().data(), dilations_, strides_,
if (!paddings_.empty()) { padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), filter->shape().data(), paddings_.data(),
dilations_, strides_, RoundType::FLOOR, output_shape.data());
} }
output->Resize(output_shape); output->Resize(output_shape);
...@@ -619,7 +617,7 @@ struct Conv2dFunctor : Conv2dFunctorBase { ...@@ -619,7 +617,7 @@ struct Conv2dFunctor : Conv2dFunctorBase {
} }
} }
DoActivation(output_data, output_data, output->NumElements(), activation_, DoActivation(output_data, output_data, output->NumElements(), activation_,
relux_max_limit_, prelu_alpha_); relux_max_limit_);
} }
}; };
...@@ -637,15 +635,13 @@ struct Conv2dFunctor<DeviceType::OPENCL, T> : Conv2dFunctorBase { ...@@ -637,15 +635,13 @@ struct Conv2dFunctor<DeviceType::OPENCL, T> : Conv2dFunctorBase {
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha)
: Conv2dFunctorBase(strides, : Conv2dFunctorBase(strides,
padding_type, padding_type,
paddings, paddings,
dilations, dilations,
activation, activation,
relux_max_limit, relux_max_limit) {}
prelu_alpha) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *filter, const Tensor *filter,
......
...@@ -135,6 +135,44 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC ...@@ -135,6 +135,44 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
output_shape[3] = output_channels; output_shape[3] = output_channels;
} }
void CalcOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // HWOI
const int *padding_size,
const int *dilations,
const int *strides,
const RoundType round_type,
index_t *output_shape) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
/*
* Convlution arithmetic:
* o = floor((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1
* Pooling arithmetic:
* o = ceil((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
output_shape[0] = input_shape[0];
if (round_type == FLOOR) {
output_shape[1] = static_cast<index_t>(std::floor(1.0 * (input_shape[1] + padding_size[0]
- filter_shape[0] - (filter_shape[0] - 1) * (dilations[0] - 1)) / strides[0]) + 1);
output_shape[2] = static_cast<index_t>(std::floor(1.0 * (input_shape[2] + padding_size[1]
- filter_shape[1] - (filter_shape[1] - 1) * (dilations[1] - 1)) / strides[1]) + 1);
} else {
output_shape[1] = static_cast<index_t>(std::ceil(1.0 * (input_shape[1] + padding_size[0]
- filter_shape[0] - (filter_shape[0] - 1) * (dilations[0] - 1)) / strides[0]) + 1);
output_shape[2] = static_cast<index_t>(std::ceil(1.0 * (input_shape[2] + padding_size[1]
- filter_shape[1] - (filter_shape[1] - 1) * (dilations[1] - 1)) / strides[1]) + 1);
}
output_shape[3] = filter_shape[2];
}
void CalPaddingSize(const index_t *input_shape, // NCHW void CalPaddingSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW const index_t *filter_shape, // OIHW
const int *dilations, const int *dilations,
......
...@@ -15,6 +15,11 @@ enum Padding { ...@@ -15,6 +15,11 @@ enum Padding {
FULL = 2, // Pads with one less than the filter size on both sides FULL = 2, // Pads with one less than the filter size on both sides
}; };
enum RoundType{
FLOOR = 0,
CEIL = 1,
};
namespace kernels { namespace kernels {
void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
...@@ -33,6 +38,14 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NCHW ...@@ -33,6 +38,14 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NCHW
index_t *output_shape, index_t *output_shape,
int *padding_size); int *padding_size);
void CalcOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // HWOI
const int *padding_size,
const int *dilations,
const int *strides,
const RoundType round_type,
index_t *output_shape);
void CalPaddingSize(const index_t *input_shape, // NCHW void CalPaddingSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW const index_t *filter_shape, // OIHW
const int *dilations, const int *dilations,
......
...@@ -241,15 +241,13 @@ struct DepthwiseConv2dFunctorBase { ...@@ -241,15 +241,13 @@ struct DepthwiseConv2dFunctorBase {
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha)
: strides_(strides), : strides_(strides),
padding_type_(padding_type), padding_type_(padding_type),
paddings_(paddings), paddings_(paddings),
dilations_(dilations), dilations_(dilations),
activation_(activation), activation_(activation),
relux_max_limit_(relux_max_limit), relux_max_limit_(relux_max_limit){}
prelu_alpha_(prelu_alpha) {}
const int *strides_; // [stride_h, stride_w] const int *strides_; // [stride_h, stride_w]
const Padding padding_type_; const Padding padding_type_;
...@@ -257,7 +255,6 @@ struct DepthwiseConv2dFunctorBase { ...@@ -257,7 +255,6 @@ struct DepthwiseConv2dFunctorBase {
const int *dilations_; // [dilation_h, dilation_w] const int *dilations_; // [dilation_h, dilation_w]
const ActivationType activation_; const ActivationType activation_;
const float relux_max_limit_; const float relux_max_limit_;
const float prelu_alpha_;
}; };
template <DeviceType D, typename T> template <DeviceType D, typename T>
...@@ -267,15 +264,13 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase { ...@@ -267,15 +264,13 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha)
: DepthwiseConv2dFunctorBase(strides, : DepthwiseConv2dFunctorBase(strides,
padding_type, padding_type,
paddings, paddings,
dilations, dilations,
activation, activation,
relux_max_limit, relux_max_limit) {}
prelu_alpha) {}
void operator()(const Tensor *input, // NHWC void operator()(const Tensor *input, // NHWC
const Tensor *filter, // HWIM const Tensor *filter, // HWIM
...@@ -295,11 +290,14 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase { ...@@ -295,11 +290,14 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input->shape().data(), fake_filter_shape.data(), dilations_, strides_, kernels::CalcNHWCPaddingAndOutputSize(
padding_type_, output_shape.data(), paddings.data()); input->shape().data(), fake_filter_shape.data(), dilations_, strides_,
if (!paddings_.empty()) { padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), fake_filter_shape.data(), paddings_.data(),
dilations_, strides_, RoundType::FLOOR, output_shape.data());
} }
auto input_shape = fake_filter_shape; auto input_shape = fake_filter_shape;
output->Resize(output_shape); output->Resize(output_shape);
...@@ -405,7 +403,7 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase { ...@@ -405,7 +403,7 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
output_ptr = output->mutable_data<T>(); output_ptr = output->mutable_data<T>();
DoActivation(output_ptr, output_ptr, output->NumElements(), activation_, DoActivation(output_ptr, output_ptr, output->NumElements(), activation_,
relux_max_limit_, prelu_alpha_); relux_max_limit_);
} }
}; };
...@@ -425,15 +423,13 @@ struct DepthwiseConv2dFunctor<DeviceType::OPENCL, T> ...@@ -425,15 +423,13 @@ struct DepthwiseConv2dFunctor<DeviceType::OPENCL, T>
const std::vector<int> &paddings, const std::vector<int> &paddings,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha)
: DepthwiseConv2dFunctorBase(strides, : DepthwiseConv2dFunctorBase(strides,
padding_type, padding_type,
paddings, paddings,
dilations, dilations,
activation, activation,
relux_max_limit, relux_max_limit) {}
prelu_alpha) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *filter, const Tensor *filter,
......
...@@ -15,23 +15,19 @@ namespace kernels { ...@@ -15,23 +15,19 @@ namespace kernels {
struct FullyConnectedBase { struct FullyConnectedBase {
FullyConnectedBase(const ActivationType activation, FullyConnectedBase(const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha)
: activation_(activation), : activation_(activation),
relux_max_limit_(relux_max_limit), relux_max_limit_(relux_max_limit){}
prelu_alpha_(prelu_alpha) {}
const ActivationType activation_; const ActivationType activation_;
const float relux_max_limit_; const float relux_max_limit_;
const float prelu_alpha_;
}; };
template<DeviceType D, typename T> template<DeviceType D, typename T>
struct FullyConnectedFunctor : FullyConnectedBase { struct FullyConnectedFunctor : FullyConnectedBase {
FullyConnectedFunctor(const ActivationType activation, FullyConnectedFunctor(const ActivationType activation,
const float relux_max_limit, const float relux_max_limit) :
const float prelu_alpha) : FullyConnectedBase(activation, relux_max_limit) {}
FullyConnectedBase(activation, relux_max_limit, prelu_alpha) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *weight, const Tensor *weight,
...@@ -70,16 +66,15 @@ struct FullyConnectedFunctor : FullyConnectedBase { ...@@ -70,16 +66,15 @@ struct FullyConnectedFunctor : FullyConnectedBase {
} }
DoActivation(output_ptr, output_ptr, output->NumElements(), activation_, DoActivation(output_ptr, output_ptr, output->NumElements(), activation_,
relux_max_limit_, prelu_alpha_); relux_max_limit_);
} }
}; };
template<typename T> template<typename T>
struct FullyConnectedFunctor<DeviceType::OPENCL, T> : FullyConnectedBase { struct FullyConnectedFunctor<DeviceType::OPENCL, T> : FullyConnectedBase {
FullyConnectedFunctor(const ActivationType activation, FullyConnectedFunctor(const ActivationType activation,
const float relux_max_limit, const float relux_max_limit) :
const float prelu_alpha) : FullyConnectedBase(activation, relux_max_limit) {}
FullyConnectedBase(activation, relux_max_limit, prelu_alpha) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *weight, const Tensor *weight,
......
...@@ -14,6 +14,7 @@ namespace kernels { ...@@ -14,6 +14,7 @@ namespace kernels {
template <typename T> template <typename T>
void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const Tensor *alpha,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
const index_t batch = input->dim(0); const index_t batch = input->dim(0);
...@@ -60,8 +61,10 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -60,8 +61,10 @@ void ActivationFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
runtime->BuildKernel("activation", kernel_name, built_options); runtime->BuildKernel("activation", kernel_name, built_options);
int idx = 0; int idx = 0;
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer()))); kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(input->buffer())));
if (activation_ == PRELU) {
kernel_.setArg(idx++, *(static_cast<const cl::Image2D *>(alpha->buffer())));
}
kernel_.setArg(idx++, static_cast<float>(relux_max_limit_)); kernel_.setArg(idx++, static_cast<float>(relux_max_limit_));
kernel_.setArg(idx++, static_cast<float>(prelu_alpha_));
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer()))); kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
} }
......
...@@ -50,9 +50,6 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -50,9 +50,6 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
case RELUX: case RELUX:
built_options.emplace("-DUSE_RELUX"); built_options.emplace("-DUSE_RELUX");
break; break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH: case TANH:
built_options.emplace("-DUSE_TANH"); built_options.emplace("-DUSE_TANH");
break; break;
...@@ -79,7 +76,6 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -79,7 +76,6 @@ void BatchNormFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
} }
kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer()))); kernel_.setArg(idx++, *(static_cast<cl::Image2D *>(output->buffer())));
kernel_.setArg(idx++, relux_max_limit_); kernel_.setArg(idx++, relux_max_limit_);
kernel_.setArg(idx++, prelu_alpha_);
} }
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
......
#include <common.h> #include <common.h>
__kernel void activation(__read_only image2d_t input, __kernel void activation(__read_only image2d_t input,
#ifdef USE_PRELU
__read_only image2d_t alpha,
#endif
__private const float relux_max_limit, __private const float relux_max_limit,
__private const float prelu_alpha,
__write_only image2d_t output) { __write_only image2d_t output) {
const int ch_blk = get_global_id(0); const int ch_blk = get_global_id(0);
const int w = get_global_id(1); const int w = get_global_id(1);
...@@ -11,7 +13,12 @@ __kernel void activation(__read_only image2d_t input, ...@@ -11,7 +13,12 @@ __kernel void activation(__read_only image2d_t input,
const int pos = mad24(ch_blk, width, w); const int pos = mad24(ch_blk, width, w);
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 out = do_activation(in, relux_max_limit, prelu_alpha); #ifdef USE_PRELU
DATA_TYPE4 prelu_alpha = READ_IMAGET(alpha, SAMPLER, (int2)(ch_blk, 0));
DATA_TYPE4 out = do_activation(in, prelu_alpha, relux_max_limit);
#else
DATA_TYPE4 out = do_activation(in, relux_max_limit);
#endif
WRITE_IMAGET(output, (int2)(pos, hb), out); WRITE_IMAGET(output, (int2)(pos, hb), out);
} }
...@@ -9,8 +9,7 @@ __kernel void batch_norm(__read_only image2d_t input, ...@@ -9,8 +9,7 @@ __kernel void batch_norm(__read_only image2d_t input,
__private const float epsilon, __private const float epsilon,
#endif #endif
__write_only image2d_t output, __write_only image2d_t output,
__private const float relux_max_limit, __private const float relux_max_limit) {
__private const float prelu_alpha) {
const int ch_blk = get_global_id(0); const int ch_blk = get_global_id(0);
const int w = get_global_id(1); const int w = get_global_id(1);
const int hb = get_global_id(2); const int hb = get_global_id(2);
...@@ -35,8 +34,8 @@ __kernel void batch_norm(__read_only image2d_t input, ...@@ -35,8 +34,8 @@ __kernel void batch_norm(__read_only image2d_t input,
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb)); DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(pos, hb));
DATA_TYPE4 out = mad(in, bn_scale, bn_offset); DATA_TYPE4 out = mad(in, bn_scale, bn_offset);
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) #if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out = do_activation(out, relux_max_limit, prelu_alpha); out = do_activation(out, relux_max_limit);
#endif #endif
WRITE_IMAGET(output, (int2)(pos, hb), out); WRITE_IMAGET(output, (int2)(pos, hb), out);
......
...@@ -22,8 +22,10 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | ...@@ -22,8 +22,10 @@ __constant sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP |
inline DATA_TYPE4 do_activation(DATA_TYPE4 in, inline DATA_TYPE4 do_activation(DATA_TYPE4 in,
__private const float relux_max_limit, #ifdef USE_PRELU
__private const float prelu_alpha) { DATA_TYPE4 prelu_alpha,
#endif
__private const float relux_max_limit) {
DATA_TYPE4 out; DATA_TYPE4 out;
#ifdef USE_RELU #ifdef USE_RELU
out = fmax(in, 0); out = fmax(in, 0);
......
...@@ -7,7 +7,6 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ ...@@ -7,7 +7,6 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
#endif #endif
__write_only image2d_t output, __write_only image2d_t output,
__private const float relux_max_limit, __private const float relux_max_limit,
__private const float prelu_alpha,
__private const int in_height, __private const int in_height,
__private const int in_width, __private const int in_width,
__private const int in_ch_blks, __private const int in_ch_blks,
...@@ -112,11 +111,11 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */ ...@@ -112,11 +111,11 @@ __kernel void conv_2d(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
} }
} }
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) #if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha); out0 = do_activation(out0, relux_max_limit);
out1 = do_activation(out1, relux_max_limit, prelu_alpha); out1 = do_activation(out1, relux_max_limit);
out2 = do_activation(out2, relux_max_limit, prelu_alpha); out2 = do_activation(out2, relux_max_limit);
out3 = do_activation(out3, relux_max_limit, prelu_alpha); out3 = do_activation(out3, relux_max_limit);
#endif #endif
const int out_x_base = mul24(out_ch_blk, out_width); const int out_x_base = mul24(out_ch_blk, out_width);
......
...@@ -7,12 +7,12 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -7,12 +7,12 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
#endif #endif
__write_only image2d_t output, __write_only image2d_t output,
__private const float relux_max_limit, __private const float relux_max_limit,
__private const float prelu_alpha,
__private const int in_height, __private const int in_height,
__private const int in_width, __private const int in_width,
__private const int in_ch_blks, __private const int in_ch_blks,
__private const int height, __private const int height,
__private const int width) { __private const int width,
__private const int stride) {
const int out_ch_blk = get_global_id(0); const int out_ch_blk = get_global_id(0);
const int out_w_blk = get_global_id(1); const int out_w_blk = get_global_id(1);
const int out_w_blks = get_global_size(1); const int out_w_blks = get_global_size(1);
...@@ -31,19 +31,12 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -31,19 +31,12 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
#endif #endif
int4 w; int4 w;
#if STRIDE == 1 int in_width_stride = mul24(out_w_blks, stride);
w.x = out_w_blk; w.x = mul24(out_w_blk, stride);
w.y = w.x + out_w_blks; w.y = w.x + in_width_stride;
w.z = w.y + out_w_blks; w.z = w.y + in_width_stride;
w.w = w.z + out_w_blks; w.w = w.z + in_width_stride;
int out_hb_idx = (out_hb % height); int out_hb_idx = mul24((out_hb % height), stride);
#elif STRIDE == 2
w.x = out_w_blk << 1;
w.y = (out_w_blk + out_w_blks) << 1;
w.z = (out_w_blk + (out_w_blks << 1)) << 1;
w.w = (out_w_blk + (out_w_blks << 1) + out_w_blks) << 1;
int out_hb_idx = (out_hb % height) << 1;
#endif
w.x = select(w.x, INT_MIN, w.x >= in_width); w.x = select(w.x, INT_MIN, w.x >= in_width);
w.y = select(w.y, INT_MIN, w.y >= in_width); w.y = select(w.y, INT_MIN, w.y >= in_width);
...@@ -92,11 +85,11 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -92,11 +85,11 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
filter_x_base += 4; filter_x_base += 4;
} }
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) #if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha); out0 = do_activation(out0, relux_max_limit);
out1 = do_activation(out1, relux_max_limit, prelu_alpha); out1 = do_activation(out1, relux_max_limit);
out2 = do_activation(out2, relux_max_limit, prelu_alpha); out2 = do_activation(out2, relux_max_limit);
out3 = do_activation(out3, relux_max_limit, prelu_alpha); out3 = do_activation(out3, relux_max_limit);
#endif #endif
const int out_x_base = mul24(out_ch_blk, width); const int out_x_base = mul24(out_ch_blk, width);
......
...@@ -7,12 +7,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -7,12 +7,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
#endif #endif
__write_only image2d_t output, __write_only image2d_t output,
__private const float relux_max_limit, __private const float relux_max_limit,
__private const float prelu_alpha,
__private const int in_height, __private const int in_height,
__private const int in_width, __private const int in_width,
__private const int in_ch_blks, __private const int in_ch_blks,
__private const int out_height, __private const int out_height,
__private const int out_width, __private const int out_width,
__private const int stride,
__private const int padding_top, __private const int padding_top,
__private const int padding_left, __private const int padding_left,
__private const int dilation_h, __private const int dilation_h,
...@@ -38,21 +38,13 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -38,21 +38,13 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
DATA_TYPE4 out4 = 0; DATA_TYPE4 out4 = 0;
#endif #endif
#if STRIDE == 1 int in_width_stride = mul24(out_w_blks, stride);
int in_width0 = out_w_blk - padding_left; int in_width0 = mad24(out_w_blk, stride, -padding_left);
int in_width1 = in_width0 + out_w_blks; int in_width1 = in_width0 + in_width_stride;
int in_width2 = in_width1 + out_w_blks; int in_width2 = in_width1 + in_width_stride;
int in_width3 = in_width2 + out_w_blks; int in_width3 = in_width2 + in_width_stride;
int in_width4 = in_width3 + out_w_blks; int in_width4 = in_width3 + in_width_stride;
const int height_idx = (out_hb % out_height) - padding_top; const int height_idx = mad24((out_hb % out_height), stride, -padding_top);
#elif STRIDE == 2
int in_width0 = (out_w_blk << 1) - padding_left;
int in_width1 = ((out_w_blk + out_w_blks) << 1) - padding_left;
int in_width2 = ((out_w_blk + (out_w_blks << 1)) << 1) - padding_left;
int in_width3 = ((out_w_blk + (out_w_blks << 1) + out_w_blks) << 1) - padding_left;
int in_width4 = ((out_w_blk + (out_w_blks << 2)) << 1) - padding_left;
const int height_idx = ((out_hb % out_height) << 1) - padding_top;
#endif
const int batch_idx = mul24((out_hb / out_height), in_height); const int batch_idx = mul24((out_hb / out_height), in_height);
const int rounded_in_ch_x_3 = (rounded_in_ch << 1) + rounded_in_ch; const int rounded_in_ch_x_3 = (rounded_in_ch << 1) + rounded_in_ch;
...@@ -127,12 +119,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -127,12 +119,12 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
} }
} }
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) #if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha); out0 = do_activation(out0, relux_max_limit);
out1 = do_activation(out1, relux_max_limit, prelu_alpha); out1 = do_activation(out1, relux_max_limit);
out2 = do_activation(out2, relux_max_limit, prelu_alpha); out2 = do_activation(out2, relux_max_limit);
out3 = do_activation(out3, relux_max_limit, prelu_alpha); out3 = do_activation(out3, relux_max_limit);
out4 = do_activation(out4, relux_max_limit, prelu_alpha); out4 = do_activation(out4, relux_max_limit);
#endif #endif
const int out_x_base = mul24(out_ch_blk, out_width); const int out_x_base = mul24(out_ch_blk, out_width);
......
...@@ -8,7 +8,6 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h ...@@ -8,7 +8,6 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h
#endif #endif
__write_only image2d_t output, __write_only image2d_t output,
__private const float relux_max_limit, __private const float relux_max_limit,
__private const float prelu_alpha,
__private const short in_height, __private const short in_height,
__private const short in_width, __private const short in_width,
__private const short in_ch_blks, __private const short in_ch_blks,
...@@ -103,11 +102,11 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h ...@@ -103,11 +102,11 @@ __kernel void depthwise_conv2d(__read_only image2d_t input, /* [c%4 * w * c/4, h
in_hb_idx += dilation_h; in_hb_idx += dilation_h;
} }
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) #if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha); out0 = do_activation(out0, relux_max_limit);
out1 = do_activation(out1, relux_max_limit, prelu_alpha); out1 = do_activation(out1, relux_max_limit);
out2 = do_activation(out2, relux_max_limit, prelu_alpha); out2 = do_activation(out2, relux_max_limit);
out3 = do_activation(out3, relux_max_limit, prelu_alpha); out3 = do_activation(out3, relux_max_limit);
#endif #endif
const short out_x_base = mul24(out_ch_blk, out_width); const short out_x_base = mul24(out_ch_blk, out_width);
...@@ -134,7 +133,6 @@ __kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4 ...@@ -134,7 +133,6 @@ __kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4
#endif #endif
__write_only image2d_t output, __write_only image2d_t output,
__private const DATA_TYPE relux_max_limit, __private const DATA_TYPE relux_max_limit,
__private const DATA_TYPE prelu_alpha,
__private const short in_height, __private const short in_height,
__private const short in_width, __private const short in_width,
__private const short in_ch_blks, __private const short in_ch_blks,
...@@ -220,11 +218,11 @@ __kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4 ...@@ -220,11 +218,11 @@ __kernel void depthwise_conv2d_s1(__read_only image2d_t input, /* [c%4 * w * c/4
in_hb_idx += 1; in_hb_idx += 1;
} }
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) #if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
out0 = do_activation(out0, relux_max_limit, prelu_alpha); out0 = do_activation(out0, relux_max_limit);
out1 = do_activation(out1, relux_max_limit, prelu_alpha); out1 = do_activation(out1, relux_max_limit);
out2 = do_activation(out2, relux_max_limit, prelu_alpha); out2 = do_activation(out2, relux_max_limit);
out3 = do_activation(out3, relux_max_limit, prelu_alpha); out3 = do_activation(out3, relux_max_limit);
#endif #endif
const short out_x_base = mul24(out_ch_blk, out_width); const short out_x_base = mul24(out_ch_blk, out_width);
......
...@@ -10,8 +10,7 @@ __kernel void fully_connected(__read_only image2d_t input, ...@@ -10,8 +10,7 @@ __kernel void fully_connected(__read_only image2d_t input,
__private const int input_height, __private const int input_height,
__private const int input_width, __private const int input_width,
__private const int input_channel, __private const int input_channel,
__private const float relux_max_limit, __private const float relux_max_limit) {
__private const float prelu_alpha) {
const int batch_idx = get_global_id(0); const int batch_idx = get_global_id(0);
const int out_blk_idx = get_global_id(1); const int out_blk_idx = get_global_id(1);
const int input_chan_blk = (input_channel + 3) >> 2; const int input_chan_blk = (input_channel + 3) >> 2;
...@@ -51,8 +50,8 @@ __kernel void fully_connected(__read_only image2d_t input, ...@@ -51,8 +50,8 @@ __kernel void fully_connected(__read_only image2d_t input,
input_coord.y++; input_coord.y++;
} }
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) #if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
result = do_activation(result, relux_max_limit, prelu_alpha); result = do_activation(result, relux_max_limit);
#endif #endif
WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result); WRITE_IMAGET(output, (int2)(out_blk_idx, batch_idx), result);
} }
...@@ -115,8 +115,7 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input, ...@@ -115,8 +115,7 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input,
__private const int out_width, __private const int out_width,
__private const int round_hw, __private const int round_hw,
__private const int round_w, __private const int round_w,
__private const float relux_max_limit, __private const float relux_max_limit) {
__private const float prelu_alpha) {
const int width_idx = get_global_id(0); const int width_idx = get_global_id(0);
const int height_idx = get_global_id(1); const int height_idx = get_global_id(1);
const int out_channel = get_global_size(1); const int out_channel = get_global_size(1);
...@@ -183,11 +182,11 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input, ...@@ -183,11 +182,11 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input,
#endif #endif
#if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_PRELU) || defined(USE_TANH) || defined(USE_SIGMOID) #if defined(USE_RELU) || defined(USE_RELUX) || defined(USE_TANH) || defined(USE_SIGMOID)
in0[0] = do_activation(in0[0], relux_max_limit, prelu_alpha); in0[0] = do_activation(in0[0], relux_max_limit);
in0[1] = do_activation(in0[1], relux_max_limit, prelu_alpha); in0[1] = do_activation(in0[1], relux_max_limit);
in1[0] = do_activation(in1[0], relux_max_limit, prelu_alpha); in1[0] = do_activation(in1[0], relux_max_limit);
in1[1] = do_activation(in1[1], relux_max_limit, prelu_alpha); in1[1] = do_activation(in1[1], relux_max_limit);
#endif #endif
WRITE_IMAGET(output, (int2)(coord_x, coord_y), in0[0]); WRITE_IMAGET(output, (int2)(coord_x, coord_y), in0[0]);
...@@ -205,6 +204,4 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input, ...@@ -205,6 +204,4 @@ __kernel void winograd_inverse_transform_2x2(__read_only image2d_t input,
WRITE_IMAGET(output, (int2)(coord_x + 1, coord_y + 1), in1[1]); WRITE_IMAGET(output, (int2)(coord_x + 1, coord_y + 1), in1[1]);
} }
} }
...@@ -17,7 +17,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, ...@@ -17,7 +17,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit,
const float prelu_alpha,
const DataType dt, const DataType dt,
Tensor *output, Tensor *output,
StatsFuture *future); StatsFuture *future);
...@@ -31,7 +30,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, ...@@ -31,7 +30,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit,
const float prelu_alpha,
const DataType dt, const DataType dt,
Tensor *output, Tensor *output,
StatsFuture *future); StatsFuture *future);
...@@ -45,7 +43,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel, ...@@ -45,7 +43,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit,
const float prelu_alpha,
const DataType dt, const DataType dt,
Tensor *output, Tensor *output,
StatsFuture *future); StatsFuture *future);
...@@ -60,7 +57,7 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -60,7 +57,7 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
cl::Kernel *kernel, cl::Kernel *kernel,
const Tensor *input, const Tensor *filter, const Tensor *bias, const int stride, const Tensor *input, const Tensor *filter, const Tensor *bias, const int stride,
const int *padding, const int *dilations, const ActivationType activation, const int *padding, const int *dilations, const ActivationType activation,
const float relux_max_limit, const float prelu_alpha, const DataType dt, const float relux_max_limit, const DataType dt,
Tensor *output, StatsFuture *future); Tensor *output, StatsFuture *future);
// Selection matrix: kernel_size x stride_size // Selection matrix: kernel_size x stride_size
static const Conv2dOpenclFunction selector[5] = static const Conv2dOpenclFunction selector[5] =
...@@ -69,7 +66,6 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -69,7 +66,6 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
index_t kernel_h = filter->dim(0); index_t kernel_h = filter->dim(0);
index_t kernel_w = filter->dim(1); index_t kernel_w = filter->dim(1);
if (!input->is_image() || strides_[0] != strides_[1] || if (!input->is_image() || strides_[0] != strides_[1] ||
((kernel_h == 1 || kernel_h == 3) && strides_[0] > 2) ||
(dilations_[0] > 1 && (strides_[0] > 1 || kernel_h == 1))) { (dilations_[0] > 1 && (strides_[0] > 1 || kernel_h == 1))) {
LOG(WARNING) << "OpenCL conv2d kernel with " LOG(WARNING) << "OpenCL conv2d kernel with "
<< "filter" << kernel_h << "x" << kernel_w << "," << "filter" << kernel_h << "x" << kernel_w << ","
...@@ -82,11 +78,14 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -82,11 +78,14 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input->shape().data(), filter->shape().data(), dilations_, strides_, kernels::CalcNHWCPaddingAndOutputSize(
padding_type_, output_shape.data(), paddings.data()); input->shape().data(), filter->shape().data(), dilations_, strides_,
if (!paddings_.empty()) { padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), filter->shape().data(), paddings_.data(),
dilations_, strides_, RoundType::FLOOR, output_shape.data());
} }
std::vector<size_t> output_image_shape; std::vector<size_t> output_image_shape;
...@@ -94,16 +93,13 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -94,16 +93,13 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
output->ResizeImage(output_shape, output_image_shape); output->ResizeImage(output_shape, output_image_shape);
if (kernel_h == kernel_w && kernel_h <= 5 && if (kernel_h == kernel_w && kernel_h <= 5 &&
selector[kernel_h - 1] != nullptr && selector[kernel_h - 1] != nullptr) {
0 < strides_[0] && strides_[0] < 3 ) {
auto conv2d_func = selector[kernel_h - 1]; auto conv2d_func = selector[kernel_h - 1];
conv2d_func(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, activation_, conv2d_func(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, activation_,
relux_max_limit_, prelu_alpha_, DataTypeToEnum<T>::value, relux_max_limit_, DataTypeToEnum<T>::value, output, future);
output, future);
} else { } else {
Conv2dOpencl(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, Conv2dOpencl(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_,
activation_, relux_max_limit_, prelu_alpha_, activation_, relux_max_limit_, DataTypeToEnum<T>::value, output, future);
DataTypeToEnum<T>::value, output, future);
} }
} }
......
...@@ -19,7 +19,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, ...@@ -19,7 +19,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit,
const float prelu_alpha,
const DataType dt, const DataType dt,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
...@@ -44,7 +43,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, ...@@ -44,7 +43,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
built_options.emplace("-Dconv_2d_1x1=" + kernel_name); built_options.emplace("-Dconv_2d_1x1=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(MakeString("-DSTRIDE=", stride));
if (bias != nullptr) { if (bias != nullptr) {
built_options.emplace("-DBIAS"); built_options.emplace("-DBIAS");
} }
...@@ -57,9 +55,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, ...@@ -57,9 +55,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
case RELUX: case RELUX:
built_options.emplace("-DUSE_RELUX"); built_options.emplace("-DUSE_RELUX");
break; break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH: case TANH:
built_options.emplace("-DUSE_TANH"); built_options.emplace("-DUSE_TANH");
break; break;
...@@ -87,12 +82,12 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, ...@@ -87,12 +82,12 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
*(static_cast<const cl::Image2D *>(output->buffer()))); *(static_cast<const cl::Image2D *>(output->buffer())));
// FIXME handle flexable data type: half not supported // FIXME handle flexable data type: half not supported
kernel->setArg(idx++, relux_max_limit); kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, prelu_alpha);
kernel->setArg(idx++, static_cast<int>(input_height)); kernel->setArg(idx++, static_cast<int>(input_height));
kernel->setArg(idx++, static_cast<int>(input_width)); kernel->setArg(idx++, static_cast<int>(input_width));
kernel->setArg(idx++, static_cast<int>(input_channel_blocks)); kernel->setArg(idx++, static_cast<int>(input_channel_blocks));
kernel->setArg(idx++, static_cast<int>(height)); kernel->setArg(idx++, static_cast<int>(height));
kernel->setArg(idx++, static_cast<int>(width)); kernel->setArg(idx++, static_cast<int>(width));
kernel->setArg(idx++, stride);
} }
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
......
...@@ -21,7 +21,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, ...@@ -21,7 +21,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit,
const float prelu_alpha,
const DataType dt, const DataType dt,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
...@@ -42,7 +41,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, ...@@ -42,7 +41,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : ""); built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace(MakeString("-DSTRIDE=", stride));
switch (activation) { switch (activation) {
case NOOP: case NOOP:
break; break;
...@@ -52,9 +50,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, ...@@ -52,9 +50,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
case RELUX: case RELUX:
built_options.emplace("-DUSE_RELUX"); built_options.emplace("-DUSE_RELUX");
break; break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH: case TANH:
built_options.emplace("-DUSE_TANH"); built_options.emplace("-DUSE_TANH");
break; break;
...@@ -81,12 +76,12 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, ...@@ -81,12 +76,12 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
kernel->setArg(idx++, kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer()))); *(static_cast<const cl::Image2D *>(output->buffer())));
kernel->setArg(idx++, relux_max_limit); kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, prelu_alpha);
kernel->setArg(idx++, static_cast<int>(input->dim(1))); kernel->setArg(idx++, static_cast<int>(input->dim(1)));
kernel->setArg(idx++, static_cast<int>(input->dim(2))); kernel->setArg(idx++, static_cast<int>(input->dim(2)));
kernel->setArg(idx++, static_cast<int>(input_channel_blocks)); kernel->setArg(idx++, static_cast<int>(input_channel_blocks));
kernel->setArg(idx++, static_cast<int>(height)); kernel->setArg(idx++, static_cast<int>(height));
kernel->setArg(idx++, static_cast<int>(width)); kernel->setArg(idx++, static_cast<int>(width));
kernel->setArg(idx++, stride);
kernel->setArg(idx++, padding[0] / 2); kernel->setArg(idx++, padding[0] / 2);
kernel->setArg(idx++, padding[1] / 2); kernel->setArg(idx++, padding[1] / 2);
kernel->setArg(idx++, dilations[0]); kernel->setArg(idx++, dilations[0]);
......
...@@ -21,7 +21,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel, ...@@ -21,7 +21,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit,
const float prelu_alpha,
const DataType dt, const DataType dt,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
...@@ -42,7 +41,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel, ...@@ -42,7 +41,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : ""); built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace(MakeString("-DSTRIDE=", stride));
switch (activation) { switch (activation) {
case NOOP: case NOOP:
break; break;
...@@ -52,9 +50,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel, ...@@ -52,9 +50,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
case RELUX: case RELUX:
built_options.emplace("-DUSE_RELUX"); built_options.emplace("-DUSE_RELUX");
break; break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH: case TANH:
built_options.emplace("-DUSE_TANH"); built_options.emplace("-DUSE_TANH");
break; break;
...@@ -81,7 +76,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel, ...@@ -81,7 +76,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
kernel->setArg(idx++, kernel->setArg(idx++,
*(static_cast<const cl::Image2D *>(output->buffer()))); *(static_cast<const cl::Image2D *>(output->buffer())));
kernel->setArg(idx++, relux_max_limit); kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, prelu_alpha);
kernel->setArg(idx++, static_cast<uint32_t>(input->dim(1))); kernel->setArg(idx++, static_cast<uint32_t>(input->dim(1)));
kernel->setArg(idx++, static_cast<uint32_t>(input->dim(2))); kernel->setArg(idx++, static_cast<uint32_t>(input->dim(2)));
kernel->setArg(idx++, static_cast<uint32_t>(input_channel_blocks)); kernel->setArg(idx++, static_cast<uint32_t>(input_channel_blocks));
......
...@@ -20,7 +20,6 @@ void DepthwiseConv2d(cl::Kernel *kernel, ...@@ -20,7 +20,6 @@ void DepthwiseConv2d(cl::Kernel *kernel,
const int *dilations, const int *dilations,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit,
const float prelu_alpha,
const DataType dt, const DataType dt,
Tensor *output, Tensor *output,
StatsFuture *future) { StatsFuture *future) {
...@@ -69,9 +68,6 @@ void DepthwiseConv2d(cl::Kernel *kernel, ...@@ -69,9 +68,6 @@ void DepthwiseConv2d(cl::Kernel *kernel,
case RELUX: case RELUX:
built_options.emplace("-DUSE_RELUX"); built_options.emplace("-DUSE_RELUX");
break; break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH: case TANH:
built_options.emplace("-DUSE_TANH"); built_options.emplace("-DUSE_TANH");
break; break;
...@@ -96,7 +92,6 @@ void DepthwiseConv2d(cl::Kernel *kernel, ...@@ -96,7 +92,6 @@ void DepthwiseConv2d(cl::Kernel *kernel,
kernel->setArg( kernel->setArg(
idx++, *(static_cast<const cl::Image2D *>(output->buffer()))); idx++, *(static_cast<const cl::Image2D *>(output->buffer())));
kernel->setArg(idx++, relux_max_limit); kernel->setArg(idx++, relux_max_limit);
kernel->setArg(idx++, prelu_alpha);
kernel->setArg(idx++, static_cast<short>(input_height)); kernel->setArg(idx++, static_cast<short>(input_height));
kernel->setArg(idx++, static_cast<short>(input_width)); kernel->setArg(idx++, static_cast<short>(input_width));
kernel->setArg(idx++, static_cast<short>(input_channel_blocks)); kernel->setArg(idx++, static_cast<short>(input_channel_blocks));
...@@ -140,8 +135,8 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -140,8 +135,8 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()(
<< " is not implemented yet, using slow version"; << " is not implemented yet, using slow version";
// TODO(heliangliang) The CPU/NEON kernel should map the buffer // TODO(heliangliang) The CPU/NEON kernel should map the buffer
DepthwiseConv2dFunctor<DeviceType::CPU, float>( DepthwiseConv2dFunctor<DeviceType::CPU, float>(
strides_, padding_type_, paddings_, dilations_, activation_, relux_max_limit_, strides_, padding_type_, paddings_, dilations_, activation_,
prelu_alpha_)(input, filter, bias, output, future); relux_max_limit_)(input, filter, bias, output, future);
return; return;
} }
...@@ -154,11 +149,14 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -154,11 +149,14 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()(
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input->shape().data(), fake_filter_shape.data(), dilations_, strides_, kernels::CalcNHWCPaddingAndOutputSize(
padding_type_, output_shape.data(), paddings.data()); input->shape().data(), fake_filter_shape.data(), dilations_, strides_,
if (!paddings_.empty()) { padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), fake_filter_shape.data(), paddings_.data(),
dilations_, strides_, RoundType::FLOOR, output_shape.data());
} }
std::vector<size_t> output_image_shape; std::vector<size_t> output_image_shape;
...@@ -166,7 +164,7 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -166,7 +164,7 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()(
output->ResizeImage(output_shape, output_image_shape); output->ResizeImage(output_shape, output_image_shape);
DepthwiseConv2d(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, DepthwiseConv2d(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_,
activation_, relux_max_limit_, prelu_alpha_, activation_, relux_max_limit_,
DataTypeToEnum<T>::value, output, future); DataTypeToEnum<T>::value, output, future);
} }
......
...@@ -48,9 +48,6 @@ void FullyConnectedFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -48,9 +48,6 @@ void FullyConnectedFunctor<DeviceType::OPENCL, T>::operator()(
case RELUX: case RELUX:
built_options.emplace("-DUSE_RELUX"); built_options.emplace("-DUSE_RELUX");
break; break;
case PRELU:
built_options.emplace("-DUSE_PRELU");
break;
case TANH: case TANH:
built_options.emplace("-DUSE_TANH"); built_options.emplace("-DUSE_TANH");
break; break;
...@@ -78,7 +75,6 @@ void FullyConnectedFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -78,7 +75,6 @@ void FullyConnectedFunctor<DeviceType::OPENCL, T>::operator()(
kernel_.setArg(idx++, static_cast<int>(input->dim(3))); kernel_.setArg(idx++, static_cast<int>(input->dim(3)));
// FIXME handle flexable data type: half not supported // FIXME handle flexable data type: half not supported
kernel_.setArg(idx++, relux_max_limit_); kernel_.setArg(idx++, relux_max_limit_);
kernel_.setArg(idx++, prelu_alpha_);
} }
const uint32_t gws[2] = { const uint32_t gws[2] = {
......
...@@ -24,12 +24,14 @@ void PoolingFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -24,12 +24,14 @@ void PoolingFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
}; };
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input->shape().data(), filter_shape.data(), kernels::CalcNHWCPaddingAndOutputSize(
dilations_, strides_, this->padding_type_, input->shape().data(), filter_shape.data(), dilations_, strides_,
output_shape.data(), paddings.data()); padding_type_, output_shape.data(), paddings.data());
if (!paddings_.empty()) { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), filter_shape.data(), paddings_.data(),
dilations_, strides_, RoundType::CEIL, output_shape.data());
} }
std::vector<size_t> output_image_shape; std::vector<size_t> output_image_shape;
......
...@@ -18,11 +18,14 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i ...@@ -18,11 +18,14 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {3, 3, input_tensor->dim(3), 1}; std::vector<index_t> filter_shape = {3, 3, input_tensor->dim(3), 1};
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input_tensor->shape().data(), filter_shape.data(), dilations_.data(), kernels::CalcNHWCPaddingAndOutputSize(
strides_.data(), padding_type_, output_shape.data(), paddings.data()); input_tensor->shape().data(), filter_shape.data(), dilations_.data(), strides_.data(),
if (!paddings_.empty()) { padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input_tensor->shape().data(), filter_shape.data(), paddings_.data(),
dilations_.data(), strides_.data(), RoundType::FLOOR, output_shape.data());
} }
const index_t round_h = (output_shape[1] + 1) / 2; const index_t round_h = (output_shape[1] + 1) / 2;
...@@ -126,7 +129,6 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te ...@@ -126,7 +129,6 @@ void WinogradInverseTransformFunctor<DeviceType::OPENCL, T>::operator()(const Te
kernel_.setArg(idx++, static_cast<uint32_t>(round_h * round_w)); kernel_.setArg(idx++, static_cast<uint32_t>(round_h * round_w));
kernel_.setArg(idx++, static_cast<uint32_t>(round_w)); kernel_.setArg(idx++, static_cast<uint32_t>(round_w));
kernel_.setArg(idx++, relux_max_limit_); kernel_.setArg(idx++, relux_max_limit_);
kernel_.setArg(idx++, prelu_alpha_);
} }
const uint32_t gws[2] = {static_cast<uint32_t>(input_tensor->dim(2)), const uint32_t gws[2] = {static_cast<uint32_t>(input_tensor->dim(2)),
......
...@@ -65,12 +65,14 @@ struct PoolingFunctor : PoolingFunctorBase { ...@@ -65,12 +65,14 @@ struct PoolingFunctor : PoolingFunctorBase {
}; };
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input_tensor->shape().data(), filter_shape.data(), kernels::CalcNHWCPaddingAndOutputSize(
dilations_, strides_, this->padding_type_, input_tensor->shape().data(), filter_shape.data(), dilations_, strides_,
output_shape.data(), paddings.data()); padding_type_, output_shape.data(), paddings.data());
if (!paddings_.empty()) { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input_tensor->shape().data(), filter_shape.data(), paddings_.data(),
dilations_, strides_, RoundType::CEIL, output_shape.data());
} }
output_tensor->Resize(output_shape); output_tensor->Resize(output_shape);
......
...@@ -58,21 +58,18 @@ struct WinogradInverseTransformFunctorBase { ...@@ -58,21 +58,18 @@ struct WinogradInverseTransformFunctorBase {
const int height, const int height,
const int width, const int width,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha)
: batch_(batch), : batch_(batch),
height_(height), height_(height),
width_(width), width_(width),
activation_(activation), activation_(activation),
relux_max_limit_(relux_max_limit), relux_max_limit_(relux_max_limit) {}
prelu_alpha_(prelu_alpha) {}
const int batch_; const int batch_;
const int height_; const int height_;
const int width_; const int width_;
const ActivationType activation_; const ActivationType activation_;
const float relux_max_limit_; const float relux_max_limit_;
const float prelu_alpha_;
}; };
template<DeviceType D, typename T> template<DeviceType D, typename T>
...@@ -81,9 +78,8 @@ struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase { ...@@ -81,9 +78,8 @@ struct WinogradInverseTransformFunctor : WinogradInverseTransformFunctorBase {
const int height, const int height,
const int width, const int width,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha) : WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit) {}
: WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit, prelu_alpha) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *bias, const Tensor *bias,
...@@ -100,9 +96,8 @@ struct WinogradInverseTransformFunctor<DeviceType::OPENCL, T> : WinogradInverseT ...@@ -100,9 +96,8 @@ struct WinogradInverseTransformFunctor<DeviceType::OPENCL, T> : WinogradInverseT
const int height, const int height,
const int width, const int width,
const ActivationType activation, const ActivationType activation,
const float relux_max_limit, const float relux_max_limit)
const float prelu_alpha) : WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit) {}
: WinogradInverseTransformFunctorBase(batch, height, width, activation, relux_max_limit, prelu_alpha) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *bias, const Tensor *bias,
......
...@@ -18,15 +18,15 @@ class ActivationOp : public Operator<D, T> { ...@@ -18,15 +18,15 @@ class ActivationOp : public Operator<D, T> {
functor_(kernels::StringToActivationType( functor_(kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f), OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->inputs_[0]; const Tensor *input_tensor = this->Input(0);
const Tensor *alpha_tensor = this->InputSize() >= 2 ? this->Input(1) : nullptr;
Tensor *output_tensor = this->outputs_[0]; Tensor *output_tensor = this->outputs_[0];
output_tensor->ResizeLike(input_tensor); output_tensor->ResizeLike(input_tensor);
functor_(input_tensor, output_tensor, future); functor_(input_tensor, alpha_tensor, output_tensor, future);
return true; return true;
} }
......
...@@ -213,17 +213,22 @@ void TestSimplePrelu() { ...@@ -213,17 +213,22 @@ void TestSimplePrelu() {
// Add input data // Add input data
net.AddInputFromArray<D, float>( net.AddInputFromArray<D, float>(
"Input", {2, 2, 2, 2}, "Input", {2, 2, 2, 2},
{-7, 7, -6, 6, -5, 5, -4, 4, -3, 3, -2, 2, -1, 1, 0, 0}); {-7, 7, -6, 6, -5, -5, -4, -4, -3, 3, -2, 2, -1, -1, 0, 0});
net.AddInputFromArray<D, float>(
"Alpha", {2},
{2.0, 3.0});
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D, float>(net, "Input", "InputImage", BufferToImage<D, float>(net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(net, "Alpha", "AlphaImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("Activation", "PreluTest") OpDefBuilder("Activation", "PreluTest")
.Input("InputImage") .Input("InputImage")
.Input("AlphaImage")
.Output("OutputImage") .Output("OutputImage")
.AddStringArg("activation", "PRELU") .AddStringArg("activation", "PRELU")
.AddFloatArg("alpha", 2.0)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -235,9 +240,9 @@ void TestSimplePrelu() { ...@@ -235,9 +240,9 @@ void TestSimplePrelu() {
} else { } else {
OpDefBuilder("Activation", "PreluTest") OpDefBuilder("Activation", "PreluTest")
.Input("Input") .Input("Input")
.Input("Alpha")
.Output("Output") .Output("Output")
.AddStringArg("activation", "PRELU") .AddStringArg("activation", "PRELU")
.AddFloatArg("alpha", 2.0)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -245,7 +250,7 @@ void TestSimplePrelu() { ...@@ -245,7 +250,7 @@ void TestSimplePrelu() {
} }
auto expected = CreateTensor<float>( auto expected = CreateTensor<float>(
{2, 2, 2, 2}, {-14, 7, -12, 6, -10, 5, -8, 4, -6, 3, -4, 2, -2, 1, 0, 0}); {2, 2, 2, 2}, {-14, 7, -12, 6, -10, -15, -8, -12, -6, 3, -4, 2, -2, -3, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
} }
......
...@@ -16,7 +16,7 @@ class BatchNormOp : public Operator<D, T> { ...@@ -16,7 +16,7 @@ class BatchNormOp : public Operator<D, T> {
public: public:
BatchNormOp(const OperatorDef &operator_def, Workspace *ws) BatchNormOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws), : Operator<D, T>(operator_def, ws),
functor_(false, kernels::ActivationType::NOOP, 0.0f, 0.0f) { functor_(false, kernels::ActivationType::NOOP, 0.0f) {
epsilon_ = OperatorBase::GetSingleArgument<float>("epsilon", epsilon_ = OperatorBase::GetSingleArgument<float>("epsilon",
static_cast<float>(1e-4)); static_cast<float>(1e-4));
} }
......
...@@ -23,7 +23,6 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -23,7 +23,6 @@ class Conv2dOp : public ConvPool2dOpBase<D, T> {
this->paddings_, this->paddings_,
this->dilations_.data(), this->dilations_.data(),
kernels::ActivationType::NOOP, kernels::ActivationType::NOOP,
0.0f,
0.0f) {} 0.0f) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
......
...@@ -342,7 +342,7 @@ TEST_F(Conv2dOpTest, CPUConv1x1) { TestConv1x1<DeviceType::CPU>(); } ...@@ -342,7 +342,7 @@ TEST_F(Conv2dOpTest, CPUConv1x1) { TestConv1x1<DeviceType::CPU>(); }
TEST_F(Conv2dOpTest, OPENCLConv1x1) { TestConv1x1<DeviceType::OPENCL>(); } TEST_F(Conv2dOpTest, OPENCLConv1x1) { TestConv1x1<DeviceType::OPENCL>(); }
template <DeviceType D, typename T> template <DeviceType D, typename T>
static void TestComplexConvNxNS12(const std::vector<index_t> &shape) { static void TestComplexConvNxNS12(const std::vector<index_t> &shape, const int stride) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) { Padding type) {
...@@ -405,20 +405,31 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) { ...@@ -405,20 +405,31 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.001); ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.001);
}; };
for (int kernel_size : {1, 3}) { for (int kernel_size : {1, 3, 7}) {
for (int stride : {1, 2}) { func(kernel_size, kernel_size, stride, stride, VALID);
func(kernel_size, kernel_size, stride, stride, VALID); func(kernel_size, kernel_size, stride, stride, SAME);
func(kernel_size, kernel_size, stride, stride, SAME);
}
} }
} }
TEST_F(Conv2dOpTest, OPENCLAlignedConvNxNS12) { TEST_F(Conv2dOpTest, OPENCLAlignedConvNxNS12) {
TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 32, 32, 64}); TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 16, 16, 32},
1);
TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 16, 16, 32},
2);
} }
TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) { TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) {
TestComplexConvNxNS12<DeviceType::OPENCL, float>({107, 113, 5, 7}); TestComplexConvNxNS12<DeviceType::OPENCL, float>({17, 113, 5, 7},
1);
TestComplexConvNxNS12<DeviceType::OPENCL, float>({17, 113, 5, 7},
2);
}
TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS34) {
TestComplexConvNxNS12<DeviceType::OPENCL, float>({31, 113, 13, 17},
3);
TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 32, 13, 17},
4);
} }
template<DeviceType D> template<DeviceType D>
...@@ -650,3 +661,81 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedDilation4) { ...@@ -650,3 +661,81 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedDilation4) {
4); 4);
} }
template<DeviceType D, typename T>
static void TestArbitraryPadConvNxN(const std::vector<index_t> &shape, const std::vector<int> &paddings) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w) {
srand(time(NULL));
// generate random input
index_t batch = 1;
index_t height = shape[0];
index_t width = shape[1];
index_t input_channels = shape[2];
index_t output_channels = shape[3];
// Construct graph
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("padding_values", paddings)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// run on gpu
BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER);
BufferToImage<D, T>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage")
.Input("FilterImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("padding_values", paddings)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on device
net.RunOp(D);
ImageToBuffer<D, T>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT_CHANNEL);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.001);
};
for (int kernel_size : {3, 5}) {
for (int stride : {2, 3}) {
func(kernel_size, kernel_size, stride, stride);
}
}
}
TEST_F(Conv2dOpTest, OPENCLAlignedPad1) {
TestArbitraryPadConvNxN<DeviceType::OPENCL, float>({32, 32, 32, 64},
{1, 1});
}
TEST_F(Conv2dOpTest, OPENCLAlignedPad2) {
TestArbitraryPadConvNxN<DeviceType::OPENCL, float>({128, 128, 16, 16},
{2, 2});
}
TEST_F(Conv2dOpTest, OPENCLUnalignedPad4) {
TestArbitraryPadConvNxN<DeviceType::OPENCL, float>({107, 113, 5, 7},
{4, 4});
}
...@@ -26,8 +26,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -26,8 +26,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
kernels::StringToActivationType( kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f), OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
......
...@@ -19,8 +19,7 @@ class FoldedBatchNormOp : public Operator<D, T> { ...@@ -19,8 +19,7 @@ class FoldedBatchNormOp : public Operator<D, T> {
kernels::StringToActivationType( kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f), OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
......
...@@ -19,8 +19,7 @@ class FullyConnectedOp : public Operator<D, T> { ...@@ -19,8 +19,7 @@ class FullyConnectedOp : public Operator<D, T> {
kernels::StringToActivationType( kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f), OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
......
...@@ -25,8 +25,7 @@ class FusedConv2dOp : public ConvPool2dOpBase<D, T> { ...@@ -25,8 +25,7 @@ class FusedConv2dOp : public ConvPool2dOpBase<D, T> {
kernels::StringToActivationType( kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f), OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
......
...@@ -24,8 +24,7 @@ class WinogradInverseTransformOp : public Operator<D, T> { ...@@ -24,8 +24,7 @@ class WinogradInverseTransformOp : public Operator<D, T> {
kernels::StringToActivationType( kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation", OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")), "NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f), OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
OperatorBase::GetSingleArgument<float>("alpha", 0.0f)) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT); const Tensor *input_tensor = this->Input(INPUT);
......
...@@ -5,10 +5,14 @@ import functools ...@@ -5,10 +5,14 @@ import functools
import argparse import argparse
import sys import sys
import six import six
import os.path
FLAGS = None FLAGS = None
def main(unused_args): def main(unused_args):
if not os.path.isfile(FLAGS.input):
print 'input model file not exist'
return -1
net = caffe_pb2.NetParameter() net = caffe_pb2.NetParameter()
with open(FLAGS.input) as f: with open(FLAGS.input) as f:
google.protobuf.text_format.Merge(str(f.read()), net) google.protobuf.text_format.Merge(str(f.read()), net)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册