提交 de992bf8 编写于 作者: L Liangliang He

Merge branch 'depthwise_conv' into 'master'

Depthwise conv

See merge request !85
......@@ -132,6 +132,8 @@ void ConstructInputWithPadding(const float *input,
const int padded_left = paddings[1] / 2;
output_tensor->Resize(output_shape);
Tensor::MappingGuard padded_input_mapper(output_tensor);
float *output_ptr = output_tensor->mutable_data<float>();
memset(output_ptr, 0, output_tensor->size() * sizeof(float));
......
......@@ -20,27 +20,27 @@ struct DepthwiseConv2dFunctor {
const int *dilations)
: strides_(strides), paddings_(paddings), dilations_(dilations) {}
void operator()(const T *input, // NCHW
const index_t *input_shape,
const T *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const T *bias, // c_out
T *output, // NCHW
const index_t *output_shape) {
void operator()(const Tensor *input, // NCHW
const Tensor *filter, // c_out, c_in, kernel_h, kernel_w
const Tensor *bias, // c_out
Tensor *output) {
MACE_CHECK_NOTNULL(input);
MACE_CHECK_NOTNULL(filter);
MACE_CHECK_NOTNULL(bias);
MACE_CHECK_NOTNULL(output);
index_t batch = output_shape[0];
index_t channels = output_shape[1];
index_t height = output_shape[2];
index_t width = output_shape[3];
index_t batch = output->dim(0);
index_t channels = output->dim(1);
index_t height = output->dim(2);
index_t width = output->dim(3);
index_t input_batch = input_shape[0];
index_t input_channels = input_shape[1];
index_t input_height = input_shape[2];
index_t input_width = input_shape[3];
index_t input_batch = input->dim(0);
index_t input_channels = input->dim(1);
index_t input_height = input->dim(2);
index_t input_width = input->dim(3);
index_t kernel_h = filter_shape[2];
index_t kernel_w = filter_shape[3];
index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->dim(3);
int stride_h = strides_[0];
int stride_w = strides_[1];
......@@ -56,20 +56,29 @@ struct DepthwiseConv2dFunctor {
index_t padded_h_stop = input_height + paddings_[0] - paddings_[0] / 2;
index_t padded_w_stop = input_width + paddings_[1] - paddings_[1] / 2;
index_t kernel_size = filter_shape[1] * kernel_h * kernel_w;
index_t multiplier = channels / input_channels;
index_t kernel_size = kernel_h * kernel_w;
index_t multiplier = filter->dim(0);
Tensor::MappingGuard input_mapper(input);
Tensor::MappingGuard filter_mapper(filter);
Tensor::MappingGuard bias_mapper(bias);
Tensor::MappingGuard output_mapper(output);
const T *input_ptr = input->data<T>();
const T *filter_ptr = filter->data<T>();
const T *bias_ptr = bias->data<T>();
T *output_ptr = output->mutable_data<T>();
#pragma omp parallel for collapse(2)
for (int n = 0; n < batch; ++n) {
for (int c = 0; c < channels; ++c) {
T bias_channel = bias ? bias[c] : 0;
T bias_channel = bias_ptr ? bias_ptr[c] : 0;
for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) {
index_t offset = n * channels * height * width +
c * height * width + h * width + w;
output[offset] = bias_channel;
output_ptr[offset] = bias_channel;
T sum = 0;
const T *filter_ptr = filter + c * kernel_size;
const T *filter_base = filter_ptr + c * kernel_size;
for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) {
int inh = padded_h_start + h * stride_h + dilation_h * kh;
......@@ -79,19 +88,17 @@ struct DepthwiseConv2dFunctor {
MACE_CHECK(inh >= padded_h_start && inh < padded_h_stop &&
inw >= padded_w_start && inw < padded_w_stop,
"Out of range read from input: ", inh, ", ", inw);
// else padding with 0:
// sum += 0;
} else {
index_t input_offset =
n * input_channels * input_height * input_width +
(c / multiplier) * input_height * input_width +
inh * input_width + inw;
sum += input[input_offset] * *filter_ptr;
sum += input_ptr[input_offset] * *filter_base;
}
++filter_ptr;
++filter_base;
}
}
output[offset] += sum;
output_ptr[offset] += sum;
}
}
}
......@@ -105,13 +112,18 @@ struct DepthwiseConv2dFunctor {
template <>
void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(
const float *input,
const index_t *input_shape,
const float *filter,
const index_t *filter_shape,
const float *bias,
float *output,
const index_t *output_shape);
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output);
template <>
void DepthwiseConv2dFunctor<DeviceType::OPENCL, float>::operator()(
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output);
} // namespace kernels
} // namespace mace
......
......@@ -29,9 +29,9 @@ void Conv2dNeonK3x3S1(const float *input, // NCHW
int input_height = input_shape[2];
int input_width = input_shape[3];
int multiplier =
filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels);
filter_shape == nullptr ? 0 : filter_shape[0];
int filter_in_channels =
filter_shape == nullptr ? input_channels : filter_shape[1];
filter_shape == nullptr ? input_channels : 1;
#pragma omp parallel for collapse(2)
for (int b = 0; b < output_batch; ++b) {
for (int oc = 0; oc < output_channels; ++oc) {
......@@ -232,9 +232,9 @@ void Conv2dNeonK3x3S2(const float *input, // NCHW
int input_height = input_shape[2];
int input_width = input_shape[3];
int multiplier =
filter_shape == nullptr ? 0 : (filter_shape[0] / input_channels);
filter_shape == nullptr ? 0 : filter_shape[0];
int filter_in_channels =
filter_shape == nullptr ? input_channels : filter_shape[1];
filter_shape == nullptr ? input_channels : 1;
#pragma omp parallel for collapse(2)
for (int b = 0; b < output_batch; ++b) {
......
......@@ -26,13 +26,10 @@ extern void Conv2dNeonK3x3S2(const float *input,
template <>
void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(
const float *input, // NCHW
const index_t *input_shape,
const float *filter, // c_out, c_in, kernel_h, kernel_w
const index_t *filter_shape,
const float *bias, // c_out
float *output, // NCHW
const index_t *output_shape) {
const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output) {
typedef void (*Conv2dNeonFunction)(
const float *input, const index_t *input_shape, const float *filter,
const index_t *filter_shape, const float *bias, float *output,
......@@ -45,8 +42,8 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(
{nullptr, nullptr},
{nullptr, nullptr}};
// not implement yet
index_t kernel_h = filter_shape[2];
index_t kernel_w = filter_shape[3];
index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->dim(3);
if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
......@@ -56,20 +53,27 @@ void DepthwiseConv2dFunctor<DeviceType::NEON, float>::operator()(
<< " is not implemented yet, using slow version";
DepthwiseConv2dFunctor<DeviceType::CPU, float>(strides_, paddings_,
dilations_)(
input, input_shape, filter, filter_shape, bias, output, output_shape);
input, filter, bias, output);
return;
}
const float *input_ptr = input->data<float>();
const index_t *input_shape = input->shape().data();
const float *filter_ptr = filter->data<float>();
const index_t *filter_shape = filter->shape().data();
const float *bias_ptr = bias->data<float>();
float *output_ptr = output->mutable_data<float>();
const index_t *output_shape = output->shape().data();
// Keep this alive during kernel execution
Tensor padded_input;
if (paddings_[0] > 0 || paddings_[1] > 0) {
ConstructInputWithPadding(input, input_shape, paddings_.data(),
ConstructInputWithPadding(input_ptr, input_shape, paddings_.data(),
&padded_input);
input = padded_input.data<float>();
input_ptr = padded_input.data<float>();
input_shape = padded_input.shape().data();
}
auto conv2d_neon_func = selector[kernel_h - 1][strides_[0] - 1];
conv2d_neon_func(input, input_shape, filter, filter_shape, bias, output,
conv2d_neon_func(input_ptr, input_shape, filter_ptr, filter_shape, bias_ptr, output_ptr,
output_shape);
}
......
inline float4 conv1x3(const float *input_ptr,
const float *filter_ptr) {
float8 input = vload8(0, input_ptr);
float4 row0 = convert_float4(input.s0123);
float4 row1 = convert_float4(input.s1234);
float4 row2 = convert_float4(input.s2345);
return (float4)filter_ptr[0] * row0 + (float4)filter_ptr[1] * row1
+ (float4)filter_ptr[2] * row2;
}
inline float4 conv3x3x4(const float *input_ptr,
const float *filter_ptr,
const int row_width) {
float4 res;
res = conv1x3(input_ptr + 0 * row_width, filter_ptr + 0 * 3);
res += conv1x3(input_ptr + 1 * row_width, filter_ptr + 1 * 3);
res += conv1x3(input_ptr + 2 * row_width, filter_ptr + 2 * 3);
return res;
}
inline float conv3x3(const float *input_ptr,
const float *filter_ptr,
const int row_width) {
float res = input_ptr[0] * filter_ptr[0] + input_ptr[1] * filter_ptr[1] + input_ptr[2] * filter_ptr[2];
input_ptr += row_width;
filter_ptr += 3;
res += input_ptr[0] * filter_ptr[0] + input_ptr[1] * filter_ptr[1] + input_ptr[2] * filter_ptr[2];
input_ptr += row_width;
filter_ptr += 3;
res += input_ptr[0] * filter_ptr[0] + input_ptr[1] * filter_ptr[1] + input_ptr[2] * filter_ptr[2];
return res;
}
void kernel depthwise_conv_3x3_s1(global const float *input, /* n, c, h, w */
global const float *filter, /* m, i, kh, kw */
global const float *bias, /* o */
global float *output, /* n, c, h, w */
private const int in_chan_num,
private const int out_chan_num,
private const int in_height,
private const int in_width,
private const int out_height,
private const int out_width) {
int batch = get_global_id(0);
int out_chan_blk = get_global_id(1);
int out_pixel_blk = get_global_id(2);
const int in_pixel = in_height * in_width;
const int out_pixel = out_height * out_width;
const int multiplier = out_chan_num / in_chan_num;
const int round_out_width = (out_width + 3) / 4;
const int out_pixel_height = out_pixel_blk / round_out_width;
const int out_pixel_width = out_pixel_blk % round_out_width;
const int out_chan_begin = out_chan_blk * 4;
const int out_chan_end = min(out_chan_begin + 4, out_chan_num);
const int out_pixel_begin = out_pixel_height * out_width + out_pixel_width * 4;
const int out_pixel_end = min(out_pixel_begin + 4, (out_pixel_height + 1) * out_width);
const int in_pixel_begin = out_pixel_height * in_width + out_pixel_width * 4;
const int in_offset = batch * in_chan_num * in_pixel;
const int out_offset = batch * out_chan_num * out_pixel;
const float *input_base = input + in_offset + in_pixel_begin;
float *output_base = output + out_offset + out_pixel_begin;
int pixels = out_pixel_end - out_pixel_begin;
for (int i = out_chan_begin; i < out_chan_end; ++i) {
float bias_value = bias[i];
const float *input_ptr = input_base + (i / multiplier) * in_pixel;
const float *filter_ptr = filter + i * 9;
float *output_ptr = output_base + i * out_pixel;
if (pixels < 4) {
for (int out_idx = 0; out_idx < pixels; ++out_idx) {
output_ptr[out_idx] = bias_value;
output_ptr[out_idx] += conv3x3(input_ptr, filter_ptr, in_width);
input_ptr += 1;
}
} else {
float4 res = conv3x3x4(input_ptr, filter_ptr, in_width);
res += (float4)bias_value;
vstore4(res, 0, output_ptr);
}
}
}
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/kernels/depthwise_conv2d.h"
namespace mace {
namespace kernels {
extern void DepthwiseConvOpenclK3x3S1(const Tensor *input, const Tensor *filter,
const Tensor *bias, Tensor *output);
template <>
void DepthwiseConv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output) {
typedef void (*Conv2dOpenclFunction)(const Tensor *input, const Tensor *filter,
const Tensor *bias, Tensor *output);
// Selection matrix: kernel_size x stride_size
static const Conv2dOpenclFunction selector[5][2] = {
{nullptr, nullptr},
{nullptr, nullptr},
{DepthwiseConvOpenclK3x3S1, nullptr},
{nullptr, nullptr},
{nullptr, nullptr}};
index_t kernel_h = filter->shape()[2];
index_t kernel_w = filter->shape()[3];
if (kernel_h != kernel_w || kernel_h > 5 || strides_[0] != strides_[1] ||
strides_[0] > 2 || dilations_[0] != 1 || dilations_[1] != 1 ||
selector[kernel_h - 1][strides_[0] - 1] == nullptr) {
LOG(WARNING) << "OpenCL conv2d kernel with "
<< "filter" << kernel_h << "x" << kernel_w << ","
<< " stride " << strides_[0] << "x" << strides_[1]
<< " is not implemented yet, using slow version";
// TODO(heliangliang) The CPU/NEON kernel should map the buffer
DepthwiseConv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)(
input, filter, bias, output);
return;
}
auto conv2d_func = selector[kernel_h - 1][strides_[0] - 1];
if (paddings_[0] > 0 || paddings_[1] > 0) {
Tensor padded_input(GetDeviceAllocator(DeviceType::OPENCL), DataTypeToEnum<float>::v());
Tensor::MappingGuard input_mapper(input);
ConstructInputWithPadding(input->data<float>(), input->shape().data(), paddings_.data(),
&padded_input);
conv2d_func(&padded_input, filter, bias, output);
}else {
conv2d_func(input, filter, bias, output);
}
}
} // namespace kernels
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/common.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/conv_2d.h"
namespace mace {
namespace kernels {
extern void DepthwiseConvOpenclK3x3S1(const Tensor *input,
const Tensor *filter,
const Tensor *bias,
Tensor *output) {
const index_t batch = output->dim(0);
const index_t channels = output->dim(1);
const index_t height = output->dim(2);
const index_t width = output->dim(3);
const index_t input_batch = input->dim(0);
const index_t input_channels = input->dim(1);
const index_t input_height = input->dim(2);
const index_t input_width = input->dim(3);
MACE_CHECK(input_batch == batch);
auto runtime = OpenCLRuntime::Get();
auto program = runtime->program();
auto conv_2d = cl::KernelFunctor<cl::Buffer, cl::Buffer, cl::Buffer, cl::Buffer,
int, int, int, int, int, int, int>(program, "depthwise_conv_3x3_s1");
const index_t pixels = height * width;
const index_t channel_blocks = (channels + 3) / 4;
const index_t pixel_blocks = (width + 3) / 4 * height;
cl_int error;
conv_2d(cl::EnqueueArgs(runtime->command_queue(),
cl::NDRange(static_cast<int>(batch),
static_cast<int>(channel_blocks),
static_cast<int>(pixel_blocks)),
cl::NDRange(1, 1, 256)),
*(static_cast<cl::Buffer *>(input->buffer())),
*(static_cast<cl::Buffer *>(filter->buffer())),
*(static_cast<cl::Buffer *>(bias->buffer())),
*(static_cast<cl::Buffer *>(output->buffer())),
static_cast<int>(input_channels),
static_cast<int>(channels),
static_cast<int>(input_height),
static_cast<int>(input_width),
static_cast<int>(height),
static_cast<int>(width),
error);
MACE_CHECK(error == CL_SUCCESS);
};
} // namespace kernels
} // namespace mace
......@@ -171,7 +171,7 @@ TEST_F(Conv2dOpTest, Conv1x1) {
}
// TODO we need more tests
TEST_F(Conv2dOpTest, IdleConvNxNS12) {
TEST_F(Conv2dOpTest, AlignedConvNxNS12) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) {
......@@ -222,7 +222,7 @@ TEST_F(Conv2dOpTest, IdleConvNxNS12) {
}
}
TEST_F(Conv2dOpTest, DisgustConvNxNS12) {
TEST_F(Conv2dOpTest, UnalignedConvNxNS12) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) {
......
......@@ -14,4 +14,7 @@ REGISTER_NEON_OPERATOR(DepthwiseConv2d,
DepthwiseConv2dOp<DeviceType::NEON, float>);
#endif // __ARM_NEON
REGISTER_OPENCL_OPERATOR(DepthwiseConv2d,
DepthwiseConv2dOp<DeviceType::OPENCL, float>);
} // namespace mace
......@@ -26,10 +26,9 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
bool Run() override {
const Tensor *input = this->Input(INPUT);
const Tensor *filter = this->Input(FILTER);
const T *bias_data = nullptr;
const Tensor *bias = nullptr;
if (this->InputSize() >= 3) {
const Tensor *bias = this->Input(BIAS);
bias_data = bias->data<T>();
bias = this->Input(BIAS);
}
Tensor *output = this->Output(OUTPUT);
......@@ -47,9 +46,7 @@ class DepthwiseConv2dOp : public ConvPool2dOpBase<D, T> {
output->Resize(output_shape);
functor_.paddings_ = paddings;
functor_(input->data<T>(), input->shape().data(), filter->data<T>(),
filter_shape.data(), bias_data, output->mutable_data<T>(),
output->shape().data());
functor_(input, filter, bias, output);
return true;
}
......
......@@ -9,10 +9,11 @@ using namespace mace;
class DepthwiseConv2dOpTest : public OpsTestBase {};
TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
template <DeviceType D>
void SimpleValidTest() {
testing::internal::LogToStderr();
// Construct graph
auto &net = test_net();
OpsTestNet net;
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input")
.Input("Filter")
......@@ -26,15 +27,15 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 2, 3},
{1, 3, 5, 7, 9, 11, 2, 4, 6, 8, 10, 12});
net.AddInputFromArray<DeviceType::CPU, float>(
net.AddInputFromArray<D, float>("Input", {1, 2, 2, 3},
{1, 3, 5, 7, 9, 11, 2, 4, 6, 8, 10, 12});
net.AddInputFromArray<D, float>(
"Filter", {2, 2, 2, 2},
{1.0f, 5.0f, 9.0f, 13.0f, 2.0f, 6.0f, 10.0f, 14.0f, 3.0f, 7.0f, 11.0f,
15.0f, 4.0f, 8.0f, 12.0f, 16.0f});
net.AddInputFromArray<DeviceType::CPU, float>("Bias", {4}, {.1f, .2f, .3f, .4f});
net.AddInputFromArray<D, float>("Bias", {4}, {.1f, .2f, .3f, .4f});
// Run
net.RunOp();
net.RunOp(D);
// Check
auto expected = CreateTensor<float>(
......@@ -42,22 +43,26 @@ TEST_F(DepthwiseConv2dOpTest, Simple_VALID) {
{196.1f, 252.1f, 216.2f, 280.2f, 272.3f, 344.3f, 296.4f, 376.4f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(DepthwiseConv2dOpTest, SimpleCPU) {
SimpleValidTest<DeviceType::CPU>();
}
TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
template <DeviceType D>
void TestNxNS12(const index_t height, const index_t width) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) {
srand(time(NULL));
// generate random input
index_t batch = 2 + rand() % 10;
index_t input_channels = 3 + rand() % 10;
index_t height = 107;
index_t width = 113;
index_t multiplier = 3 + rand() % 10;
index_t batch = 1;
index_t input_channels = 3;
index_t multiplier = 2;
// Construct graph
auto &net = test_net();
OpsTestNet net;
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input")
.Input("Filter")
......@@ -71,19 +76,18 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", {batch, input_channels, height, width});
net.AddRandomInput<DeviceType::CPU, float>("Filter",
{multiplier, input_channels, kernel_h, kernel_w});
net.AddRandomInput<DeviceType::CPU, float>("Bias", {multiplier * input_channels});
// run cpu
net.RunOp();
net.AddRandomInput<D, float>("Input", {batch, input_channels, height, width});
net.AddRandomInput<D, float>("Filter", {multiplier, input_channels, kernel_h, kernel_w});
net.AddRandomInput<D, float>("Bias", {multiplier * input_channels});
// Run on device
net.RunOp(D);
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run NEON
net.RunOp(DeviceType::NEON);
// run cpu
net.RunOp();
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-3);
};
......@@ -93,4 +97,31 @@ TEST_F(DepthwiseConv2dOpTest, ConvNxNS12) {
func(kernel_size, kernel_size, stride, stride, SAME);
}
}
}
TEST_F(DepthwiseConv2dOpTest, NeonSimpleNxNS12) {
TestNxNS12<DeviceType::NEON>(4, 4);
}
TEST_F(DepthwiseConv2dOpTest, OpenCLSimpleNxNS12) {
TestNxNS12<DeviceType::OPENCL>(4, 4);
}
TEST_F(DepthwiseConv2dOpTest, NeonAlignedNxNS12) {
TestNxNS12<DeviceType::NEON>(64, 64);
TestNxNS12<DeviceType::NEON>(128, 128);
}
TEST_F(DepthwiseConv2dOpTest, OpenCLAlignedNxNS12) {
TestNxNS12<DeviceType::OPENCL>(64, 64);
TestNxNS12<DeviceType::OPENCL>(128, 128);
}
TEST_F(DepthwiseConv2dOpTest, NeonUnalignedNxNS12) {
TestNxNS12<DeviceType::NEON>(107, 113);
}
TEST_F(DepthwiseConv2dOpTest, OpenCLUnalignedNxNS12) {
TestNxNS12<DeviceType::OPENCL>(107, 113);
}
......@@ -38,20 +38,22 @@ static void DepthwiseConv2d(int iters,
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddRandomInput<DeviceType::CPU, float>("Input", {batch, channels, height, width});
net.AddRandomInput<DeviceType::CPU, float>("Filter",
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
net.AddRandomInput<D, float>("Filter",
{output_channels, channels, kernel_h, kernel_w});
net.AddRandomInput<DeviceType::CPU, float>("Bias", {output_channels});
net.AddRandomInput<D, float>("Bias", {output_channels});
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
#define BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, STRIDE, P, OC, TYPE, \
......@@ -70,7 +72,8 @@ static void DepthwiseConv2d(int iters,
#define BM_DEPTHWISE_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON);
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, NEON);\
BM_DEPTHWISE_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, OPENCL);
BM_DEPTHWISE_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
BM_DEPTHWISE_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册