“c907a8deda85f399978ce005cbf4f94d6238673b”上不存在“examples/aishell/asr1/conf/chunk_roformer_bidecoder.yaml”
提交 de12a6df 编写于 作者: L liuqi

Support arbitrary stride for conv 1x1 and 3x3.

上级 7ad87bee
...@@ -229,11 +229,14 @@ struct Conv2dFunctor : Conv2dFunctorBase { ...@@ -229,11 +229,14 @@ struct Conv2dFunctor : Conv2dFunctorBase {
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input->shape().data(), filter->shape().data(), dilations_, strides_, kernels::CalcNHWCPaddingAndOutputSize(
padding_type_, output_shape.data(), paddings.data()); input->shape().data(), filter->shape().data(), dilations_, strides_,
if (!paddings_.empty()) { padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), filter->shape().data(), paddings_.data(),
dilations_, strides_, RoundType::FLOOR, output_shape.data());
} }
output->Resize(output_shape); output->Resize(output_shape);
......
...@@ -135,6 +135,44 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC ...@@ -135,6 +135,44 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
output_shape[3] = output_channels; output_shape[3] = output_channels;
} }
void CalcOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // HWOI
const int *padding_size,
const int *dilations,
const int *strides,
const RoundType round_type,
index_t *output_shape) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
(dilations[1] == 1 || strides[1] == 1),
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
/*
* Convlution arithmetic:
* o = floor((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1
* Pooling arithmetic:
* o = ceil((i + 2 * p - k - (k - 1) * (d - 1)) / s) + 1
* For details, see https://arxiv.org/pdf/1603.07285.pdf or
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
output_shape[0] = input_shape[0];
if (round_type == FLOOR) {
output_shape[1] = static_cast<index_t>(std::floor(1.0 * (input_shape[1] + padding_size[0]
- filter_shape[0] - (filter_shape[0] - 1) * (dilations[0] - 1)) / strides[0]) + 1);
output_shape[2] = static_cast<index_t>(std::floor(1.0 * (input_shape[2] + padding_size[1]
- filter_shape[1] - (filter_shape[1] - 1) * (dilations[1] - 1)) / strides[1]) + 1);
} else {
output_shape[1] = static_cast<index_t>(std::ceil(1.0 * (input_shape[1] + padding_size[0]
- filter_shape[0] - (filter_shape[0] - 1) * (dilations[0] - 1)) / strides[0]) + 1);
output_shape[2] = static_cast<index_t>(std::ceil(1.0 * (input_shape[2] + padding_size[1]
- filter_shape[1] - (filter_shape[1] - 1) * (dilations[1] - 1)) / strides[1]) + 1);
}
output_shape[3] = filter_shape[2];
}
void CalPaddingSize(const index_t *input_shape, // NCHW void CalPaddingSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW const index_t *filter_shape, // OIHW
const int *dilations, const int *dilations,
......
...@@ -15,6 +15,11 @@ enum Padding { ...@@ -15,6 +15,11 @@ enum Padding {
FULL = 2, // Pads with one less than the filter size on both sides FULL = 2, // Pads with one less than the filter size on both sides
}; };
enum RoundType{
FLOOR = 0,
CEIL = 1,
};
namespace kernels { namespace kernels {
void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
...@@ -33,6 +38,14 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NCHW ...@@ -33,6 +38,14 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NCHW
index_t *output_shape, index_t *output_shape,
int *padding_size); int *padding_size);
void CalcOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // HWOI
const int *padding_size,
const int *dilations,
const int *strides,
const RoundType round_type,
index_t *output_shape);
void CalPaddingSize(const index_t *input_shape, // NCHW void CalPaddingSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW const index_t *filter_shape, // OIHW
const int *dilations, const int *dilations,
......
...@@ -295,11 +295,14 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase { ...@@ -295,11 +295,14 @@ struct DepthwiseConv2dFunctor : public DepthwiseConv2dFunctorBase {
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input->shape().data(), fake_filter_shape.data(), dilations_, strides_, kernels::CalcNHWCPaddingAndOutputSize(
padding_type_, output_shape.data(), paddings.data()); input->shape().data(), fake_filter_shape.data(), dilations_, strides_,
if (!paddings_.empty()) { padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), fake_filter_shape.data(), paddings_.data(),
dilations_, strides_, RoundType::FLOOR, output_shape.data());
} }
auto input_shape = fake_filter_shape; auto input_shape = fake_filter_shape;
output->Resize(output_shape); output->Resize(output_shape);
......
...@@ -12,7 +12,8 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -12,7 +12,8 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
__private const int in_width, __private const int in_width,
__private const int in_ch_blks, __private const int in_ch_blks,
__private const int height, __private const int height,
__private const int width) { __private const int width,
__private const int stride) {
const int out_ch_blk = get_global_id(0); const int out_ch_blk = get_global_id(0);
const int out_w_blk = get_global_id(1); const int out_w_blk = get_global_id(1);
const int out_w_blks = get_global_size(1); const int out_w_blks = get_global_size(1);
...@@ -31,19 +32,12 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -31,19 +32,12 @@ __kernel void conv_2d_1x1(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
#endif #endif
int4 w; int4 w;
#if STRIDE == 1 int in_width_stride = mul24(out_w_blks, stride);
w.x = out_w_blk; w.x = mul24(out_w_blk, stride);
w.y = w.x + out_w_blks; w.y = w.x + in_width_stride;
w.z = w.y + out_w_blks; w.z = w.y + in_width_stride;
w.w = w.z + out_w_blks; w.w = w.z + in_width_stride;
int out_hb_idx = (out_hb % height); int out_hb_idx = mul24((out_hb % height), stride);
#elif STRIDE == 2
w.x = out_w_blk << 1;
w.y = (out_w_blk + out_w_blks) << 1;
w.z = (out_w_blk + (out_w_blks << 1)) << 1;
w.w = (out_w_blk + (out_w_blks << 1) + out_w_blks) << 1;
int out_hb_idx = (out_hb % height) << 1;
#endif
w.x = select(w.x, INT_MIN, w.x >= in_width); w.x = select(w.x, INT_MIN, w.x >= in_width);
w.y = select(w.y, INT_MIN, w.y >= in_width); w.y = select(w.y, INT_MIN, w.y >= in_width);
......
...@@ -13,6 +13,7 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -13,6 +13,7 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
__private const int in_ch_blks, __private const int in_ch_blks,
__private const int out_height, __private const int out_height,
__private const int out_width, __private const int out_width,
__private const int stride,
__private const int padding_top, __private const int padding_top,
__private const int padding_left, __private const int padding_left,
__private const int dilation_h, __private const int dilation_h,
...@@ -38,21 +39,13 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -38,21 +39,13 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
DATA_TYPE4 out4 = 0; DATA_TYPE4 out4 = 0;
#endif #endif
#if STRIDE == 1 int in_width_stride = mul24(out_w_blks, stride);
int in_width0 = out_w_blk - padding_left; int in_width0 = mad24(out_w_blk, stride, -padding_left);
int in_width1 = in_width0 + out_w_blks; int in_width1 = in_width0 + in_width_stride;
int in_width2 = in_width1 + out_w_blks; int in_width2 = in_width1 + in_width_stride;
int in_width3 = in_width2 + out_w_blks; int in_width3 = in_width2 + in_width_stride;
int in_width4 = in_width3 + out_w_blks; int in_width4 = in_width3 + in_width_stride;
const int height_idx = (out_hb % out_height) - padding_top; const int height_idx = mad24((out_hb % out_height), stride, -padding_top);
#elif STRIDE == 2
int in_width0 = (out_w_blk << 1) - padding_left;
int in_width1 = ((out_w_blk + out_w_blks) << 1) - padding_left;
int in_width2 = ((out_w_blk + (out_w_blks << 1)) << 1) - padding_left;
int in_width3 = ((out_w_blk + (out_w_blks << 1) + out_w_blks) << 1) - padding_left;
int in_width4 = ((out_w_blk + (out_w_blks << 2)) << 1) - padding_left;
const int height_idx = ((out_hb % out_height) << 1) - padding_top;
#endif
const int batch_idx = mul24((out_hb / out_height), in_height); const int batch_idx = mul24((out_hb / out_height), in_height);
const int rounded_in_ch_x_3 = (rounded_in_ch << 1) + rounded_in_ch; const int rounded_in_ch_x_3 = (rounded_in_ch << 1) + rounded_in_ch;
......
...@@ -69,7 +69,6 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -69,7 +69,6 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
index_t kernel_h = filter->dim(0); index_t kernel_h = filter->dim(0);
index_t kernel_w = filter->dim(1); index_t kernel_w = filter->dim(1);
if (!input->is_image() || strides_[0] != strides_[1] || if (!input->is_image() || strides_[0] != strides_[1] ||
((kernel_h == 1 || kernel_h == 3) && strides_[0] > 2) ||
(dilations_[0] > 1 && (strides_[0] > 1 || kernel_h == 1))) { (dilations_[0] > 1 && (strides_[0] > 1 || kernel_h == 1))) {
LOG(WARNING) << "OpenCL conv2d kernel with " LOG(WARNING) << "OpenCL conv2d kernel with "
<< "filter" << kernel_h << "x" << kernel_w << "," << "filter" << kernel_h << "x" << kernel_w << ","
...@@ -82,11 +81,14 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -82,11 +81,14 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input->shape().data(), filter->shape().data(), dilations_, strides_, kernels::CalcNHWCPaddingAndOutputSize(
padding_type_, output_shape.data(), paddings.data()); input->shape().data(), filter->shape().data(), dilations_, strides_,
if (!paddings_.empty()) { padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), filter->shape().data(), paddings_.data(),
dilations_, strides_, RoundType::FLOOR, output_shape.data());
} }
std::vector<size_t> output_image_shape; std::vector<size_t> output_image_shape;
...@@ -94,8 +96,7 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -94,8 +96,7 @@ void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
output->ResizeImage(output_shape, output_image_shape); output->ResizeImage(output_shape, output_image_shape);
if (kernel_h == kernel_w && kernel_h <= 5 && if (kernel_h == kernel_w && kernel_h <= 5 &&
selector[kernel_h - 1] != nullptr && selector[kernel_h - 1] != nullptr) {
0 < strides_[0] && strides_[0] < 3 ) {
auto conv2d_func = selector[kernel_h - 1]; auto conv2d_func = selector[kernel_h - 1];
conv2d_func(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, activation_, conv2d_func(&kernel_, input, filter, bias, strides_[0], paddings.data(), dilations_, activation_,
relux_max_limit_, prelu_alpha_, DataTypeToEnum<T>::value, relux_max_limit_, prelu_alpha_, DataTypeToEnum<T>::value,
......
...@@ -44,7 +44,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, ...@@ -44,7 +44,6 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
built_options.emplace("-Dconv_2d_1x1=" + kernel_name); built_options.emplace("-Dconv_2d_1x1=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(MakeString("-DSTRIDE=", stride));
if (bias != nullptr) { if (bias != nullptr) {
built_options.emplace("-DBIAS"); built_options.emplace("-DBIAS");
} }
...@@ -93,6 +92,7 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel, ...@@ -93,6 +92,7 @@ extern void Conv2dOpenclK1x1(cl::Kernel *kernel,
kernel->setArg(idx++, static_cast<int>(input_channel_blocks)); kernel->setArg(idx++, static_cast<int>(input_channel_blocks));
kernel->setArg(idx++, static_cast<int>(height)); kernel->setArg(idx++, static_cast<int>(height));
kernel->setArg(idx++, static_cast<int>(width)); kernel->setArg(idx++, static_cast<int>(width));
kernel->setArg(idx++, stride);
} }
const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks), const uint32_t gws[3] = {static_cast<uint32_t>(channel_blocks),
......
...@@ -42,7 +42,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, ...@@ -42,7 +42,6 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : ""); built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace(MakeString("-DSTRIDE=", stride));
switch (activation) { switch (activation) {
case NOOP: case NOOP:
break; break;
...@@ -87,6 +86,7 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel, ...@@ -87,6 +86,7 @@ extern void Conv2dOpenclK3x3(cl::Kernel *kernel,
kernel->setArg(idx++, static_cast<int>(input_channel_blocks)); kernel->setArg(idx++, static_cast<int>(input_channel_blocks));
kernel->setArg(idx++, static_cast<int>(height)); kernel->setArg(idx++, static_cast<int>(height));
kernel->setArg(idx++, static_cast<int>(width)); kernel->setArg(idx++, static_cast<int>(width));
kernel->setArg(idx++, stride);
kernel->setArg(idx++, padding[0] / 2); kernel->setArg(idx++, padding[0] / 2);
kernel->setArg(idx++, padding[1] / 2); kernel->setArg(idx++, padding[1] / 2);
kernel->setArg(idx++, dilations[0]); kernel->setArg(idx++, dilations[0]);
......
...@@ -42,7 +42,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel, ...@@ -42,7 +42,6 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt)); built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt)); built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(bias != nullptr ? "-DBIAS" : ""); built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace(MakeString("-DSTRIDE=", stride));
switch (activation) { switch (activation) {
case NOOP: case NOOP:
break; break;
......
...@@ -154,11 +154,14 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()( ...@@ -154,11 +154,14 @@ void DepthwiseConv2dFunctor<DeviceType::OPENCL, T>::operator()(
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input->shape().data(), fake_filter_shape.data(), dilations_, strides_, kernels::CalcNHWCPaddingAndOutputSize(
padding_type_, output_shape.data(), paddings.data()); input->shape().data(), fake_filter_shape.data(), dilations_, strides_,
if (!paddings_.empty()) { padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), fake_filter_shape.data(), paddings_.data(),
dilations_, strides_, RoundType::FLOOR, output_shape.data());
} }
std::vector<size_t> output_image_shape; std::vector<size_t> output_image_shape;
......
...@@ -24,12 +24,14 @@ void PoolingFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input, ...@@ -24,12 +24,14 @@ void PoolingFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
}; };
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input->shape().data(), filter_shape.data(), kernels::CalcNHWCPaddingAndOutputSize(
dilations_, strides_, this->padding_type_, input->shape().data(), filter_shape.data(), dilations_, strides_,
output_shape.data(), paddings.data()); padding_type_, output_shape.data(), paddings.data());
if (!paddings_.empty()) { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), filter_shape.data(), paddings_.data(),
dilations_, strides_, RoundType::CEIL, output_shape.data());
} }
std::vector<size_t> output_image_shape; std::vector<size_t> output_image_shape;
......
...@@ -18,11 +18,14 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i ...@@ -18,11 +18,14 @@ void WinogradTransformFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *i
std::vector<index_t> output_shape(4); std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {3, 3, input_tensor->dim(3), 1}; std::vector<index_t> filter_shape = {3, 3, input_tensor->dim(3), 1};
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input_tensor->shape().data(), filter_shape.data(), dilations_.data(), kernels::CalcNHWCPaddingAndOutputSize(
strides_.data(), padding_type_, output_shape.data(), paddings.data()); input_tensor->shape().data(), filter_shape.data(), dilations_.data(), strides_.data(),
if (!paddings_.empty()) { padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input_tensor->shape().data(), filter_shape.data(), paddings_.data(),
dilations_.data(), strides_.data(), RoundType::FLOOR, output_shape.data());
} }
const index_t round_h = (output_shape[1] + 1) / 2; const index_t round_h = (output_shape[1] + 1) / 2;
......
...@@ -65,12 +65,14 @@ struct PoolingFunctor : PoolingFunctorBase { ...@@ -65,12 +65,14 @@ struct PoolingFunctor : PoolingFunctorBase {
}; };
std::vector<int> paddings(2); std::vector<int> paddings(2);
kernels::CalcNHWCPaddingAndOutputSize( if (paddings_.empty()) {
input_tensor->shape().data(), filter_shape.data(), kernels::CalcNHWCPaddingAndOutputSize(
dilations_, strides_, this->padding_type_, input_tensor->shape().data(), filter_shape.data(), dilations_, strides_,
output_shape.data(), paddings.data()); padding_type_, output_shape.data(), paddings.data());
if (!paddings_.empty()) { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input_tensor->shape().data(), filter_shape.data(), paddings_.data(),
dilations_, strides_, RoundType::CEIL, output_shape.data());
} }
output_tensor->Resize(output_shape); output_tensor->Resize(output_shape);
......
...@@ -342,7 +342,7 @@ TEST_F(Conv2dOpTest, CPUConv1x1) { TestConv1x1<DeviceType::CPU>(); } ...@@ -342,7 +342,7 @@ TEST_F(Conv2dOpTest, CPUConv1x1) { TestConv1x1<DeviceType::CPU>(); }
TEST_F(Conv2dOpTest, OPENCLConv1x1) { TestConv1x1<DeviceType::OPENCL>(); } TEST_F(Conv2dOpTest, OPENCLConv1x1) { TestConv1x1<DeviceType::OPENCL>(); }
template <DeviceType D, typename T> template <DeviceType D, typename T>
static void TestComplexConvNxNS12(const std::vector<index_t> &shape) { static void TestComplexConvNxNS12(const std::vector<index_t> &shape, const int stride) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) { Padding type) {
...@@ -405,20 +405,31 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) { ...@@ -405,20 +405,31 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.001); ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.001);
}; };
for (int kernel_size : {1, 3}) { for (int kernel_size : {1, 3, 7}) {
for (int stride : {1, 2}) { func(kernel_size, kernel_size, stride, stride, VALID);
func(kernel_size, kernel_size, stride, stride, VALID); func(kernel_size, kernel_size, stride, stride, SAME);
func(kernel_size, kernel_size, stride, stride, SAME);
}
} }
} }
TEST_F(Conv2dOpTest, OPENCLAlignedConvNxNS12) { TEST_F(Conv2dOpTest, OPENCLAlignedConvNxNS12) {
TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 32, 32, 64}); TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 16, 16, 32},
1);
TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 16, 16, 32},
2);
} }
TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) { TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) {
TestComplexConvNxNS12<DeviceType::OPENCL, float>({107, 113, 5, 7}); TestComplexConvNxNS12<DeviceType::OPENCL, float>({17, 113, 5, 7},
1);
TestComplexConvNxNS12<DeviceType::OPENCL, float>({17, 113, 5, 7},
2);
}
TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS34) {
TestComplexConvNxNS12<DeviceType::OPENCL, float>({31, 113, 13, 17},
3);
TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 32, 13, 17},
4);
} }
template<DeviceType D> template<DeviceType D>
...@@ -650,3 +661,81 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedDilation4) { ...@@ -650,3 +661,81 @@ TEST_F(Conv2dOpTest, OPENCLUnalignedDilation4) {
4); 4);
} }
template<DeviceType D, typename T>
static void TestArbitraryPadConvNxN(const std::vector<index_t> &shape, const std::vector<int> &paddings) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w) {
srand(time(NULL));
// generate random input
index_t batch = 1;
index_t height = shape[0];
index_t width = shape[1];
index_t input_channels = shape[2];
index_t output_channels = shape[3];
// Construct graph
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("padding_values", paddings)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// run on gpu
BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, T>(net, "Filter", "FilterImage", kernels::BufferType::CONV2D_FILTER);
BufferToImage<D, T>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage")
.Input("FilterImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntsArg("padding_values", paddings)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on device
net.RunOp(D);
ImageToBuffer<D, T>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT_CHANNEL);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.001);
};
for (int kernel_size : {3, 5}) {
for (int stride : {2, 3}) {
func(kernel_size, kernel_size, stride, stride);
}
}
}
TEST_F(Conv2dOpTest, OPENCLAlignedPad1) {
TestArbitraryPadConvNxN<DeviceType::OPENCL, float>({32, 32, 32, 64},
{1, 1});
}
TEST_F(Conv2dOpTest, OPENCLAlignedPad2) {
TestArbitraryPadConvNxN<DeviceType::OPENCL, float>({128, 128, 16, 16},
{2, 2});
}
TEST_F(Conv2dOpTest, OPENCLUnalignedPad4) {
TestArbitraryPadConvNxN<DeviceType::OPENCL, float>({107, 113, 5, 7},
{4, 4});
}
...@@ -5,10 +5,14 @@ import functools ...@@ -5,10 +5,14 @@ import functools
import argparse import argparse
import sys import sys
import six import six
import os.path
FLAGS = None FLAGS = None
def main(unused_args): def main(unused_args):
if not os.path.isfile(FLAGS.input):
print 'input model file not exist'
return -1
net = caffe_pb2.NetParameter() net = caffe_pb2.NetParameter()
with open(FLAGS.input) as f: with open(FLAGS.input) as f:
google.protobuf.text_format.Merge(str(f.read()), net) google.protobuf.text_format.Merge(str(f.read()), net)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册