提交 29c3f0f7 编写于 作者: L liuqi

Refactor conv register and optimize conv 3x3.

上级 8499b852
...@@ -1098,7 +1098,7 @@ namespace half_float ...@@ -1098,7 +1098,7 @@ namespace half_float
/// Conversion constructor. /// Conversion constructor.
/// \param rhs float to convert /// \param rhs float to convert
explicit half(float rhs) : data_(detail::float2half<round_style>(rhs)) {} half(float rhs) : data_(detail::float2half<round_style>(rhs)) {}
/// Conversion to single-precision. /// Conversion to single-precision.
/// \return single precision value representing expression value /// \return single precision value representing expression value
......
...@@ -11,13 +11,23 @@ ...@@ -11,13 +11,23 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
struct Conv2dFunctorBase {
Conv2dFunctorBase(const int *strides,
const Padding &paddings,
const int *dilations)
: strides_(strides), dilations_(dilations), paddings_(paddings) {}
const int *strides_; // [stride_h, stride_w]
const int *dilations_; // [dilation_h, dilation_w]
Padding paddings_;
};
template<DeviceType D, typename T> template<DeviceType D, typename T>
struct Conv2dFunctor { struct Conv2dFunctor : Conv2dFunctorBase {
Conv2dFunctor() {}
Conv2dFunctor(const int *strides, Conv2dFunctor(const int *strides,
const Padding &paddings, const Padding &paddings,
const int *dilations) const int *dilations)
: strides_(strides), dilations_(dilations), paddings_(paddings) {} : Conv2dFunctorBase(strides, paddings, dilations) {}
void operator()(const Tensor *input, void operator()(const Tensor *input,
const Tensor *filter, const Tensor *filter,
...@@ -76,9 +86,10 @@ struct Conv2dFunctor { ...@@ -76,9 +86,10 @@ struct Conv2dFunctor {
for (int h = 0; h < height; ++h) { for (int h = 0; h < height; ++h) {
for (int w = 0; w < width; ++w) { for (int w = 0; w < width; ++w) {
for (int c = 0; c < channels; ++c) { for (int c = 0; c < channels; ++c) {
T bias_channel = bias_data ? bias_data[c] : 0; T bias_channel = 0.0f;
if (bias) bias_channel = bias_data[c];
*output_data = bias_channel; *output_data = bias_channel;
T sum = 0; T sum = 0.0f;
const T *filter_ptr = filter_data + c; const T *filter_ptr = filter_data + c;
for (int kh = 0; kh < kernel_h; ++kh) { for (int kh = 0; kh < kernel_h; ++kh) {
for (int kw = 0; kw < kernel_w; ++kw) { for (int kw = 0; kw < kernel_w; ++kw) {
...@@ -113,9 +124,6 @@ struct Conv2dFunctor { ...@@ -113,9 +124,6 @@ struct Conv2dFunctor {
} }
const int *strides_; // [stride_h, stride_w]
const int *dilations_; // [dilation_h, dilation_w]
Padding paddings_;
}; };
template<> template<>
...@@ -123,11 +131,19 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input, ...@@ -123,11 +131,19 @@ void Conv2dFunctor<DeviceType::NEON, float>::operator()(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *bias, const Tensor *bias,
Tensor *output); Tensor *output);
template<>
void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input, template<typename T>
struct Conv2dFunctor<DeviceType::OPENCL, T> : Conv2dFunctorBase {
Conv2dFunctor(const int *strides,
const Padding &paddings,
const int *dilations)
: Conv2dFunctorBase(strides, paddings, dilations) {}
void operator()(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *bias, const Tensor *bias,
Tensor *output); Tensor *output);
};
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
......
...@@ -19,86 +19,76 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -19,86 +19,76 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
const int out_hb = get_global_id(2); const int out_hb = get_global_id(2);
const int rounded_in_ch = in_ch_blks * 4; const int rounded_in_ch = in_ch_blks * 4;
DATA_TYPE4 out0 = 0; float4 out0 = 0;
DATA_TYPE4 out1 = 0; float4 out1 = 0;
DATA_TYPE4 out2 = 0; float4 out2 = 0;
DATA_TYPE4 out3 = 0; float4 out3 = 0;
DATA_TYPE4 out4 = 0; float4 out4 = 0;
const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; const sampler_t sampler = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
#ifdef BIAS #ifdef BIAS
out0 = out0 =
READ_IMAGET(bias, sampler, (int2)(out_ch_blk, 0)); convert_float4(READ_IMAGET(bias, sampler, (int2)(out_ch_blk, 0)));
out1 = out0; out1 = out0;
out2 = out0; out2 = out0;
out3 = out0; out3 = out0;
out4 = out0; out4 = out0;
#endif #endif
#define DEFINE_IN_WIDTH(i) \ #ifdef STRIDE_1
in_width##i[1] = in_width##i[0] + 1; \ int in_width0 = out_w_blk - padding_left;
in_width##i[2] = in_width##i[0] + 2; \ int in_width1 = in_width0 + out_w_blks;
in_width##i[0] = (in_width##i[0] < 0 || in_width##i[0] >= in_width) ? (INT_MIN) : in_width##i[0]; \ int in_width2 = in_width1 + out_w_blks;
in_width##i[1] = (in_width##i[1] < 0 || in_width##i[1] >= in_width) ? (INT_MIN) : in_width##i[1]; \ int in_width3 = in_width2 + out_w_blks;
in_width##i[2] = (in_width##i[2] < 0 || in_width##i[2] >= in_width) ? (INT_MIN) : in_width##i[2]; int in_width4 = in_width3 + out_w_blks;
const int height_idx = (out_hb % out_height) - padding_top;
int in_width0[3]; #else
int in_width1[3]; int in_width0 = out_w_blk * 2 - padding_left;
int in_width2[3]; int in_width1 = (out_w_blk + out_w_blks) * 2 - padding_left;
int in_width3[3]; int in_width2 = (out_w_blk + 2 * out_w_blks) * 2 - padding_left;
int in_width4[3]; int in_width3 = (out_w_blk + 3 * out_w_blks) * 2 - padding_left;
in_width0[0] = out_w_blk - padding_left; int in_width4 = (out_w_blk + 4 * out_w_blks) * 2 - padding_left;
in_width1[0] = in_width0[0] + out_w_blks; const int height_idx = (out_hb % out_height) * 2 - padding_top;
in_width2[0] = in_width1[0] + out_w_blks; #endif
in_width3[0] = in_width2[0] + out_w_blks;
in_width4[0] = in_width3[0] + out_w_blks; const int batch_idx = (out_hb / out_height) * in_height;
DEFINE_IN_WIDTH(0);
float4 in0, in1, in2, in3, in4;
DEFINE_IN_WIDTH(1); float4 weights0, weights1, weights2, weights3;
DEFINE_IN_WIDTH(2);
DEFINE_IN_WIDTH(3);
DEFINE_IN_WIDTH(4);
#undef DEFINE_IN_WIDTH
const int batch_idx = out_hb / out_height;
const int height_idx = out_hb % out_height;
int in_hb[3];
in_hb[0] = height_idx - padding_top;
in_hb[1] = in_hb[0] + 1;
in_hb[2] = in_hb[1] + 1;
// Judge the height border for padding input.
in_hb[0] = (in_hb[0] < 0 || in_hb[0] >= in_height) ? -1 : in_hb[0] + batch_idx * in_height;
in_hb[1] = (in_hb[1] < 0 || in_hb[1] >= in_height) ? -1 : in_hb[1] + batch_idx * in_height;
in_hb[2] = (in_hb[2] < 0 || in_hb[2] >= in_height) ? -1 : in_hb[2] + batch_idx * in_height;
const int input_image_width = in_ch_blks * in_width;
DATA_TYPE4 in0, in1, in2, in3, in4;
DATA_TYPE4 weights0, weights1, weights2, weights3;
int in_idx, hb_idx, width_idx, in_width_idx; int in_idx, hb_idx, width_idx, in_width_idx;
// Unrolling this loop hurt perfmance // Unrolling this loop hurt perfmance
for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) { for (short in_ch_blk = 0; in_ch_blk < in_ch_blks; ++in_ch_blk) {
for (short hb_idx = 0; hb_idx < 3; ++ hb_idx) { for (short hb_idx = 0; hb_idx < 3; ++hb_idx) {
for (short width_idx = 0; width_idx < 3; ++width_idx) { for (short width_idx = 0; width_idx < 3; ++width_idx) {
in_idx = in_ch_blk * in_width; in_idx = in_ch_blk * in_width;
// Judge the width border for padding input. int in_hb_value = height_idx + hb_idx;
in0 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width0[width_idx], in_hb[hb_idx])); in_hb_value = select(in_hb_value + batch_idx,
in1 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width1[width_idx], in_hb[hb_idx])); -1,
in2 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width2[width_idx], in_hb[hb_idx])); (in_hb_value < 0 || in_hb_value >= in_height));
in3 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width3[width_idx], in_hb[hb_idx]));
in4 = READ_IMAGET(input, sampler, (int2)(in_idx + in_width4[width_idx], in_hb[hb_idx])); int in_width_value;
#define READ_INPUT(i) \
in_width_value = in_width##i + width_idx; \
in_width_value = select(in_idx + in_width_value, \
-1, \
(in_width_value < 0 || in_width_value >= in_width)); \
in##i = convert_float4(READ_IMAGET(input, sampler, (int2)(in_width_value, in_hb_value)));
READ_INPUT(0);
READ_INPUT(1);
READ_INPUT(2);
READ_INPUT(3);
READ_INPUT(4);
#undef READ_INPUT
int filter_idx = (in_ch_blk << 2) + (hb_idx * 3 + width_idx) * rounded_in_ch; int filter_idx = (in_ch_blk << 2) + (hb_idx * 3 + width_idx) * rounded_in_ch;
weights0 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 0, out_ch_blk)); weights0 = convert_float4(READ_IMAGET(filter, sampler, (int2)(filter_idx + 0, out_ch_blk)));
weights1 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 1, out_ch_blk)); weights1 = convert_float4(READ_IMAGET(filter, sampler, (int2)(filter_idx + 1, out_ch_blk)));
weights2 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 2, out_ch_blk)); weights2 = convert_float4(READ_IMAGET(filter, sampler, (int2)(filter_idx + 2, out_ch_blk)));
weights3 = READ_IMAGET(filter, sampler, (int2)(filter_idx + 3, out_ch_blk)); weights3 = convert_float4(READ_IMAGET(filter, sampler, (int2)(filter_idx + 3, out_ch_blk)));
// Will prefetch L2 improve performance? How to pretch image data? // Will prefetch L2 improve performance? How to pretch image data?
...@@ -131,6 +121,7 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -131,6 +121,7 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
} }
} }
#ifdef TYPE_FLOAT
const int out_x_base = out_ch_blk * out_width; const int out_x_base = out_ch_blk * out_width;
int w = out_w_blk; int w = out_w_blk;
WRITE_IMAGET(output, WRITE_IMAGET(output,
...@@ -160,4 +151,36 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b] ...@@ -160,4 +151,36 @@ __kernel void conv_2d_3x3(__read_only image2d_t input, /* [c%4 * w * c/4, h * b]
WRITE_IMAGET(output, WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb), (int2)(out_x_base + w, out_hb),
out4); out4);
#else
const int out_x_base = out_ch_blk * out_width;
int w = out_w_blk;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
convert_half4(out0));
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
convert_half4(out1));
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
convert_half4(out2));
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
convert_half4(out3));
w += out_w_blks;
if (w >= out_width) return;
WRITE_IMAGET(output,
(int2)(out_x_base + w, out_hb),
convert_half4(out4));
#endif
} }
...@@ -24,8 +24,8 @@ extern void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter, ...@@ -24,8 +24,8 @@ extern void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter,
const Tensor *bias, const int *padding, const Tensor *bias, const int *padding,
Tensor *output); Tensor *output);
template <> template<typename T>
void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input, void Conv2dFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
const Tensor *filter, const Tensor *filter,
const Tensor *bias, const Tensor *bias,
Tensor *output) { Tensor *output) {
...@@ -36,7 +36,7 @@ void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input, ...@@ -36,7 +36,7 @@ void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
static const Conv2dOpenclFunction selector[5][2] = { static const Conv2dOpenclFunction selector[5][2] = {
{Conv2dOpenclK1x1S1, Conv2dOpenclK1x1S2}, {Conv2dOpenclK1x1S1, Conv2dOpenclK1x1S2},
{nullptr, nullptr}, {nullptr, nullptr},
{Conv2dOpenclK3x3S1, nullptr}, {Conv2dOpenclK3x3S1, Conv2dOpenclK3x3S2},
{nullptr, nullptr}, {nullptr, nullptr},
{nullptr, nullptr}}; {nullptr, nullptr}};
...@@ -50,7 +50,7 @@ void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input, ...@@ -50,7 +50,7 @@ void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
<< " stride " << strides_[0] << "x" << strides_[1] << " stride " << strides_[0] << "x" << strides_[1]
<< " is not implemented yet, using slow version"; << " is not implemented yet, using slow version";
// TODO(heliangliang) The CPU/NEON kernel should map the buffer // TODO(heliangliang) The CPU/NEON kernel should map the buffer
Conv2dFunctor<DeviceType::CPU, float>(strides_, paddings_, dilations_)( Conv2dFunctor<DeviceType::CPU, T>(strides_, paddings_, dilations_)(
input, filter, bias, output); input, filter, bias, output);
return; return;
} }
...@@ -73,5 +73,8 @@ void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input, ...@@ -73,5 +73,8 @@ void Conv2dFunctor<DeviceType::OPENCL, float>::operator()(const Tensor *input,
conv2d_func(input, filter, bias, paddings.data(), output); conv2d_func(input, filter, bias, paddings.data(), output);
} }
template struct Conv2dFunctor<DeviceType::OPENCL, float>;
template struct Conv2dFunctor<DeviceType::OPENCL, half>;
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -25,9 +25,10 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter, ...@@ -25,9 +25,10 @@ static void Conv2d3x3S12(const Tensor *input, const Tensor *filter,
const index_t width_blocks = RoundUpDiv<index_t, 5>(width); const index_t width_blocks = RoundUpDiv<index_t, 5>(width);
std::set<std::string> built_options; std::set<std::string> built_options;
built_options.emplace("-DDATA_TYPE=" + DataTypeToCLType(input->dtype())); built_options.emplace(input->dtype() == DT_FLOAT ? "-DTYPE_FLOAT" : "");
built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(input->dtype())); built_options.emplace("-DCMD_DATA_TYPE=" + DataTypeToOPENCLCMDDataType(input->dtype()));
built_options.emplace(bias != nullptr ? "-DBIAS" : ""); built_options.emplace(bias != nullptr ? "-DBIAS" : "");
built_options.emplace(stride == 1 ? "-DSTRIDE_1" : "");
auto runtime = OpenCLRuntime::Get(); auto runtime = OpenCLRuntime::Get();
auto program = runtime->program(); auto program = runtime->program();
...@@ -68,6 +69,7 @@ void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter, ...@@ -68,6 +69,7 @@ void Conv2dOpenclK3x3S1(const Tensor *input, const Tensor *filter,
void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter, void Conv2dOpenclK3x3S2(const Tensor *input, const Tensor *filter,
const Tensor *bias, const int *padding, Tensor *output) { const Tensor *bias, const int *padding, Tensor *output) {
Conv2d3x3S12(input, filter, bias, 2, padding, output);
}; };
} // namespace kernels } // namespace kernels
......
...@@ -11,6 +11,11 @@ REGISTER_CPU_OPERATOR(OpKeyBuilder("Conv2D") ...@@ -11,6 +11,11 @@ REGISTER_CPU_OPERATOR(OpKeyBuilder("Conv2D")
.Build(), .Build(),
Conv2dOp<DeviceType::CPU, float>); Conv2dOp<DeviceType::CPU, float>);
REGISTER_CPU_OPERATOR(OpKeyBuilder("Conv2D")
.TypeConstraint<half>("T")
.Build(),
Conv2dOp<DeviceType::CPU, half>);
#if __ARM_NEON #if __ARM_NEON
REGISTER_NEON_OPERATOR(OpKeyBuilder("Conv2D") REGISTER_NEON_OPERATOR(OpKeyBuilder("Conv2D")
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
...@@ -23,4 +28,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("Conv2D") ...@@ -23,4 +28,9 @@ REGISTER_OPENCL_OPERATOR(OpKeyBuilder("Conv2D")
.Build(), .Build(),
Conv2dOp<DeviceType::OPENCL, float>); Conv2dOp<DeviceType::OPENCL, float>);
REGISTER_OPENCL_OPERATOR(OpKeyBuilder("Conv2D")
.TypeConstraint<half>("T")
.Build(),
Conv2dOp<DeviceType::OPENCL, half>);
} // namespace mace } // namespace mace
...@@ -33,9 +33,9 @@ static void Conv2d(int iters, ...@@ -33,9 +33,9 @@ static void Conv2d(int iters,
net.AddRandomInput<D, T>("Bias", {output_channels}); net.AddRandomInput<D, T>("Bias", {output_channels});
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D>(net, "Input", "InputImage", kernels::BufferType::IN_OUT); BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D>(net, "Filter", "FilterImage", kernels::BufferType::FILTER); BufferToImage<D, T>(net, "Filter", "FilterImage", kernels::BufferType::FILTER);
BufferToImage<D>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); BufferToImage<D, T>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage") .Input("InputImage")
.Input("FilterImage") .Input("FilterImage")
...@@ -44,6 +44,7 @@ static void Conv2d(int iters, ...@@ -44,6 +44,7 @@ static void Conv2d(int iters,
.AddIntsArg("strides", {stride, stride}) .AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding) .AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
...@@ -88,43 +89,9 @@ static void Conv2d(int iters, ...@@ -88,43 +89,9 @@ static void Conv2d(int iters,
BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE) BM_CONV_2D_##N##_##C##_##H##_##W##_K##KH##x##KW##S##STRIDE##_##P##_##OC##_##TYPE##_##DEVICE)
#define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \ #define BM_CONV_2D(N, C, H, W, KH, KW, S, P, OC, TYPE) \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, CPU); \
BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, OPENCL); BM_CONV_2D_MACRO(N, C, H, W, KH, KW, S, P, OC, TYPE, OPENCL);
// ICNet
BM_CONV_2D(1, 512, 15, 15, 1, 1, 1, VALID, 1024, float);
BM_CONV_2D(1, 128, 60, 60, 3, 3, 1, VALID, 128, float);
// SNPE GPU ExecutionDuration = 448us, % ALU Utilization = 105
BM_CONV_2D(1, 64, 60, 60, 1, 1, 1, VALID, 128, float);
// SNPE GPU ExecutionDuration = 258us, % ALU Utilization = 108
BM_CONV_2D(1, 32, 60, 60, 1, 1, 1, VALID, 128, float);
// SNPE GPU ExecutionDuration = 506us, % ALU Utilization = 106.8 // SNPE GPU ExecutionDuration = 506us, % ALU Utilization = 106.8
BM_CONV_2D(1, 32, 60, 60, 3, 3, 1, SAME, 32, float); BM_CONV_2D(1, 32, 60, 60, 3, 3, 1, SAME, 32, half);
// Test RGB <-> YUV
BM_CONV_2D(1, 3, 2160, 1080, 1, 1, 1, VALID, 3, float);
BM_CONV_2D(1, 3, 480, 480, 1, 1, 1, VALID, 3, float);
BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 1, 1, 1, VALID, 128, float); // Test bad alignments
BM_CONV_2D(1, 3, 512, 512, 1, 1, 1, VALID, 3, float);
BM_CONV_2D(1, 32, 112, 112, 1, 1, 1, VALID, 64, float);
BM_CONV_2D(1, 64, 56, 56, 1, 1, 1, VALID, 128, float);
BM_CONV_2D(1, 256, 28, 28, 1, 1, 1, VALID, 256, float);
BM_CONV_2D(1, 1024, 7, 7, 1, 1, 1, VALID, 1024, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 3, 512, 512, 3, 3, 1, VALID, 3, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 1, SAME, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, VALID, 128, float);
BM_CONV_2D(1, 3, 512, 512, 3, 3, 2, VALID, 3, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 2, SAME, 128, float);
BM_CONV_2D(1, 64, 33, 31, 3, 3, 2, SAME, 128, float);
BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 5, 5, 1, SAME, 128, float);
BM_CONV_2D(1, 64, 32, 31, 5, 5, 1, SAME, 128, float);
} // namespace mace } // namespace mace
...@@ -98,9 +98,9 @@ void TestNHWCSimple3x3VALID() { ...@@ -98,9 +98,9 @@ void TestNHWCSimple3x3VALID() {
net.AddInputFromArray<D, T>("Bias", {1}, {0.1f}); net.AddInputFromArray<D, T>("Bias", {1}, {0.1f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D>(net, "Input", "InputImage", kernels::BufferType::IN_OUT); BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D>(net, "Filter", "FilterImage", kernels::BufferType::FILTER); BufferToImage<D, T>(net, "Filter", "FilterImage", kernels::BufferType::FILTER);
BufferToImage<D>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); BufferToImage<D, T>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage") .Input("InputImage")
.Input("FilterImage") .Input("FilterImage")
...@@ -109,12 +109,13 @@ void TestNHWCSimple3x3VALID() { ...@@ -109,12 +109,13 @@ void TestNHWCSimple3x3VALID() {
.AddIntsArg("strides", {1, 1}) .AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID) .AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(D); net.RunOp(D);
// Transfer output // Transfer output
ImageToBuffer<D>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT); ImageToBuffer<D, T>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT);
} else { } else {
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
...@@ -125,13 +126,14 @@ void TestNHWCSimple3x3VALID() { ...@@ -125,13 +126,14 @@ void TestNHWCSimple3x3VALID() {
.AddIntsArg("strides", {1, 1}) .AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID) .AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
} }
auto expected = CreateTensor<T>({1, 1, 1, 1}, {18.1f}); auto expected = CreateTensor<float>({1, 1, 1, 1}, {18.1f});
ExpectTensorNear<T>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float, T>(*expected, *net.GetOutput("Output"), 0.01);
} }
template<DeviceType D, typename T> template<DeviceType D, typename T>
...@@ -149,9 +151,9 @@ void TestNHWCSimple3x3SAME() { ...@@ -149,9 +151,9 @@ void TestNHWCSimple3x3SAME() {
net.AddInputFromArray<D, T>("Bias", {1}, {0.1f}); net.AddInputFromArray<D, T>("Bias", {1}, {0.1f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D>(net, "Input", "InputImage", kernels::BufferType::IN_OUT); BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D>(net, "Filter", "FilterImage", kernels::BufferType::FILTER); BufferToImage<D, T>(net, "Filter", "FilterImage", kernels::BufferType::FILTER);
BufferToImage<D>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); BufferToImage<D, T>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage") .Input("InputImage")
.Input("FilterImage") .Input("FilterImage")
...@@ -160,12 +162,13 @@ void TestNHWCSimple3x3SAME() { ...@@ -160,12 +162,13 @@ void TestNHWCSimple3x3SAME() {
.AddIntsArg("strides", {1, 1}) .AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::SAME) .AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
// Transfer output // Transfer output
ImageToBuffer<D>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT); ImageToBuffer<D, T>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT);
} else { } else {
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
...@@ -176,26 +179,31 @@ void TestNHWCSimple3x3SAME() { ...@@ -176,26 +179,31 @@ void TestNHWCSimple3x3SAME() {
.AddIntsArg("strides", {1, 1}) .AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::SAME) .AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
} }
auto expected = CreateTensor<T>( auto expected = CreateTensor<float>(
{1, 3, 3, 1}, {1, 3, 3, 1},
{8.1f, 12.1f, 8.1f, 12.1f, 18.1f, 12.1f, 8.1f, 12.1f, 8.1f}); {8.1f, 12.1f, 8.1f, 12.1f, 18.1f, 12.1f, 8.1f, 12.1f, 8.1f});
ExpectTensorNear<T>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float, T>(*expected, *net.GetOutput("Output"), 0.01);
} }
TEST_F(Conv2dOpTest, CPUSimple) { TEST_F(Conv2dOpTest, CPUSimple) {
TestNHWCSimple3x3VALID<DeviceType::CPU, float>(); TestNHWCSimple3x3VALID<DeviceType::CPU, float>();
TestNHWCSimple3x3VALID<DeviceType::CPU, half>();
TestNHWCSimple3x3SAME<DeviceType::CPU, float>(); TestNHWCSimple3x3SAME<DeviceType::CPU, float>();
TestNHWCSimple3x3SAME<DeviceType::CPU, half>();
} }
TEST_F(Conv2dOpTest, OPENCLSimple) { TEST_F(Conv2dOpTest, OPENCLSimple) {
TestNHWCSimple3x3VALID<DeviceType::OPENCL, float>(); TestNHWCSimple3x3VALID<DeviceType::OPENCL, float>();
TestNHWCSimple3x3VALID<DeviceType::OPENCL, half>();
TestNHWCSimple3x3SAME<DeviceType::OPENCL, float>(); TestNHWCSimple3x3SAME<DeviceType::OPENCL, float>();
TestNHWCSimple3x3SAME<DeviceType::OPENCL, half>();
} }
template<DeviceType D> template<DeviceType D>
...@@ -233,22 +241,22 @@ TEST_F(Conv2dOpTest, NEONWithouBias) { ...@@ -233,22 +241,22 @@ TEST_F(Conv2dOpTest, NEONWithouBias) {
TestSimple3x3WithoutBias<DeviceType::NEON>(); TestSimple3x3WithoutBias<DeviceType::NEON>();
} }
template<DeviceType D> template<DeviceType D, typename T>
void TestNHWCSimple3x3WithoutBias() { void TestNHWCSimple3x3WithoutBias() {
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddInputFromArray<D, float>( net.AddInputFromArray<D, T>(
"Input", {1, 3, 3, 2}, "Input", {1, 3, 3, 2},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, float>( net.AddInputFromArray<D, T>(
"Filter", {3, 3, 2, 1}, "Filter", {3, 3, 2, 1},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}); 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D>(net, "Input", "InputImage", kernels::BufferType::IN_OUT); BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D>(net, "Filter", "FilterImage", kernels::BufferType::FILTER); BufferToImage<D, T>(net, "Filter", "FilterImage", kernels::BufferType::FILTER);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage") .Input("InputImage")
...@@ -257,11 +265,12 @@ void TestNHWCSimple3x3WithoutBias() { ...@@ -257,11 +265,12 @@ void TestNHWCSimple3x3WithoutBias() {
.AddIntsArg("strides", {1, 1}) .AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID) .AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
// Transfer output // Transfer output
ImageToBuffer<D>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT); ImageToBuffer<D, T>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT);
} else { } else {
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
...@@ -270,6 +279,7 @@ void TestNHWCSimple3x3WithoutBias() { ...@@ -270,6 +279,7 @@ void TestNHWCSimple3x3WithoutBias() {
.AddIntsArg("strides", {1, 1}) .AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID) .AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -279,15 +289,17 @@ void TestNHWCSimple3x3WithoutBias() { ...@@ -279,15 +289,17 @@ void TestNHWCSimple3x3WithoutBias() {
// Check // Check
auto expected = CreateTensor<float>({1, 1, 1, 1}, {18.0f}); auto expected = CreateTensor<float>({1, 1, 1, 1}, {18.0f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float, T>(*expected, *net.GetOutput("Output"), 0.01);
} }
TEST_F(Conv2dOpTest, CPUWithoutBias) { TEST_F(Conv2dOpTest, CPUWithoutBias) {
TestNHWCSimple3x3WithoutBias<DeviceType::CPU>(); TestNHWCSimple3x3WithoutBias<DeviceType::CPU, float>();
TestNHWCSimple3x3WithoutBias<DeviceType::CPU, half>();
} }
TEST_F(Conv2dOpTest, OPENCLWithoutBias) { TEST_F(Conv2dOpTest, OPENCLWithoutBias) {
TestNHWCSimple3x3WithoutBias<DeviceType::OPENCL>(); TestNHWCSimple3x3WithoutBias<DeviceType::OPENCL, float>();
TestNHWCSimple3x3WithoutBias<DeviceType::OPENCL, half>();
} }
template<DeviceType D> template<DeviceType D>
...@@ -333,27 +345,27 @@ TEST_F(Conv2dOpTest, NEONCombined) { ...@@ -333,27 +345,27 @@ TEST_F(Conv2dOpTest, NEONCombined) {
TestCombined3x3<DeviceType::NEON>(); TestCombined3x3<DeviceType::NEON>();
} }
template<DeviceType D> template<DeviceType D, typename T>
static void TestNHWCCombined3x3() { static void TestNHWCCombined3x3() {
// Construct graph // Construct graph
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
net.AddInputFromArray<D, float>( net.AddInputFromArray<D, T>(
"Input", {1, 5, 5, 2}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, "Input", {1, 5, 5, 2}, {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}); 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, float>( net.AddInputFromArray<D, T>(
"Filter", {3, 3, 2, 2}, "Filter", {3, 3, 2, 2},
{1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, {1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f,
1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f,
1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f}); 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f, 1.0f, 0.5f});
net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f}); net.AddInputFromArray<D, T>("Bias", {2}, {0.1f, 0.2f});
if (D == DeviceType::OPENCL) { if (D == DeviceType::OPENCL) {
BufferToImage<D>(net, "Input", "InputImage", kernels::BufferType::IN_OUT); BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D>(net, "Filter", "FilterImage", kernels::BufferType::FILTER); BufferToImage<D, T>(net, "Filter", "FilterImage", kernels::BufferType::FILTER);
BufferToImage<D>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); BufferToImage<D, T>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("InputImage") .Input("InputImage")
...@@ -363,11 +375,12 @@ static void TestNHWCCombined3x3() { ...@@ -363,11 +375,12 @@ static void TestNHWCCombined3x3() {
.AddIntsArg("strides", {2, 2}) .AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::SAME) .AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
ImageToBuffer<D>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT); ImageToBuffer<D, T>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT);
} else { } else {
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
...@@ -377,6 +390,7 @@ static void TestNHWCCombined3x3() { ...@@ -377,6 +390,7 @@ static void TestNHWCCombined3x3() {
.AddIntsArg("strides", {2, 2}) .AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::SAME) .AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
...@@ -388,17 +402,33 @@ static void TestNHWCCombined3x3() { ...@@ -388,17 +402,33 @@ static void TestNHWCCombined3x3() {
{1, 3, 3, 2}, {8.1f, 4.2f, 12.1f, 6.2f, 8.1f, 4.2f, {1, 3, 3, 2}, {8.1f, 4.2f, 12.1f, 6.2f, 8.1f, 4.2f,
12.1f, 6.2f, 18.1f, 9.2f, 12.1f, 6.2f, 12.1f, 6.2f, 18.1f, 9.2f, 12.1f, 6.2f,
8.1f, 4.2f, 12.1f, 6.2f, 8.1f, 4.2f}); 8.1f, 4.2f, 12.1f, 6.2f, 8.1f, 4.2f});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 0.001); ExpectTensorNear<float, T>(*expected, *net.GetOutput("Output"), 0.01);
}
TEST_F(Conv2dOpTest, CPUStride2) {
TestNHWCCombined3x3<DeviceType::CPU, float>();
TestNHWCCombined3x3<DeviceType::CPU, half>();
} }
TEST_F(Conv2dOpTest, CPUCombined) { TEST_F(Conv2dOpTest, OPENCLStride2) {
TestNHWCCombined3x3<DeviceType::CPU>(); TestNHWCCombined3x3<DeviceType::OPENCL, float>();
TestNHWCCombined3x3<DeviceType::OPENCL, half>();
} }
template<DeviceType D> template<DeviceType D>
void TestConv1x1() { void TestConv1x1() {
// Construct graph
OpsTestNet net; OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddInputFromArray<D, float>( net.AddInputFromArray<D, float>(
...@@ -415,39 +445,8 @@ void TestConv1x1() { ...@@ -415,39 +445,8 @@ void TestConv1x1() {
{1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}); {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f});
net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f}); net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f});
// Construct graph // Run
if (D == DeviceType::OPENCL) {
BufferToImage<D>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D>(net, "Filter", "FilterImage", kernels::BufferType::FILTER);
BufferToImage<D>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage")
.Input("FilterImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
net.RunOp(D);
// Transfer output
ImageToBuffer<D>(net, "OutputImage", "Output", kernels::BufferType::IN_OUT);
} else {
OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
net.RunOp(D); net.RunOp(D);
}
// Check // Check
auto expected = CreateTensor<float>( auto expected = CreateTensor<float>(
...@@ -466,11 +465,11 @@ TEST_F(Conv2dOpTest, CPUConv1x1) { ...@@ -466,11 +465,11 @@ TEST_F(Conv2dOpTest, CPUConv1x1) {
TestConv1x1<DeviceType::CPU>(); TestConv1x1<DeviceType::CPU>();
} }
TEST_F(Conv2dOpTest, OPENCLConv1x1) { //TEST_F(Conv2dOpTest, OPENCLConv1x1) {
TestConv1x1<DeviceType::OPENCL>(); // TestConv1x1<DeviceType::OPENCL>();
} //}
template<DeviceType D> template<DeviceType D, typename T>
static void TestComplexConvNxNS12(const std::vector<index_t> &shape) { static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
testing::internal::LogToStderr(); testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w, auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
...@@ -478,7 +477,6 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) { ...@@ -478,7 +477,6 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
srand(time(NULL)); srand(time(NULL));
// generate random input // generate random input
// TODO test all sizes
index_t batch = 3 + (rand() % 10); index_t batch = 3 + (rand() % 10);
index_t height = shape[0]; index_t height = shape[0];
index_t width = shape[1]; index_t width = shape[1];
...@@ -494,13 +492,14 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) { ...@@ -494,13 +492,14 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
.AddIntsArg("strides", {stride_h, stride_w}) .AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type) .AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, input_channels}); net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, float>( net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, input_channels, output_channels}); "Filter", {kernel_h, kernel_w, input_channels, output_channels});
net.AddRandomInput<D, float>("Bias", {output_channels}); net.AddRandomInput<D, T>("Bias", {output_channels});
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
...@@ -509,9 +508,9 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) { ...@@ -509,9 +508,9 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
// run on gpu // run on gpu
BufferToImage<D>(net, "Input", "InputImage", kernels::BufferType::IN_OUT); BufferToImage<D, T>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D>(net, "Filter", "FilterImage", kernels::BufferType::FILTER); BufferToImage<D, T>(net, "Filter", "FilterImage", kernels::BufferType::FILTER);
BufferToImage<D>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); BufferToImage<D, T>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage") .Input("InputImage")
...@@ -521,16 +520,17 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) { ...@@ -521,16 +520,17 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
.AddIntsArg("strides", {stride_h, stride_w}) .AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type) .AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on device // Run on device
net.RunOp(D); net.RunOp(D);
ImageToBuffer<D>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT); ImageToBuffer<D, T>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT);
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.001); ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"), 0.001);
}; };
for (int kernel_size : {1, 3}) { for (int kernel_size : {1, 3}) {
for (int stride : {1}) { for (int stride : {1, 2}) {
func(kernel_size, kernel_size, stride, stride, VALID); func(kernel_size, kernel_size, stride, stride, VALID);
func(kernel_size, kernel_size, stride, stride, SAME); func(kernel_size, kernel_size, stride, stride, SAME);
} }
...@@ -538,9 +538,97 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) { ...@@ -538,9 +538,97 @@ static void TestComplexConvNxNS12(const std::vector<index_t> &shape) {
} }
TEST_F(Conv2dOpTest, OPENCLAlignedConvNxNS12) { TEST_F(Conv2dOpTest, OPENCLAlignedConvNxNS12) {
TestComplexConvNxNS12<DeviceType::OPENCL>({32, 32, 64, 128}); TestComplexConvNxNS12<DeviceType::OPENCL, float>({32, 32, 64, 128});
} }
TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) { TEST_F(Conv2dOpTest, OPENCLUnalignedConvNxNS12) {
TestComplexConvNxNS12<DeviceType::OPENCL>({107, 113, 5, 7}); TestComplexConvNxNS12<DeviceType::OPENCL, float>({107, 113, 5, 7});
} }
template<DeviceType D, typename T>
static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type) {
srand(time(NULL));
// generate random input
index_t batch = 3 + (rand() % 10);
index_t height = shape[0];
index_t width = shape[1];
index_t input_channels = shape[2] + (rand() % 10);
index_t output_channels = shape[3] + (rand() % 10);
// Construct graph
OpsTestNet net;
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
std::vector<float> float_input_data;
GenerateRandomRealTypeData({batch, height, width, input_channels}, float_input_data);
std::vector<float> float_filter_data;
GenerateRandomRealTypeData({kernel_h, kernel_w, input_channels, output_channels}, float_filter_data);
std::vector<float> float_bias_data;
GenerateRandomRealTypeData({output_channels}, float_bias_data);
// Add input data
net.AddInputFromArray<D, float>("Input", {batch, height, width, input_channels}, float_input_data);
net.AddInputFromArray<D, float>(
"Filter", {kernel_h, kernel_w, input_channels, output_channels}, float_filter_data);
net.AddInputFromArray<D, float>("Bias", {output_channels}, float_bias_data);
// run on cpu
net.RunOp();
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
std::vector<half> input_data(float_input_data.begin(), float_input_data.end());
std::vector<half> filter_data(float_filter_data.begin(), float_filter_data.end());
std::vector<half> bias_data(float_bias_data.begin(), float_bias_data.end());
net.AddInputFromArray<D, T>("InputHalf", {batch, height, width, input_channels}, input_data);
net.AddInputFromArray<D, T>(
"FilterHalf", {kernel_h, kernel_w, input_channels, output_channels}, filter_data);
net.AddInputFromArray<D, T>("BiasHalf", {output_channels}, bias_data);
// run on gpu
BufferToImage<D, T>(net, "InputHalf", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D, T>(net, "FilterHalf", "FilterImage", kernels::BufferType::FILTER);
BufferToImage<D, T>(net, "BiasHalf", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage")
.Input("FilterImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on device
net.RunOp(D);
ImageToBuffer<D, T>(net, "OutputImage", "OPENCLOutput", kernels::BufferType::IN_OUT);
ExpectTensorNear<float, T>(expected, *net.GetOutput("OPENCLOutput"), 1.0);
};
for (int kernel_size : {3}) {
for (int stride : {1}) {
func(kernel_size, kernel_size, stride, stride, VALID);
}
}
}
// TODO: support half input & float computation
//TEST_F(Conv2dOpTest, OPENCLHalfAlignedConvNxNS12) {
// TestHalfComplexConvNxNS12<DeviceType::OPENCL, half>({32, 32, 64, 128});
//}
...@@ -210,13 +210,17 @@ void GenerateRandomRealTypeData(const std::vector<index_t> &shape, ...@@ -210,13 +210,17 @@ void GenerateRandomRealTypeData(const std::vector<index_t> &shape,
std::vector<T> &res) { std::vector<T> &res) {
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); std::mt19937 gen(rd());
std::normal_distribution<T> nd(0, 1); std::normal_distribution<float> nd(0, 1);
index_t size = std::accumulate(shape.begin(), shape.end(), 1, index_t size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<index_t>()); std::multiplies<index_t>());
res.resize(size); res.resize(size);
if (DataTypeToEnum<T>::value == DT_HALF) {
std::generate(res.begin(), res.end(), [&gen, &nd] { return half_float::half_cast<half>(nd(gen)); });
} else {
std::generate(res.begin(), res.end(), [&gen, &nd] { return nd(gen); }); std::generate(res.begin(), res.end(), [&gen, &nd] { return nd(gen); });
}
} }
template <typename T> template <typename T>
...@@ -290,39 +294,40 @@ inline void ExpectEqual<double>(const double &a, const double &b) { ...@@ -290,39 +294,40 @@ inline void ExpectEqual<double>(const double &a, const double &b) {
EXPECT_DOUBLE_EQ(a, b); EXPECT_DOUBLE_EQ(a, b);
} }
inline void AssertSameTypeDims(const Tensor &x, const Tensor &y) { inline void AssertSameDims(const Tensor &x, const Tensor &y) {
ASSERT_EQ(x.dtype(), y.dtype());
ASSERT_TRUE(IsSameSize(x, y)) << "x.shape [" << ShapeToString(x) << "] vs " ASSERT_TRUE(IsSameSize(x, y)) << "x.shape [" << ShapeToString(x) << "] vs "
<< "y.shape [ " << ShapeToString(y) << "]"; << "y.shape [ " << ShapeToString(y) << "]";
} }
template <typename T, bool is_fp = is_floating_point_type<T>::value> template <typename EXP_TYPE, typename RES_TYPE, bool is_fp = is_floating_point_type<EXP_TYPE>::value>
struct Expector; struct Expector;
// Partial specialization for float and double. // Partial specialization for float and double.
template <typename T> template <typename EXP_TYPE, typename RES_TYPE>
struct Expector<T, true> { struct Expector<EXP_TYPE, RES_TYPE, true> {
static void Equal(const T &a, const T &b) { ExpectEqual(a, b); } static void Equal(const EXP_TYPE &a, const RES_TYPE &b) { ExpectEqual(a, b); }
static void Equal(const Tensor &x, const Tensor &y) { static void Equal(const Tensor &x, const Tensor &y) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v()); ASSERT_EQ(x.dtype(), DataTypeToEnum<EXP_TYPE>::v());
AssertSameTypeDims(x, y); ASSERT_EQ(y.dtype(), DataTypeToEnum<RES_TYPE>::v());
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x); Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y); Tensor::MappingGuard y_mapper(&y);
auto a = x.data<T>(); auto a = x.data<EXP_TYPE>();
auto b = y.data<T>(); auto b = y.data<RES_TYPE>();
for (int i = 0; i < x.size(); ++i) { for (int i = 0; i < x.size(); ++i) {
ExpectEqual(a(i), b(i)); ExpectEqual(a(i), b(i));
} }
} }
static void Near(const Tensor &x, const Tensor &y, const double abs_err) { static void Near(const Tensor &x, const Tensor &y, const double abs_err) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v()); ASSERT_EQ(x.dtype(), DataTypeToEnum<EXP_TYPE>::v());
AssertSameTypeDims(x, y); ASSERT_EQ(y.dtype(), DataTypeToEnum<RES_TYPE>::v());
AssertSameDims(x, y);
Tensor::MappingGuard x_mapper(&x); Tensor::MappingGuard x_mapper(&x);
Tensor::MappingGuard y_mapper(&y); Tensor::MappingGuard y_mapper(&y);
auto a = x.data<T>(); auto a = x.data<EXP_TYPE>();
auto b = y.data<T>(); auto b = y.data<RES_TYPE>();
for (int i = 0; i < x.size(); ++i) { for (int i = 0; i < x.size(); ++i) {
EXPECT_NEAR(a[i], b[i], abs_err) << "a = " << a << " b = " << b EXPECT_NEAR(a[i], b[i], abs_err) << "a = " << a << " b = " << b
<< " index = " << i; << " index = " << i;
...@@ -335,10 +340,25 @@ template <typename T> ...@@ -335,10 +340,25 @@ template <typename T>
void ExpectTensorNear(const Tensor &x, const Tensor &y, const double abs_err) { void ExpectTensorNear(const Tensor &x, const Tensor &y, const double abs_err) {
static_assert(is_floating_point_type<T>::value, static_assert(is_floating_point_type<T>::value,
"T is not a floating point type"); "T is not a floating point type");
Expector<T>::Near(x, y, abs_err); Expector<T, T>::Near(x, y, abs_err);
}
template <typename EXP_TYPE, typename RES_TYPE>
void ExpectTensorNear(const Tensor &x, const Tensor &y, const double abs_err) {
static_assert(is_floating_point_type<EXP_TYPE>::value
&& is_floating_point_type<RES_TYPE>::value,
"T is not a floating point type");
Expector<EXP_TYPE, RES_TYPE>::Near(x, y, abs_err);
}
template <typename T>
std::string ToString(const T &input) {
std::stringstream ss;
ss << input;
return ss.str();
} }
template <DeviceType D> template <DeviceType D, typename T>
void BufferToImage(OpsTestNet &net, void BufferToImage(OpsTestNet &net,
const std::string &input_name, const std::string &input_name,
const std::string &output_name, const std::string &output_name,
...@@ -347,6 +367,7 @@ void BufferToImage(OpsTestNet &net, ...@@ -347,6 +367,7 @@ void BufferToImage(OpsTestNet &net,
.Input(input_name) .Input(input_name)
.Output(output_name) .Output(output_name)
.AddIntArg("buffer_type", type) .AddIntArg("buffer_type", type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
...@@ -355,7 +376,7 @@ void BufferToImage(OpsTestNet &net, ...@@ -355,7 +376,7 @@ void BufferToImage(OpsTestNet &net,
net.Sync(); net.Sync();
} }
template <DeviceType D> template <DeviceType D, typename T>
void ImageToBuffer(OpsTestNet &net, void ImageToBuffer(OpsTestNet &net,
const std::string &input_name, const std::string &input_name,
const std::string &output_name, const std::string &output_name,
...@@ -364,6 +385,7 @@ void ImageToBuffer(OpsTestNet &net, ...@@ -364,6 +385,7 @@ void ImageToBuffer(OpsTestNet &net,
.Input(input_name) .Input(input_name)
.Output(output_name) .Output(output_name)
.AddIntArg("buffer_type", type) .AddIntArg("buffer_type", type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册