You need to sign in or sign up before continuing.
提交 f147aa67 编写于 作者: 李寅

Merge branch 'gpu-oihw' into 'master'

Change the filter format from HWOI to OIHW.

See merge request !471
......@@ -40,13 +40,19 @@ SerialNet::SerialNet(const std::shared_ptr<const OperatorRegistry> op_registry,
MACE_LATENCY_LOGGER(1, "Constructing SerialNet ", net_def->name());
for (int idx = 0; idx < net_def->op_size(); ++idx) {
const auto &operator_def = net_def->op(idx);
VLOG(3) << "Creating operator " << operator_def.name() << "("
<< operator_def.type() << ")";
OperatorDef temp_def(operator_def);
std::unique_ptr<OperatorBase> op(
op_registry->CreateOperator(temp_def, ws, type, mode));
if (op) {
operators_.emplace_back(std::move(op));
// TODO(liuqi): refactor based on PB
const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
operator_def, "device", -1);
if (op_device == type) {
VLOG(3) << "Creating operator " << operator_def.name() << "("
<< operator_def.type() << ")";
OperatorDef temp_def(operator_def);
std::unique_ptr<OperatorBase> op(
op_registry->CreateOperator(temp_def, ws, type, mode));
if (op) {
operators_.emplace_back(std::move(op));
}
}
}
}
......
......@@ -136,7 +136,11 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
// As DSP may have different data output type for each op,
// we stick to the same concept.
for (auto &op : net_def.op()) {
if (!op.mem_id().empty()) {
// TODO(liuqi): refactor based on PB
const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "device", -1);
if (op_device == device_type && !op.mem_id().empty()) {
const DataType op_dtype = static_cast<DataType>(
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "T", static_cast<int>(DT_FLOAT)));
......@@ -150,20 +154,29 @@ void Workspace::CreateOutputTensorBuffer(const NetDef &net_def,
MACE_CHECK(dtype != DataType::DT_INVALID, "data type is invalid.");
for (auto &mem_block : net_def.mem_arena().mem_block()) {
if (device_type == DeviceType::GPU) {
std::unique_ptr<BufferBase> image_buf(
new Image({mem_block.x(), mem_block.y()}, dtype));
preallocated_allocator_.SetBuffer(mem_block.mem_id(),
std::move(image_buf));
// TODO(liuqi): refactor based on PB
if (mem_block.mem_id() >= 20000) {
std::unique_ptr<BufferBase> image_buf(
new Image({mem_block.x(), mem_block.y()}, dtype));
preallocated_allocator_.SetBuffer(mem_block.mem_id(),
std::move(image_buf));
}
} else {
std::unique_ptr<BufferBase> tensor_buf(
new Buffer(GetDeviceAllocator(device_type), mem_block.x()));
preallocated_allocator_.SetBuffer(mem_block.mem_id(),
std::move(tensor_buf));
if (mem_block.mem_id() < 20000) {
std::unique_ptr<BufferBase> tensor_buf(
new Buffer(GetDeviceAllocator(device_type), mem_block.x()));
preallocated_allocator_.SetBuffer(mem_block.mem_id(),
std::move(tensor_buf));
}
}
}
VLOG(3) << "Preallocate buffer to tensors";
for (auto &op : net_def.op()) {
if (!op.mem_id().empty()) {
// TODO(liuqi): refactor based on PB
const int op_device =
ArgumentHelper::GetSingleArgument<OperatorDef, int>(
op, "device", -1);
if (op_device == device_type && !op.mem_id().empty()) {
auto mem_ids = op.mem_id();
int count = mem_ids.size();
for (int i = 0; i < count; ++i) {
......
......@@ -25,17 +25,15 @@ namespace mace {
namespace kernels {
struct BufferToImageFunctorBase {
explicit BufferToImageFunctorBase(bool i2b)
: i2b_(i2b), kernel_error_(nullptr) {}
bool i2b_;
BufferToImageFunctorBase()
: kernel_error_(nullptr) {}
std::unique_ptr<BufferBase> kernel_error_;
};
template <DeviceType D, typename T>
struct BufferToImageFunctor : BufferToImageFunctorBase {
explicit BufferToImageFunctor(bool i2b = false)
: BufferToImageFunctorBase(i2b) {}
void operator()(Tensor *input,
BufferToImageFunctor() {}
void operator()(const Tensor *input,
const BufferType type,
Tensor *output,
StatsFuture *future) {
......@@ -49,9 +47,8 @@ struct BufferToImageFunctor : BufferToImageFunctorBase {
template <typename T>
struct BufferToImageFunctor<DeviceType::GPU, T> : BufferToImageFunctorBase {
explicit BufferToImageFunctor(bool i2b = false)
: BufferToImageFunctorBase(i2b) {}
void operator()(Tensor *input,
BufferToImageFunctor() {}
void operator()(const Tensor *input,
const BufferType type,
Tensor *output,
StatsFuture *future);
......
......@@ -21,12 +21,12 @@ namespace mace {
namespace kernels {
void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW
const index_t *filter_shape, // OIHW
const int *dilations,
const int *strides,
Padding padding,
index_t *output_shape,
int *padding_size) {
const index_t *filter_shape, // OIHW
const int *dilations,
const int *strides,
Padding padding,
index_t *output_shape,
int *padding_size) {
MACE_CHECK(dilations[0] > 0 && dilations[1] > 0,
"Invalid dilations, must >= 1");
MACE_CHECK((dilations[0] == 1 || strides[0] == 1) &&
......@@ -85,7 +85,7 @@ void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW
}
void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // HWOI
const index_t *filter_shape, // OIHW
const int *dilations,
const int *strides,
Padding padding,
......@@ -108,9 +108,9 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
padding_size[1] = 0;
index_t output_height = 0, output_width = 0;
index_t kernel_height = filter_shape[0];
index_t kernel_width = filter_shape[1];
index_t output_channels = filter_shape[2];
index_t output_channels = filter_shape[0];
index_t kernel_height = filter_shape[2];
index_t kernel_width = filter_shape[3];
index_t k_extent_height = (kernel_height - 1) * dilations[0] + 1;
index_t k_extent_width = (kernel_width - 1) * dilations[1] + 1;
......@@ -151,7 +151,7 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
void CalcOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // HWOI
const index_t *filter_shape, // OIHW
const int *padding_size,
const int *dilations,
const int *strides,
......@@ -168,28 +168,28 @@ void CalcOutputSize(const index_t *input_shape, // NHWC
output_shape[0] = input_shape[0];
if (round_type == FLOOR) {
output_shape[1] = static_cast<index_t>(
std::floor(1.0 * (input_shape[1] + padding_size[0] - filter_shape[0] -
(filter_shape[0] - 1) * (dilations[0] - 1)) /
std::floor(1.0 * (input_shape[1] + padding_size[0] - filter_shape[2] -
(filter_shape[2] - 1) * (dilations[0] - 1)) /
strides[0]) +
1);
output_shape[2] = static_cast<index_t>(
std::floor(1.0 * (input_shape[2] + padding_size[1] - filter_shape[1] -
(filter_shape[1] - 1) * (dilations[1] - 1)) /
std::floor(1.0 * (input_shape[2] + padding_size[1] - filter_shape[3] -
(filter_shape[3] - 1) * (dilations[1] - 1)) /
strides[1]) +
1);
} else {
output_shape[1] = static_cast<index_t>(
std::ceil(1.0 * (input_shape[1] + padding_size[0] - filter_shape[0] -
(filter_shape[0] - 1) * (dilations[0] - 1)) /
std::ceil(1.0 * (input_shape[1] + padding_size[0] - filter_shape[2] -
(filter_shape[2] - 1) * (dilations[0] - 1)) /
strides[0]) +
1);
output_shape[2] = static_cast<index_t>(
std::ceil(1.0 * (input_shape[2] + padding_size[1] - filter_shape[1] -
(filter_shape[1] - 1) * (dilations[1] - 1)) /
std::ceil(1.0 * (input_shape[2] + padding_size[1] - filter_shape[3] -
(filter_shape[3] - 1) * (dilations[1] - 1)) /
strides[1]) +
1);
}
output_shape[3] = filter_shape[2];
output_shape[3] = filter_shape[0];
}
void CalcNCHWOutputSize(const index_t *input_shape, // NCHW
......
......@@ -49,7 +49,7 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape,
int *padding_size);
void CalcOutputSize(const index_t *input_shape, // NHWC
const index_t *filter_shape, // HWOI
const index_t *filter_shape, // OIHW
const int *padding_size,
const int *dilations,
const int *strides,
......
......@@ -117,15 +117,14 @@ struct Deconv2dFunctorBase {
const int *strides,
index_t *output_shape,
const int *padding_size,
const bool isNCHW = false,
const bool isOIHW = false) {
const bool isNCHW = false) {
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
MACE_CHECK_NOTNULL(input_shape);
MACE_CHECK_NOTNULL(filter_shape);
MACE_CHECK_NOTNULL(strides);
const index_t output_channel = isOIHW ? filter_shape[0] : filter_shape[2];
const index_t output_channel = filter_shape[0];
const index_t in_height = isNCHW ? input_shape[2] : input_shape[1];
const index_t in_width = isNCHW ? input_shape[3] : input_shape[2];
......@@ -135,8 +134,8 @@ struct Deconv2dFunctorBase {
const index_t extended_input_width =
(in_width - 1) * strides[1] + 1 + padding_size[1];
const index_t filter_h = isOIHW ? filter_shape[2] : filter_shape[0];
const index_t filter_w = isOIHW ? filter_shape[3] : filter_shape[1];
const index_t filter_h = filter_shape[2];
const index_t filter_w = filter_shape[3];
index_t out_height = extended_input_height - filter_h + 1;
index_t out_width = extended_input_width - filter_w + 1;
......@@ -160,8 +159,7 @@ struct Deconv2dFunctorBase {
Padding padding,
const index_t *output_shape,
int *padding_size,
const bool isNCHW = false,
const bool isOIHW = false) {
const bool isNCHW = false) {
MACE_CHECK_NOTNULL(output_shape);
MACE_CHECK_NOTNULL(padding_size);
MACE_CHECK_NOTNULL(input_shape);
......@@ -177,8 +175,8 @@ struct Deconv2dFunctorBase {
const index_t extended_input_height = (in_height - 1) * strides[0] + 1;
const index_t extended_input_width = (in_width - 1) * strides[1] + 1;
const index_t filter_h = isOIHW ? filter_shape[2] : filter_shape[0];
const index_t filter_w = isOIHW ? filter_shape[3] : filter_shape[1];
const index_t filter_h = filter_shape[2];
const index_t filter_w = filter_shape[3];
index_t expected_input_height = 0, expected_input_width = 0;
......@@ -259,7 +257,7 @@ struct Deconv2dFunctor : Deconv2dFunctorBase {
filter->shape().data(),
strides_, padding_type_,
output_shape.data(),
paddings_.data(), true, true);
paddings_.data(), true);
output->Resize(output_shape);
} else {
output_shape_.clear();
......@@ -268,7 +266,7 @@ struct Deconv2dFunctor : Deconv2dFunctorBase {
filter->shape().data(),
strides_,
output_shape_.data(),
paddings_.data(), true, true);
paddings_.data(), true);
output->Resize(output_shape_);
}
index_t batch = output->dim(0);
......
......@@ -32,14 +32,11 @@ namespace mace {
namespace kernels {
struct FullyConnectedBase {
FullyConnectedBase(const int /*BufferType*/ weight_type,
const ActivationType activation,
FullyConnectedBase(const ActivationType activation,
const float relux_max_limit)
: weight_type_(weight_type),
activation_(activation),
: activation_(activation),
relux_max_limit_(relux_max_limit) {}
const int weight_type_;
const ActivationType activation_;
const float relux_max_limit_;
};
......@@ -49,10 +46,9 @@ struct FullyConnectedFunctor;
template <>
struct FullyConnectedFunctor<DeviceType::CPU, float>: FullyConnectedBase {
FullyConnectedFunctor(const int /*BufferType*/ weight_type,
const ActivationType activation,
FullyConnectedFunctor(const ActivationType activation,
const float relux_max_limit)
: FullyConnectedBase(weight_type, activation, relux_max_limit) {}
: FullyConnectedBase(activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *weight,
......@@ -63,7 +59,7 @@ struct FullyConnectedFunctor<DeviceType::CPU, float>: FullyConnectedBase {
std::vector<index_t> output_shape = {input->dim(0), weight->dim(0), 1, 1};
output->Resize(output_shape);
const index_t N = output->dim(0);
const index_t input_size = weight->dim(1);
const index_t input_size = weight->dim(1) * weight->dim(2) * weight->dim(3);
const index_t output_size = weight->dim(0);
Tensor::MappingGuard guard_input(input);
......@@ -90,10 +86,9 @@ struct FullyConnectedFunctor<DeviceType::CPU, float>: FullyConnectedBase {
#ifdef MACE_ENABLE_OPENCL
template <typename T>
struct FullyConnectedFunctor<DeviceType::GPU, T> : FullyConnectedBase {
FullyConnectedFunctor(const int /*BufferType*/ weight_type,
const ActivationType activation,
FullyConnectedFunctor(const ActivationType activation,
const float relux_max_limit)
: FullyConnectedBase(weight_type, activation, relux_max_limit) {}
: FullyConnectedBase(activation, relux_max_limit) {}
void operator()(const Tensor *input,
const Tensor *weight,
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_IMAGE_TO_BUFFER_H_
#define MACE_KERNELS_IMAGE_TO_BUFFER_H_
#include <memory>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/kernels/opencl/helper.h"
namespace mace {
namespace kernels {
struct ImageToBufferFunctorBase {
ImageToBufferFunctorBase()
: kernel_error_(nullptr) {}
std::unique_ptr<BufferBase> kernel_error_;
};
template <DeviceType D, typename T>
struct ImageToBufferFunctor : ImageToBufferFunctorBase {
ImageToBufferFunctor() {}
void operator()(const Tensor *input,
const BufferType type,
Tensor *output,
StatsFuture *future) {
MACE_NOT_IMPLEMENTED;
}
};
template <typename T>
struct ImageToBufferFunctor<DeviceType::GPU, T> : ImageToBufferFunctorBase {
ImageToBufferFunctor() {}
void operator()(const Tensor *input,
const BufferType type,
Tensor *output,
StatsFuture *future);
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_IMAGE_TO_BUFFER_H_
......@@ -21,20 +21,18 @@ namespace kernels {
template <typename T>
void BufferToImageFunctor<DeviceType::GPU, T>::operator()(
Tensor *buffer, const BufferType type, Tensor *image, StatsFuture *future) {
std::vector<size_t> image_shape;
const Tensor *buffer,
const BufferType type,
Tensor *image,
StatsFuture *future) {
if (!i2b_) {
CalImage2DShape(buffer->shape(), type, &image_shape);
if (type == WINOGRAD_FILTER) {
std::vector<index_t> new_shape = CalWinogradShape(buffer->shape(), type);
image->ResizeImage(new_shape, image_shape);
} else {
image->ResizeImage(buffer->shape(), image_shape);
}
std::vector<size_t> image_shape;
CalImage2DShape(buffer->shape(), type, &image_shape);
if (type == WINOGRAD_FILTER) {
std::vector<index_t> new_shape = CalWinogradShape(buffer->shape(), type);
image->ResizeImage(new_shape, image_shape);
} else {
CalImage2DShape(image->shape(), type, &image_shape);
buffer->Resize(image->shape());
image->ResizeImage(buffer->shape(), image_shape);
}
uint32_t gws[2] = {static_cast<uint32_t>(image_shape[0]),
......@@ -42,32 +40,32 @@ void BufferToImageFunctor<DeviceType::GPU, T>::operator()(
std::string kernel_name;
switch (type) {
case CONV2D_FILTER:
kernel_name = i2b_ ? "filter_image_to_buffer" : "filter_buffer_to_image";
kernel_name = "filter_buffer_to_image";
break;
case DW_CONV2D_FILTER:
kernel_name =
i2b_ ? "dw_filter_image_to_buffer" : "dw_filter_buffer_to_image";
kernel_name = "dw_filter_buffer_to_image";
break;
case IN_OUT_CHANNEL:
kernel_name = i2b_ ? "in_out_image_to_buffer" : "in_out_buffer_to_image";
kernel_name = "in_out_buffer_to_image";
break;
case ARGUMENT:
kernel_name = i2b_ ? "arg_image_to_buffer" : "arg_buffer_to_image";
kernel_name = "arg_buffer_to_image";
break;
case IN_OUT_HEIGHT:
case WEIGHT_HEIGHT:
kernel_name = i2b_ ? "in_out_height_image_to_buffer"
: "in_out_height_buffer_to_image";
kernel_name = "in_out_height_buffer_to_image";
break;
case IN_OUT_WIDTH:
case WEIGHT_WIDTH:
MACE_CHECK(!i2b_) << "IN_OUT_WIDTH only support buffer to image now";
kernel_name = "in_out_width_buffer_to_image";
break;
case WEIGHT_HEIGHT:
kernel_name = "weight_height_buffer_to_image";
break;
case WEIGHT_WIDTH:
kernel_name = "weight_width_buffer_to_image";
break;
case WINOGRAD_FILTER:
gws[1] /= 16;
kernel_name = i2b_ ? "winograd_filter_image_to_buffer"
: "winograd_filter_buffer_to_image";
kernel_name = "winograd_filter_buffer_to_image";
break;
}
......@@ -115,24 +113,25 @@ void BufferToImageFunctor<DeviceType::GPU, T>::operator()(
b2f_kernel.setArg(idx++, gws[1]);
}
b2f_kernel.setArg(idx++, *(buffer->opencl_buffer()));
if (!i2b_) {
MACE_CHECK(buffer->buffer_offset() % GetEnumTypeSize(buffer->dtype()) == 0,
"buffer offset not aligned");
b2f_kernel.setArg(idx++,
static_cast<uint32_t>(buffer->buffer_offset() /
GetEnumTypeSize(buffer->dtype())));
}
MACE_CHECK(buffer->buffer_offset() % GetEnumTypeSize(buffer->dtype()) == 0,
"buffer offset not aligned");
b2f_kernel.setArg(idx++,
static_cast<uint32_t>(buffer->buffer_offset() /
GetEnumTypeSize(buffer->dtype())));
if (type == CONV2D_FILTER) {
const index_t inner_size =
buffer->dim(1) * buffer->dim(2) * buffer->dim(3);
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(inner_size));
} else if (type == DW_CONV2D_FILTER || type == WEIGHT_HEIGHT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
} else if (type == ARGUMENT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
} else if (type == WEIGHT_HEIGHT || type == WEIGHT_WIDTH) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, 1);
} else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
......
......@@ -2,12 +2,12 @@
__kernel void filter_buffer_to_image(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, /* h, w, oc, ic */
__global const DATA_TYPE *input, /* OIHW */
__private const int input_offset,
__private const int out_channel,
__private const int filter_h,
__private const int filter_w,
__private const int out_channel,
__private const int in_channel,
__private const int inner_size,
__write_only image2d_t output) {
int w = get_global_id(0);
int h = get_global_id(1);
......@@ -24,10 +24,9 @@ __kernel void filter_buffer_to_image(KERNEL_ERROR_PARAMS
const int hw_idx = h % hw_size;
const int h_idx = hw_idx / filter_w;
const int w_idx = hw_idx % filter_w;
const int offset = input_offset
+ ((h_idx * filter_w + w_idx) * out_channel
+ out_channel_idx) * in_channel
+ in_channel_idx;
const int offset = input_offset +
mad24(out_channel_idx, inner_size,
mad24(mad24(in_channel_idx, filter_h, h_idx), filter_w, w_idx));
DATA_TYPE4 values = 0;
if (out_channel_idx < out_channel) {
......@@ -35,16 +34,16 @@ __kernel void filter_buffer_to_image(KERNEL_ERROR_PARAMS
if (size < 4) {
switch (size) {
case 3:
values.z = *(input + offset + 2 * in_channel);
values.z = *(input + offset + 2 * inner_size);
case 2:
values.y = *(input + offset + 1 * in_channel);
values.y = *(input + offset + 1 * inner_size);
case 1:
values.x = *(input + offset);
}
} else {
values.w = *(input + offset + 3 * in_channel);
values.z = *(input + offset + 2 * in_channel);
values.y = *(input + offset + 1 * in_channel);
values.w = *(input + offset + 3 * inner_size);
values.z = *(input + offset + 2 * inner_size);
values.y = *(input + offset + 1 * inner_size);
values.x = *(input + offset);
}
}
......@@ -55,11 +54,11 @@ __kernel void filter_buffer_to_image(KERNEL_ERROR_PARAMS
__kernel void filter_image_to_buffer(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__global DATA_TYPE *output, /* h, w, oc, ic */
__global DATA_TYPE *output, /* OIHW */
__private const int out_channel,
__private const int filter_h,
__private const int filter_w,
__private const int out_channel,
__private const int in_channel,
__private const int inner_size,
__read_only image2d_t input) {
int w = get_global_id(0);
int h = get_global_id(1);
......@@ -76,9 +75,9 @@ __kernel void filter_image_to_buffer(KERNEL_ERROR_PARAMS
const int hw_idx = h % hw_size;
const int h_idx = hw_idx / filter_w;
const int w_idx = hw_idx % filter_w;
const int offset = ((h_idx * filter_w + w_idx) * out_channel
+ out_channel_idx) * in_channel
+ in_channel_idx;
const int offset =
mad24(out_channel_idx, inner_size,
mad24(mad24(in_channel_idx, filter_h, h_idx), filter_w, w_idx));
if (out_channel_idx < out_channel) {
int2 coord = (int2)(w, h);
......@@ -87,28 +86,30 @@ __kernel void filter_image_to_buffer(KERNEL_ERROR_PARAMS
if (size < 4) {
switch (size) {
case 3:
output[offset + 2 * in_channel] = values.z;
output[offset + 2 * inner_size] = values.z;
case 2:
output[offset + 1 * in_channel] = values.y;
output[offset + 1 * inner_size] = values.y;
case 1:
output[offset] = values.x;
}
} else {
output[offset + 3 * in_channel] = values.w;
output[offset + 2 * in_channel] = values.z;
output[offset + 1 * in_channel] = values.y;
output[offset + 3 * inner_size] = values.w;
output[offset + 2 * inner_size] = values.z;
output[offset + 1 * inner_size] = values.y;
output[offset] = values.x;
}
}
}
// TODO(liuqi): Support multiplier > 1
__kernel void dw_filter_buffer_to_image(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, /* h, w, ic, m */
__global const DATA_TYPE *input, /* MIHW */
__private const int input_offset,
__private const int filter_w,
__private const int in_channel,
__private const int multiplier,
__private const int in_channel,
__private const int filter_h,
__private const int filter_w,
__write_only image2d_t output) { /* ic%4 * kh * kw * m, ic/4 */
const int w = get_global_id(0);
const int h = get_global_id(1);
......@@ -125,35 +126,28 @@ __kernel void dw_filter_buffer_to_image(KERNEL_ERROR_PARAMS
const int h_idx = w / filter_w;
const int w_idx = w % filter_w;
const int offset = input_offset + mad24(mad24(h_idx, filter_w, w_idx),
in_channel, in_channel_idx);
const int offset = input_offset
+ mad24(mad24(in_channel_idx, filter_h, h_idx), filter_w, w_idx);
const int hw_size = mul24(filter_h, filter_w);
const int size = in_channel - in_channel_idx;
if (in_channel_idx < in_channel) {
if (size < 4) {
switch(size) {
case 3:
values.z = *(input + offset + 2);
values.z = *(input + offset + 2 * hw_size);
case 2:
values.y = *(input + offset + 1);
values.y = *(input + offset + 1 * hw_size);
case 1:
values.x = *(input + offset);
}
} else {
values = vload4(0, input + offset);
values.x = *(input + offset);
values.y = *(input + offset + 1 * hw_size);
values.z = *(input + offset + 2 * hw_size);
values.w = *(input + offset + 3 * hw_size);
}
}
} else {
const int in_channel_idx = h << 2;
const int m = w % multiplier;
const int hw_idx = w / multiplier;
const int h_idx = hw_idx / filter_w;
const int w_idx = hw_idx % filter_w;
const int offset = input_offset + mad24(mad24(mad24(h_idx, filter_w, w_idx),
in_channel, in_channel_idx),
multiplier, m);
// TODO support multiplier > 1
}
int2 coord = (int2)(w, h);
......@@ -244,7 +238,7 @@ __kernel void in_out_image_to_buffer(KERNEL_ERROR_PARAMS
__kernel void arg_buffer_to_image(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, /* nhwc */
__global const DATA_TYPE *input,
__private const int input_offset,
__private const int count,
__write_only image2d_t output) {
......@@ -280,7 +274,7 @@ __kernel void arg_buffer_to_image(KERNEL_ERROR_PARAMS
__kernel void arg_image_to_buffer(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__global DATA_TYPE *output, /* nhwc */
__global DATA_TYPE *output,
__private const int count,
__read_only image2d_t input) {
int w = get_global_id(0);
......@@ -365,11 +359,11 @@ __kernel void in_out_height_image_to_buffer(KERNEL_ERROR_PARAMS
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
#ifndef NON_UNIFORM_WORK_GROUP
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
#endif
const int height_blks = (height + 3) / 4;
const int batch_idx = h / height_blks;
......@@ -393,7 +387,6 @@ __kernel void in_out_height_image_to_buffer(KERNEL_ERROR_PARAMS
output[offset] = values.w;
}
__kernel void in_out_width_buffer_to_image(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, /* nhwc */
......@@ -405,18 +398,19 @@ __kernel void in_out_width_buffer_to_image(KERNEL_ERROR_PARAMS
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
#ifndef NON_UNIFORM_WORK_GROUP
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
#endif
#endif
const int width_blks = (width + 3) / 4;
const int batch_idx = h / height;
const int height_idx = h % height;
const int width_idx = (w % width_blks) << 2;
const int channel_idx = w / width_blks;
const int offset = input_offset + ((batch_idx * height + height_idx) * width + width_idx) * channels
const int offset = input_offset
+ ((batch_idx * height + height_idx) * width + width_idx) * channels
+ channel_idx;
int size = width - width_idx;
......@@ -436,6 +430,192 @@ __kernel void in_out_width_buffer_to_image(KERNEL_ERROR_PARAMS
WRITE_IMAGET(output, coord, values);
}
__kernel void weight_height_buffer_to_image(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, // OIHW
__private const int input_offset,
__private const int out_channels,
__private const int in_channels,
__private const int height,
__private const int width,
__write_only image2d_t output) {
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
const int inner_size = global_size_dim0;
#else
const int inner_size = get_global_size(0);
#endif
const int out_chan_idx = h << 2;
const int in_chan_idx = w % in_channels;
const int hw_idx = w / in_channels;
const int height_idx = hw_idx / width;
const int width_idx = hw_idx % width;
int offset = input_offset +
mad24(out_chan_idx, inner_size,
mad24(mad24(in_chan_idx, height, height_idx), width, width_idx));
int size = out_channels - out_chan_idx;
size = size >= 4 ? 0 : size;
DATA_TYPE4 values = 0;
switch (size) {
case 0:
values.w = *(input + offset + inner_size * 3);
case 3:
values.z = *(input + offset + inner_size * 2);
case 2:
values.y = *(input + offset + inner_size);
case 1:
values.x = *(input + offset);
}
int2 coord = (int2)(w, h);
WRITE_IMAGET(output, coord, values);
}
__kernel void weight_height_image_to_buffer(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__global DATA_TYPE *output, //OIHW
__private const int out_channels,
__private const int in_channels,
__private const int height,
__private const int width,
__read_only image2d_t input) {
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
const int inner_size = global_size_dim0;
#else
const int inner_size = get_global_size(0);
#endif
const int out_chan_idx = h << 2;
const int in_chan_idx = w % in_channels;
const int hw_idx = w / in_channels;
const int height_idx = hw_idx / width;
const int width_idx = hw_idx % width;
int offset =
mad24(out_chan_idx, inner_size,
mad24(mad24(in_chan_idx, height, height_idx), width, width_idx));
int2 coord = (int2)(w, h);
DATA_TYPE4 values = READ_IMAGET(input, SAMPLER, coord);
output[offset] = values.x;
if (out_chan_idx + 1 >= out_channels) return;
offset += inner_size;
output[offset] = values.y;
if (out_chan_idx + 2 >= out_channels) return;
offset += inner_size;
output[offset] = values.z;
if (out_chan_idx + 3 >= out_channels) return;
offset += inner_size;
output[offset] = values.w;
}
__kernel void weight_width_buffer_to_image(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__global const DATA_TYPE *input, // OIHW
__private const int input_offset,
__private const int in_channels,
__private const int height,
__private const int width,
__write_only image2d_t output) {
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
const int out_channels = global_size_dim1;
#else
const int out_channels = get_global_size(1);
#endif
const int in_chan_blks = (in_channels + 3) >> 2;
const int hw_size = height * width;
const int inner_size = in_channels * hw_size;
const int out_chan_idx = h;
const int in_chan_idx = (w % in_chan_blks) << 2;
const int hw_idx = w / in_chan_blks;
const int height_idx = hw_idx / width;
const int width_idx = hw_idx % width;
int offset = input_offset +
mad24(out_chan_idx, inner_size,
mad24(mad24(in_chan_idx, height, height_idx), width, width_idx));
int size = in_channels - in_chan_idx;
size = size >= 4 ? 0 : size;
DATA_TYPE4 values = 0;
switch (size) {
case 0:
values.w = *(input + offset + hw_size * 3);
case 3:
values.z = *(input + offset + hw_size * 2);
case 2:
values.y = *(input + offset + hw_size);
case 1:
values.x = *(input + offset);
}
int2 coord = (int2)(w, h);
WRITE_IMAGET(output, coord, values);
}
__kernel void weight_width_image_to_buffer(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__global DATA_TYPE *output, // OIHW
__private const int in_channels,
__private const int height,
__private const int width,
__read_only image2d_t input) {
int w = get_global_id(0);
int h = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
if (w >= global_size_dim0 || h >= global_size_dim1) {
return;
}
const int out_channels = global_size_dim1;
#else
const int out_channels = get_global_size(1);
#endif
const int in_chan_blks = (in_channels + 3) >> 2;
const int hw_size = height * width;
const int inner_size = in_channels * hw_size;
const int out_chan_idx = h;
const int in_chan_idx = (w % in_chan_blks) << 2;
const int hw_idx = w / in_chan_blks;
const int height_idx = hw_idx / width;
const int width_idx = hw_idx % width;
int offset =
mad24(out_chan_idx, inner_size,
mad24(mad24(in_chan_idx, height, height_idx), width, width_idx));
int2 coord = (int2)(w, h);
DATA_TYPE4 values = READ_IMAGET(input, SAMPLER, coord);
output[offset] = values.x;
if (in_chan_idx + 1 >= in_channels) return;
offset += hw_size;
output[offset] = values.y;
if (in_chan_idx + 2 >= in_channels) return;
offset += hw_size;
output[offset] = values.z;
if (in_chan_idx + 3 >= in_channels) return;
offset += hw_size;
output[offset] = values.w;
}
// only support 3x3 now
__kernel void winograd_filter_buffer_to_image(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
......
......@@ -83,8 +83,8 @@ void Conv2dFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
static const Conv2dOpenclFunction selector[5] = {
Conv2dOpenclK1x1, nullptr, Conv2dOpenclK3x3, nullptr, nullptr};
index_t kernel_h = filter->dim(0);
index_t kernel_w = filter->dim(1);
index_t kernel_h = filter->dim(2);
index_t kernel_w = filter->dim(3);
if (strides_[0] != strides_[1] ||
(dilations_[0] > 1 && (strides_[0] > 1 || kernel_h == 1))) {
LOG(WARNING) << "OpenCL conv2d kernel with "
......
......@@ -155,8 +155,8 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
kernel->setArg(idx++, static_cast<uint32_t>(input_channel_blocks));
kernel->setArg(idx++, static_cast<uint32_t>(height));
kernel->setArg(idx++, static_cast<uint32_t>(width));
kernel->setArg(idx++, static_cast<uint32_t>(filter->dim(0)));
kernel->setArg(idx++, static_cast<uint32_t>(filter->dim(1)));
kernel->setArg(idx++, static_cast<uint32_t>(filter->dim(2)));
kernel->setArg(idx++, static_cast<uint32_t>(filter->dim(3)));
kernel->setArg(idx++, static_cast<uint32_t>(stride));
kernel->setArg(idx++, padding[0] / 2);
kernel->setArg(idx++, padding[1] / 2);
......@@ -169,9 +169,9 @@ extern void Conv2dOpencl(cl::Kernel *kernel,
std::string tuning_key =
Concat("conv2d_general_opencl_kernel", output->dim(0),
output->dim(1), output->dim(2), output->dim(3),
filter->dim(0), filter->dim(1));
filter->dim(2), filter->dim(3));
std::vector<uint32_t> lws =
LocalWS(gws, filter->dim(0) * filter->dim(1), *kwg_size);
LocalWS(gws, filter->dim(2) * filter->dim(3), *kwg_size);
TuningOrRun3DKernel(*kernel, tuning_key, gws, lws, future);
if (runtime->IsOutOfRangeCheckEnabled()) {
......
......@@ -52,7 +52,7 @@ void Deconv2dOpencl(cl::Kernel *kernel,
const int align_h = stride - 1 - padding_h;
const int align_w = stride - 1 - padding_w;
const int kernel_size = filter->dim(0) * filter->dim(1);
const int kernel_size = filter->dim(2) * filter->dim(3);
auto runtime = OpenCLRuntime::Global();
......@@ -127,8 +127,8 @@ void Deconv2dOpencl(cl::Kernel *kernel,
kernel->setArg(idx++, static_cast<int32_t>(align_w));
kernel->setArg(idx++, static_cast<int32_t>(padding_h));
kernel->setArg(idx++, static_cast<int32_t>(padding_w));
kernel->setArg(idx++, static_cast<int32_t>(filter->dim(0)));
kernel->setArg(idx++, static_cast<int32_t>(filter->dim(1)));
kernel->setArg(idx++, static_cast<int32_t>(filter->dim(2)));
kernel->setArg(idx++, static_cast<int32_t>(filter->dim(3)));
kernel->setArg(idx++, static_cast<int32_t>(kernel_size));
kernel->setArg(idx++, static_cast<int32_t>(input_channel_blocks));
kernel->setArg(idx++, static_cast<int32_t>(channel_blocks));
......
......@@ -73,7 +73,7 @@ static void DepthwiseConv2d(cl::Kernel *kernel,
const index_t channels = output->dim(3);
const index_t input_channels = input->dim(3);
const index_t multiplier = filter->dim(3);
const index_t multiplier = filter->dim(0);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t input_channel_blocks = RoundUpDiv4(input_channels);
......@@ -138,11 +138,11 @@ static void DepthwiseConv2d(cl::Kernel *kernel,
const index_t input_height = input->dim(1);
const index_t input_width = input->dim(2);
const index_t filter_height = filter->dim(0);
const index_t filter_width = filter->dim(1);
const index_t filter_height = filter->dim(2);
const index_t filter_width = filter->dim(3);
MACE_CHECK(multiplier == 1, "Multiplier > 1 not supported");
MACE_CHECK(multiplier * input_channels == channels);
MACE_CHECK(filter->dim(2) == input_channels, filter->dim(2), "!=",
MACE_CHECK(filter->dim(1) == input_channels, filter->dim(1), "!=",
input_channels);
uint32_t idx = 0;
......@@ -195,7 +195,7 @@ static void DepthwiseConv2d(cl::Kernel *kernel,
template <typename T>
void DepthwiseConv2dFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input,
const Tensor *filter,
const Tensor *filter, /* MIHW */
const Tensor *bias,
Tensor *output,
StatsFuture *future) {
......@@ -216,10 +216,10 @@ void DepthwiseConv2dFunctor<DeviceType::GPU, T>::operator()(
// Create a fake conv_2d filter to calculate the paddings and output size
std::vector<index_t> fake_filter_shape(4);
fake_filter_shape[0] = filter->shape()[0];
fake_filter_shape[1] = filter->shape()[1];
fake_filter_shape[2] = filter->shape()[2] * filter->shape()[3];
fake_filter_shape[3] = 1;
fake_filter_shape[0] = filter->dim(0) * filter->dim(1);
fake_filter_shape[1] = filter->dim(1);
fake_filter_shape[2] = filter->dim(2);
fake_filter_shape[3] = filter->dim(3);
std::vector<index_t> output_shape(4);
std::vector<int> paddings(2);
......
......@@ -32,8 +32,6 @@ void FCWXKernel(cl::Kernel *kernel,
const float relux_max_limit,
StatsFuture *future,
std::unique_ptr<BufferBase> *kernel_error) {
MACE_CHECK(input->dim(3) % 4 == 0)
<< "FC width kernel only support input with 4x channel.";
MACE_CHECK_NOTNULL(gws);
MACE_CHECK_NOTNULL(lws);
auto runtime = OpenCLRuntime::Global();
......@@ -294,15 +292,9 @@ void FullyConnectedFunctor<DeviceType::GPU, T>::operator()(
&output_image_shape);
output->ResizeImage(output_shape, output_image_shape);
if (weight_type_ == BufferType::WEIGHT_HEIGHT) {
FCWTXKernel<T>(&kernel_, input, weight, bias, &input_shape_, output,
activation_, &gws_, &lws_, relux_max_limit_, future,
&kernel_error_);
} else {
FCWXKernel<T>(&kernel_, input, weight, bias, &input_shape_, output,
activation_, &gws_, &lws_, relux_max_limit_, future,
&kernel_error_);
}
FCWXKernel<T>(&kernel_, input, weight, bias, &input_shape_, output,
activation_, &gws_, &lws_, relux_max_limit_, future,
&kernel_error_);
}
template struct FullyConnectedFunctor<DeviceType::GPU, float>;
......
......@@ -35,22 +35,22 @@ void CalInOutputImageShape(const std::vector<index_t> &shape, /* NHWC */
}
// [Ic, H * W * (Oc + 3) / 4]
void CalConv2dFilterImageShape(const std::vector<index_t> &shape, /* HWOI */
void CalConv2dFilterImageShape(const std::vector<index_t> &shape, /* OIHW */
std::vector<size_t> *image_shape) {
MACE_CHECK(shape.size() == 4);
image_shape->resize(2);
(*image_shape)[0] = shape[3];
(*image_shape)[1] = shape[0] * shape[1] * RoundUpDiv4(shape[2]);
(*image_shape)[0] = shape[1];
(*image_shape)[1] = shape[2] * shape[3] * RoundUpDiv4(shape[0]);
}
// [H * W * M, (Ic + 3) / 4]
void CalDepthwiseConv2dFilterImageShape(
const std::vector<index_t> &shape, /* HWIM */
const std::vector<index_t> &shape, /* MIHW */
std::vector<size_t> *image_shape) {
MACE_CHECK(shape.size() == 4);
image_shape->resize(2);
(*image_shape)[0] = shape[0] * shape[1] * shape[3];
(*image_shape)[1] = RoundUpDiv4(shape[2]);
(*image_shape)[0] = shape[0] * shape[2] * shape[3];
(*image_shape)[1] = RoundUpDiv4(shape[1]);
}
// [(size + 3) / 4, 1]
......@@ -91,21 +91,21 @@ void CalInOutWidthImageShape(const std::vector<index_t> &shape, /* NHWC */
(*image_shape)[1] = shape[0] * shape[1];
}
// [W, (H + 3) / 4]
void CalWeightHeightImageShape(const std::vector<index_t> &shape, /* HW */
// [Ic * H * W, (Oc + 3) / 4]
void CalWeightHeightImageShape(const std::vector<index_t> &shape, /* OIHW */
std::vector<size_t> *image_shape) {
MACE_CHECK(shape.size() == 2);
MACE_CHECK(shape.size() == 4);
image_shape->resize(2);
(*image_shape)[0] = shape[1];
(*image_shape)[0] = shape[1] * shape[2] * shape[3];
(*image_shape)[1] = RoundUpDiv4(shape[0]);
}
// [(W + 3) / 4, H]
void CalWeightWidthImageShape(const std::vector<index_t> &shape, /* HW */
// [(Ic + 3) / 4 * H * W, Oc]
void CalWeightWidthImageShape(const std::vector<index_t> &shape, /* OIHW */
std::vector<size_t> *image_shape) {
MACE_CHECK(shape.size() == 2);
MACE_CHECK(shape.size() == 4);
image_shape->resize(2);
(*image_shape)[0] = RoundUpDiv4(shape[1]);
(*image_shape)[0] = RoundUpDiv4(shape[1]) * shape[2] * shape[3];
(*image_shape)[1] = shape[0];
}
} // namespace
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/image_to_buffer.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
namespace mace {
namespace kernels {
template <typename T>
void ImageToBufferFunctor<DeviceType::GPU, T>::operator()(
const Tensor *image,
const BufferType type,
Tensor *buffer,
StatsFuture *future) {
std::vector<size_t> image_shape;
CalImage2DShape(image->shape(), type, &image_shape);
buffer->Resize(image->shape());
uint32_t gws[2] = {static_cast<uint32_t>(image_shape[0]),
static_cast<uint32_t>(image_shape[1])};
std::string kernel_name;
switch (type) {
case CONV2D_FILTER:
kernel_name = "filter_image_to_buffer";
break;
case IN_OUT_CHANNEL:
kernel_name = "in_out_image_to_buffer";
break;
case ARGUMENT:
kernel_name = "arg_image_to_buffer";
break;
case IN_OUT_HEIGHT:
kernel_name = "in_out_height_image_to_buffer";
break;
case WINOGRAD_FILTER:
gws[1] /= 16;
kernel_name = "winograd_filter_image_to_buffer";
break;
case WEIGHT_HEIGHT:
kernel_name = "weight_height_image_to_buffer";
break;
case WEIGHT_WIDTH:
kernel_name = "weight_width_image_to_buffer";
break;
case DW_CONV2D_FILTER:
case IN_OUT_WIDTH:
LOG(FATAL) << "IN_OUT_WIDTH only support buffer to image now";
break;
}
auto runtime = OpenCLRuntime::Global();
std::string obfuscated_kernel_name = MACE_OBFUSCATE_SYMBOL(kernel_name);
std::set<std::string> built_options;
std::stringstream kernel_name_ss;
kernel_name_ss << "-D" << kernel_name << "=" << obfuscated_kernel_name;
built_options.emplace(kernel_name_ss.str());
if (runtime->IsNonUniformWorkgroupsSupported()) {
built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
}
if (buffer->dtype() == image->dtype()) {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" +
DtToCLCMDDt(DataTypeToEnum<T>::value));
} else {
built_options.emplace("-DDATA_TYPE=" +
DtToUpstreamCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" +
DtToUpstreamCLCMDDt(DataTypeToEnum<T>::value));
}
if (runtime->IsOutOfRangeCheckEnabled()) {
built_options.emplace("-DOUT_OF_RANGE_CHECK");
if (!kernel_error_) {
kernel_error_ = std::move(std::unique_ptr<Buffer>(
new Buffer(GetDeviceAllocator(DeviceType::GPU), 1)));
kernel_error_->Map(nullptr);
*(kernel_error_->mutable_data<char>()) = 0;
kernel_error_->UnMap();
}
}
auto b2f_kernel = runtime->BuildKernel("buffer_to_image",
obfuscated_kernel_name, built_options);
uint32_t idx = 0;
if (runtime->IsOutOfRangeCheckEnabled()) {
b2f_kernel.setArg(idx++,
*(static_cast<cl::Buffer *>(kernel_error_->buffer())));
}
if (!runtime->IsNonUniformWorkgroupsSupported()) {
b2f_kernel.setArg(idx++, gws[0]);
b2f_kernel.setArg(idx++, gws[1]);
}
b2f_kernel.setArg(idx++, *(buffer->opencl_buffer()));
if (type == CONV2D_FILTER) {
const index_t inner_size =
buffer->dim(1) * buffer->dim(2) * buffer->dim(3);
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(inner_size));
} else if (type == ARGUMENT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
} else if (type == WEIGHT_HEIGHT) {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(0)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
} else {
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(1)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(2)));
b2f_kernel.setArg(idx++, static_cast<uint32_t>(buffer->dim(3)));
}
b2f_kernel.setArg(idx++, *(image->opencl_image()));
const uint32_t kwg_size =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(b2f_kernel));
const std::vector<uint32_t> lws = {16, kwg_size / 16};
cl::Event event;
cl_int error;
if (runtime->IsNonUniformWorkgroupsSupported()) {
error = runtime->command_queue().enqueueNDRangeKernel(
b2f_kernel, cl::NullRange, cl::NDRange(gws[0], gws[1]),
cl::NDRange(lws[0], lws[1]), nullptr, &event);
} else {
std::vector<uint32_t> roundup_gws(lws.size());
for (size_t i = 0; i < lws.size(); ++i) {
roundup_gws[i] = RoundUp(gws[i], lws[i]);
}
error = runtime->command_queue().enqueueNDRangeKernel(
b2f_kernel, cl::NullRange, cl::NDRange(roundup_gws[0], roundup_gws[1]),
cl::NDRange(lws[0], lws[1]), nullptr, &event);
}
MACE_CHECK_CL_SUCCESS(error);
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_error_->Map(nullptr);
char *kerror_code = kernel_error_->mutable_data<char>();
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
if (future != nullptr) {
future->wait_fn = [runtime, event](CallStats *stats) {
event.wait();
if (stats != nullptr) {
runtime->GetCallStats(event, stats);
}
};
}
}
template struct ImageToBufferFunctor<DeviceType::GPU, float>;
template struct ImageToBufferFunctor<DeviceType::GPU, half>;
} // namespace kernels
} // namespace mace
......@@ -89,8 +89,8 @@ void PoolingFunctor<DeviceType::GPU, T>::operator()(const Tensor *input,
std::vector<uint32_t> gws;
if (!IsVecEqual(input_shape_, input->shape())) {
std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {kernels_[0], kernels_[1],
input->dim(3), input->dim(3)};
std::vector<index_t> filter_shape = {input->dim(3), input->dim(3),
kernels_[0], kernels_[1]};
std::vector<int> paddings(2);
if (paddings_.empty()) {
......
......@@ -54,7 +54,7 @@ void WinogradTransformFunctor<DeviceType::GPU, T>::operator()(
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {3, 3, 1, input_tensor->dim(3)};
std::vector<index_t> filter_shape = {1, input_tensor->dim(3), 3, 3};
std::vector<int> paddings(2);
if (paddings_.empty()) {
kernels::CalcNHWCPaddingAndOutputSize(
......
......@@ -35,7 +35,7 @@ class BufferToImageOp : public Operator<D, T> {
"buffer_type", static_cast<int>(kernels::CONV2D_FILTER)));
Tensor *output = this->Output(OUTPUT);
functor_(const_cast<Tensor *>(input_tensor), type, output, future);
functor_(input_tensor, type, output, future);
return true;
}
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void FilterBufferToImage(int iters,
int out_channel, int in_channel,
int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input",
{out_channel, in_channel, height, width});
OpDefBuilder("BufferToImage", "BufferToImageBM")
.Input("Input")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Warm-up
net.Setup(D);
for (int i = 0; i < 5; ++i) {
net.Run();
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.Run();
}
net.Sync();
}
} // namespace
#define BM_B2I_MACRO(O, I, H, W, TYPE, DEVICE) \
static void BM_B2I_##O##_##I##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * O * I * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
FilterBufferToImage<DEVICE, TYPE>(iters, O, I, H, W); \
} \
BENCHMARK(BM_B2I_##O##_##I##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_B2I(O, I, H, W) \
BM_B2I_MACRO(O, I, H, W, float, GPU); \
BM_B2I_MACRO(O, I, H, W, half, GPU);
BM_B2I(5, 3, 3, 3);
BM_B2I(5, 3, 7, 7);
BM_B2I(32, 16, 1, 1);
BM_B2I(32, 16, 3, 3);
BM_B2I(32, 16, 5, 5);
BM_B2I(32, 16, 7, 7);
BM_B2I(64, 32, 1, 1);
BM_B2I(64, 32, 3, 3);
BM_B2I(64, 32, 5, 5);
BM_B2I(64, 32, 7, 7);
BM_B2I(128, 64, 1, 1);
BM_B2I(128, 64, 3, 3);
BM_B2I(128, 32, 1, 1);
BM_B2I(128, 32, 3, 3);
BM_B2I(256, 32, 1, 1);
BM_B2I(256, 32, 3, 3);
} // namespace test
} // namespace ops
} // namespace mace
......@@ -61,7 +61,7 @@ TEST(BufferToImageTest, ArgHalfSmall) {
TestBidirectionTransform<DeviceType::GPU, half>(kernels::ARGUMENT, {11});
}
TEST(BufferToImageTest, ArgMedia) {
TEST(BufferToImageTest, ArgMedium) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::ARGUMENT, {11});
}
......@@ -84,7 +84,7 @@ TEST(BufferToImageTest, InputSmallMultipleBatchAndChannel) {
{3, 2, 3, 3});
}
TEST(BufferToImageTest, InputMedia) {
TEST(BufferToImageTest, InputMedium) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::IN_OUT_CHANNEL,
{3, 13, 17, 128});
}
......@@ -96,32 +96,62 @@ TEST(BufferToImageTest, InputLarge) {
TEST(BufferToImageTest, Filter1x1Small) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::CONV2D_FILTER,
{1, 1, 3, 5});
{5, 3, 1, 1});
}
TEST(BufferToImageTest, Filter1x1Media) {
TEST(BufferToImageTest, Filter1x1Medium) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::CONV2D_FILTER,
{1, 1, 13, 17});
{13, 17, 1, 1});
}
TEST(BufferToImageTest, Filter1x1Large) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::CONV2D_FILTER,
{1, 1, 128, 512});
{512, 128, 1, 1});
}
TEST(BufferToImageTest, Filter3x3Small) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::CONV2D_FILTER,
{3, 3, 3, 5});
{3, 5, 3, 3});
}
TEST(BufferToImageTest, Filter3x3Meida) {
TEST(BufferToImageTest, Filter3x3Medium) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::CONV2D_FILTER,
{3, 3, 13, 17});
{17, 13, 3, 3});
}
TEST(BufferToImageTest, Filter3x3Large) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::CONV2D_FILTER,
{3, 3, 128, 256});
{256, 128, 3, 3});
}
TEST(BufferToImageTest, WeightWidthSmall) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::WEIGHT_WIDTH,
{1, 3, 3, 3});
}
TEST(BufferToImageTest, WeightWidthMedium) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::WEIGHT_WIDTH,
{11, 13, 13, 17});
}
TEST(BufferToImageTest, WeightWidthLarge) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::WEIGHT_WIDTH,
{64, 128, 11, 13});
}
TEST(BufferToImageTest, WeightHeightSmall) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::WEIGHT_HEIGHT,
{2, 1, 1, 1});
}
TEST(BufferToImageTest, WeightHeightMedium) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::WEIGHT_HEIGHT,
{11, 13, 13, 17});
}
TEST(BufferToImageTest, WeightHeightLarge) {
TestBidirectionTransform<DeviceType::GPU, float>(kernels::WEIGHT_HEIGHT,
{64, 32, 11, 13});
}
namespace {
......@@ -159,7 +189,7 @@ void TestDiffTypeBidirectionTransform(const int type,
TEST(BufferToImageTest, ArgFloatToHalfSmall) {
TestDiffTypeBidirectionTransform<DeviceType::GPU, half>(kernels::ARGUMENT,
{11});
{11});
}
namespace {
......
......@@ -43,19 +43,15 @@ void Conv2d(int iters,
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
net.AddRandomInput<D, float>("Filter",
{output_channels, channels, kernel_h,
kernel_w});
net.AddRandomInput<D, float>("Bias", {output_channels});
} else if (D == DeviceType::GPU) {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
net.AddRandomInput<D, float>("Filter",
{kernel_h, kernel_w, output_channels,
channels});
net.AddRandomInput<D, float>("Bias", {output_channels});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, float>("Filter",
{output_channels, channels, kernel_h,
kernel_w});
net.AddRandomInput<D, float>("Bias", {output_channels});
if (D == DeviceType::CPU) {
OpDefBuilder("Conv2D", "Conv2dTest")
......
......@@ -33,7 +33,7 @@ void TestNHWCSimple3x3VALID() {
"Input", {1, 3, 3, 2},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, T>(
"Filter", {3, 3, 1, 2},
"Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, T>("Bias", {1}, {0.1f});
......@@ -43,13 +43,9 @@ void TestNHWCSimple3x3VALID() {
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
......@@ -104,9 +100,8 @@ void TestNHWCSimple3x3SAME() {
"Input", {1, 3, 3, 2},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, T>(
"Filter", {3, 3, 1, 2},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
"Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, T>("Bias", {1}, {0.1f});
......@@ -115,13 +110,9 @@ void TestNHWCSimple3x3SAME() {
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
......@@ -191,9 +182,8 @@ void TestNHWCSimple3x3WithoutBias() {
"Input", {1, 3, 3, 2},
{1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, T>(
"Filter", {3, 3, 1, 2},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
"Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
if (D == DeviceType::CPU) {
......@@ -201,13 +191,9 @@ void TestNHWCSimple3x3WithoutBias() {
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
......@@ -272,10 +258,11 @@ void TestNHWCCombined3x3() {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, T>(
"Filter", {3, 3, 2, 2},
{1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f,
1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f,
1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f, 1.0f, 1.0f, 0.5f, 0.5f});
"Filter", {2, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f,
0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f, 0.5f});
net.AddInputFromArray<D, T>("Bias", {2}, {0.1f, 0.2f});
if (D == DeviceType::CPU) {
......@@ -283,13 +270,9 @@ void TestNHWCCombined3x3() {
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2DTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {2, 2})
......@@ -356,9 +339,8 @@ void TestFusedNHWCSimple3x3VALID() {
{-1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1});
net.AddInputFromArray<D, float>(
"Filter", {3, 3, 1, 2},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
"Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
net.AddInputFromArray<D, float>("Bias", {1}, {-0.1f});
......@@ -367,13 +349,9 @@ void TestFusedNHWCSimple3x3VALID() {
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
......@@ -431,9 +409,8 @@ void TestFusedNHWCSimple3x3WithoutBias() {
-1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1});
net.AddInputFromArray<D, float>(
"Filter", {3, 3, 1, 2},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
"Filter", {1, 2, 3, 3},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f,
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f});
if (D == DeviceType::CPU) {
......@@ -441,13 +418,9 @@ void TestFusedNHWCSimple3x3WithoutBias() {
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2DTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
......@@ -523,7 +496,7 @@ void TestConv1x1() {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
net.AddInputFromArray<D, float>(
"Filter", {1, 1, 2, 5},
"Filter", {2, 5, 1, 1},
{1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 2.0f, 2.0f});
net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f});
......@@ -532,13 +505,9 @@ void TestConv1x1() {
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2DTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
......@@ -615,21 +584,17 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape,
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
"Filter", {output_channels, input_channels, kernel_h, kernel_w});
net.AddRandomInput<D, T>("Bias", {output_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
// Construct graph
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w})
......@@ -733,7 +698,7 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>(
"Input", {batch, height, width, input_channels}, float_input_data);
net.AddInputFromArray<D, float>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels},
"Filter", {output_channels, input_channels, kernel_h, kernel_w},
float_filter_data);
net.AddInputFromArray<D, float>("Bias", {output_channels}, float_bias_data);
......@@ -741,14 +706,10 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape,
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w})
......@@ -876,22 +837,18 @@ void TestDilationConvNxN(const std::vector<index_t> &shape,
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
"Filter", {output_channels, input_channels, kernel_h, kernel_w});
net.AddRandomInput<D, T>("Bias", {output_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
// Construct graph
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w})
......@@ -984,22 +941,17 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape,
net.AddRandomInput<D, float>("Input",
{batch, height, width, input_channels});
net.AddRandomInput<D, float>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
"Filter", {output_channels, input_channels, kernel_h, kernel_w});
net.AddRandomInput<D, float>("Bias", {output_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
// Construct graph
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w})
......@@ -1080,21 +1032,17 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape,
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
"Filter", {output_channels, input_channels, kernel_h, kernel_w});
net.AddRandomInput<D, T>("Bias", {output_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
// Construct graph
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w})
......
......@@ -43,15 +43,12 @@ static void Deconv2d(int iters,
// Add input data
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input", {batch, channels, height, width});
net.AddRandomInput<D, float>("Filter",
{output_channels, channels, kernel_h,
kernel_w});
} else {
net.AddRandomInput<D, float>("Input", {batch, height, width, channels});
net.AddRandomInput<D, float>("Filter",
{kernel_h, kernel_w, output_channels,
channels});
}
net.AddRandomInput<D, float>("Filter",
{output_channels, channels, kernel_h,
kernel_w});
if (D == DeviceType::GPU) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
......@@ -122,7 +119,7 @@ static void Deconv2d(int iters,
BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, float, GPU); \
BM_DECONV_2D_MACRO(N, C, H, W, KH, KW, S, OH, OW, P, OC, half, GPU);
BM_DECONV_2D(1, 512, 15, 15, 1, 1, 1, 15, 15, VALID, 1024);
BM_DECONV_2D(1, 128, 15, 15, 1, 1, 1, 15, 15, VALID, 256);
BM_DECONV_2D(1, 32, 60, 60, 1, 1, 1, 60, 60, VALID, 128);
BM_DECONV_2D(1, 128, 60, 60, 3, 3, 1, 62, 62, VALID, 128);
......
......@@ -24,6 +24,7 @@ namespace test {
class Deconv2dOpTest : public OpsTestBase {};
namespace {
template<DeviceType D>
void RunTestSimple(const std::vector<index_t> &input_shape,
const std::vector<float> &input_data,
......@@ -39,21 +40,25 @@ void RunTestSimple(const std::vector<index_t> &input_shape,
// Add input data
net.AddInputFromArray<D, float>("Input", input_shape, input_data);
net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data);
net.TransformDataFormat<D, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
if (D == DeviceType::GPU) {
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Filter", "FilterImage",
kernels::BufferType::CONV2D_FILTER);
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "FilterOIHW", "FilterImage",
kernels::BufferType::CONV2D_FILTER);
OpDefBuilder("Deconv2D", "Deconv2dTest")
.Input("InputImage")
.Input("FilterImage")
.Output("OutputImage")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("padding_values", padding_size)
.AddIntsArg("output_shape", output_shape)
.Finalize(net.NewOperatorDef());
.Input("InputImage")
.Input("FilterImage")
.Output("OutputImage")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("padding_values", padding_size)
.AddIntsArg("output_shape", output_shape)
.Finalize(net.NewOperatorDef());
net.RunOp(D);
......@@ -65,19 +70,15 @@ void RunTestSimple(const std::vector<index_t> &input_shape,
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
OpDefBuilder("Deconv2D", "Deconv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("padding_values", padding_size)
.AddIntsArg("output_shape", output_shape)
.Finalize(net.NewOperatorDef());
.Input("InputNCHW")
.Input("FilterOIHW")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride, stride})
.AddIntArg("padding", padding)
.AddIntsArg("padding_values", padding_size)
.AddIntsArg("output_shape", output_shape)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
......@@ -392,6 +393,7 @@ void TestNHWCSimple2x2VALID() {
1.f, 1.f, 2.f, 1.f, 1.f,
1.f, 1.f, 2.f, 1.f, 1.f});
}
} // namespace
TEST_F(Deconv2dOpTest, CPUSimple3X3PaddingSame_S1) {
TestNHWCSimple3x3SAME_S1<DeviceType::CPU>();
......@@ -451,34 +453,30 @@ TEST_F(Deconv2dOpTest, OPENCLSimple3X3PaddingValid_S2) {
namespace {
template<DeviceType D, typename T>
void TestComplexDeconvNxNS12(const std::vector<int> &shape,
void TestComplexDeconvNxNS12(const int batch,
const std::vector<int> &shape,
const int stride) {
testing::internal::LogToStderr();
auto func = [&](int kernel_h, int kernel_w, int stride_h, int stride_w,
Padding type, int padding) {
// generate random input
static unsigned int seed = time(NULL);
int batch = 3 + (rand_r(&seed) % 10);
int height = shape[0];
int width = shape[1];
int input_channels = shape[2] + (rand_r(&seed) % 10);
int output_channels = shape[3] + (rand_r(&seed) % 10);
int input_channels = shape[2];
int output_channels = shape[3];
OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, input_channels});
net.AddRandomInput<D, T>(
"Filter", {kernel_h, kernel_w, output_channels, input_channels});
"Filter", {output_channels, input_channels, kernel_h, kernel_w});
net.AddRandomInput<D, T>("Bias", {output_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWOI,
"FilterOIHW",
OIHW);
int out_h = 0;
int out_w = 0;
......@@ -506,7 +504,7 @@ void TestComplexDeconvNxNS12(const std::vector<int> &shape,
// Construct graph
OpDefBuilder("Deconv2D", "Deconv2dTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w})
......@@ -562,32 +560,33 @@ void TestComplexDeconvNxNS12(const std::vector<int> &shape,
func(kernel_size, kernel_size, stride, stride, SAME, -1);
func(kernel_size, kernel_size, stride, stride, VALID, 1);
func(kernel_size, kernel_size, stride, stride, VALID, 2);
func(kernel_size, kernel_size, stride, stride, VALID, 3);
func(kernel_size, kernel_size, stride, stride, VALID, 4);
}
}
} // namespace
TEST_F(Deconv2dOpTest, OPENCLAlignedDeconvNxNS12) {
TestComplexDeconvNxNS12<DeviceType::GPU, float>({32, 16, 16, 32}, 1);
TestComplexDeconvNxNS12<DeviceType::GPU, float>({32, 16, 16, 32}, 2);
TestComplexDeconvNxNS12<DeviceType::GPU, float>({33, 17, 16, 32}, 1);
TestComplexDeconvNxNS12<DeviceType::GPU, float>({33, 17, 16, 32}, 2);
TestComplexDeconvNxNS12<DeviceType::GPU, float>(1, {32, 16, 16, 32}, 1);
TestComplexDeconvNxNS12<DeviceType::GPU, float>(1, {32, 16, 16, 32}, 2);
}
TEST_F(Deconv2dOpTest, OPENCLAlignedDeconvNxNS34) {
TestComplexDeconvNxNS12<DeviceType::GPU, float>({32, 16, 16, 32}, 3);
TestComplexDeconvNxNS12<DeviceType::GPU, float>({32, 16, 16, 32}, 4);
TestComplexDeconvNxNS12<DeviceType::GPU, float>(1, {32, 16, 16, 32}, 3);
TestComplexDeconvNxNS12<DeviceType::GPU, float>(1, {32, 16, 16, 32}, 4);
}
TEST_F(Deconv2dOpTest, OPENCLUnalignedDeconvNxNS12) {
TestComplexDeconvNxNS12<DeviceType::GPU, float>({17, 113, 5, 7}, 1);
TestComplexDeconvNxNS12<DeviceType::GPU, float>({17, 113, 5, 7}, 2);
TestComplexDeconvNxNS12<DeviceType::GPU, float>(1, {17, 113, 5, 7}, 1);
TestComplexDeconvNxNS12<DeviceType::GPU, float>(1, {17, 113, 5, 7}, 2);
}
TEST_F(Deconv2dOpTest, OPENCLUnalignedDeconvNxNS34) {
TestComplexDeconvNxNS12<DeviceType::GPU, float>({17, 113, 5, 7}, 3);
TestComplexDeconvNxNS12<DeviceType::GPU, float>({17, 113, 5, 7}, 4);
TestComplexDeconvNxNS12<DeviceType::GPU, float>(1, {17, 113, 5, 7}, 3);
TestComplexDeconvNxNS12<DeviceType::GPU, float>(1, {17, 113, 5, 7}, 4);
}
TEST_F(Deconv2dOpTest, OPENCLUnalignedDeconvNxNMultiBatch) {
TestComplexDeconvNxNS12<DeviceType::GPU, float>(3, {17, 13, 5, 7}, 1);
TestComplexDeconvNxNS12<DeviceType::GPU, float>(5, {17, 13, 5, 7}, 2);
}
} // namespace test
......
......@@ -43,18 +43,15 @@ void DepthwiseConv2d(int iters,
if (D == DeviceType::CPU) {
net.AddRandomInput<D, float>("Input",
{batch, input_channels, height, width});
net.AddRandomInput<D, float>(
"Filter", {multiplier, input_channels, kernel_h, kernel_w});
net.AddRandomInput<D, float>("Bias", {input_channels * multiplier});
} else if (D == DeviceType::GPU) {
net.AddRandomInput<D, float>("Input",
{batch, height, width, input_channels});
net.AddRandomInput<D, float>(
"Filter", {kernel_h, kernel_w, input_channels, multiplier});
net.AddRandomInput<D, float>("Bias", {input_channels * multiplier});
} else {
MACE_NOT_IMPLEMENTED;
}
net.AddRandomInput<D, float>(
"Filter", {multiplier, input_channels, kernel_h, kernel_w});
net.AddRandomInput<D, float>("Bias", {input_channels * multiplier});
if (D == DeviceType::CPU) {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2dTest")
......
......@@ -33,20 +33,16 @@ void SimpleValidTest() {
"Input", {1, 3, 3, 2},
{1, 2, 2, 4, 3, 6, 4, 8, 5, 10, 6, 12, 7, 14, 8, 16, 9, 18});
net.AddInputFromArray<D, float>(
"Filter", {2, 2, 2, 1}, {1.0f, 2.0f, 2.0f, 4.0f, 3.0f, 6.0f, 4.0f, 8.0f});
"Filter", {1, 2, 2, 2}, {1.0f, 2.0f, 3.0f, 4.0f, 2.0f, 4.0f, 6.0f, 8.0f});
net.AddInputFromArray<D, float>("Bias", {2}, {.1f, .2f});
if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWIO,
"FilterOIHW",
OIHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {1, 1})
......@@ -127,10 +123,10 @@ void ComplexValidTest(index_t batch, index_t channel, index_t height,
net.AddInputFromArray<D, float>("Input", {batch, height, width, channel},
input_data);
std::vector<float> filter_data(kernel * kernel * channel * multiplier);
GenerateRandomRealTypeData({kernel, kernel, channel, multiplier},
GenerateRandomRealTypeData({multiplier, channel, kernel, kernel},
&filter_data);
net.AddInputFromArray<D, float>("Filter",
{kernel, kernel, channel, multiplier},
{multiplier, channel, kernel, kernel},
filter_data);
std::vector<float> bias_data(channel * multiplier);
GenerateRandomRealTypeData({channel * multiplier}, &bias_data);
......@@ -142,13 +138,9 @@ void ComplexValidTest(index_t batch, index_t channel, index_t height,
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWIO,
"FilterOIHW",
OIHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride, stride})
......@@ -214,7 +206,7 @@ void ComplexValidTest(index_t batch, index_t channel, index_t height,
index_t in_offset =
((b * height + ih) * width + iw) * channel + c;
index_t filter_offset =
(((kh * kernel) + kw) * channel + c) * multiplier + o;
((o * channel + c) * kernel + kh) * kernel + kw;
sum += input_data[in_offset] * filter_data[filter_offset];
}
}
......@@ -275,22 +267,18 @@ void TestNxNS12(const index_t height, const index_t width) {
{batch, height, width,
input_channels});
net.AddRandomInput<DeviceType::GPU, float>(
"Filter", {kernel_h, kernel_w, input_channels, multiplier});
"Filter", {multiplier, input_channels, kernel_h, kernel_w});
net.AddRandomInput<DeviceType::GPU, float>("Bias",
{multiplier
* input_channels});
{multiplier
* input_channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Filter",
HWIO,
"FilterOIHW",
OIHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW")
.Input("FilterOIHW")
.Input("Filter")
.Input("Bias")
.Output("OutputNCHW")
.AddIntsArg("strides", {stride_h, stride_w})
......@@ -336,7 +324,6 @@ void TestNxNS12(const index_t height, const index_t width) {
"DeviceOutput",
kernels::BufferType::IN_OUT_CHANNEL);
// Check
if (DataTypeToEnum<T>::value == DT_FLOAT) {
ExpectTensorNear<float>(expected, *net.GetOutput("DeviceOutput"),
......
......@@ -28,33 +28,42 @@ class FullyConnectedOp : public Operator<D, T> {
public:
FullyConnectedOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetSingleArgument<int>(
"weight_type",
// TODO(liuqi): 8 is stand for kernels::WEIGHT_WIDTH
8 /*static_cast<int>(kernels::WEIGHT_WIDTH)*/),
kernels::StringToActivationType(
functor_(kernels::StringToActivationType(
OperatorBase::GetSingleArgument<std::string>("activation",
"NOOP")),
OperatorBase::GetSingleArgument<float>("max_limit", 0.0f)) {}
bool Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
const Tensor *weight = this->Input(WEIGHT);
const Tensor *weight = this->Input(WEIGHT); // OIHW
const Tensor *bias = this->InputSize() >= 3 ? this->Input(BIAS) : nullptr;
Tensor *output = this->Output(OUTPUT);
const index_t input_size = input->dim(1) * input->dim(2) * input->dim(3);
MACE_CHECK(input_size == weight->dim(1) && weight->dim(0) == bias->dim(0),
"The size of Input: ",
input_size,
" Weight: ",
weight->dim(1),
",",
weight->dim(
0),
" and Bias ",
bias->dim(0),
" don't match.");
if (D == DeviceType::CPU) {
MACE_CHECK(input->dim(1) == weight->dim(1)
&& input->dim(2) == weight->dim(2)
&& input->dim(3) == weight->dim(3)
&& weight->dim(0) == bias->dim(0),
"The shape of Input: ",
MakeString(input->shape()),
"The shape of Weight: ",
MakeString(weight->shape()),
" and Bias ",
bias->dim(0),
" don't match.");
} else {
MACE_CHECK(input->dim(1) == weight->dim(2)
&& input->dim(2) == weight->dim(3)
&& input->dim(3) == weight->dim(1)
&& weight->dim(0) == bias->dim(0),
"The shape of Input: ",
MakeString(input->shape()),
"The shape of Weight: ",
MakeString(weight->shape()),
" and Bias ",
bias->dim(0),
" don't match.");
}
functor_(input, weight, bias, output, future);
return true;
......
......@@ -33,12 +33,16 @@ void FCBenchmark(
// Add input data
net.AddRandomInput<D, float>("Input", {batch, height, width, channel});
net.AddRandomInput<D, float>("Weight",
{out_channel, height * width * channel});
{out_channel, channel, height, width});
net.AddRandomInput<D, float>("Bias", {out_channel});
if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("Input")
.Input("InputNCHW")
.Input("Weight")
.Input("Bias")
.Output("Output")
......
......@@ -41,7 +41,6 @@ void Simple(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>("Bias", bias_shape, bias_value);
if (D == DeviceType::CPU) {
net.Transpose2D<D, float>("Weight", "WeightTranspose");
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
......@@ -55,7 +54,7 @@ void Simple(const std::vector<index_t> &input_shape,
BufferToImage<D, float>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<D, float>(&net, "Weight", "WeightImage",
kernels::BufferType::WEIGHT_HEIGHT);
kernels::BufferType::WEIGHT_WIDTH);
BufferToImage<D, float>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
......@@ -64,7 +63,6 @@ void Simple(const std::vector<index_t> &input_shape,
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
......@@ -84,141 +82,52 @@ void Simple(const std::vector<index_t> &input_shape,
} // namespace
TEST_F(FullyConnectedOpTest, SimpleCPU) {
Simple<DeviceType::CPU>({1, 2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, {1, 8},
Simple<DeviceType::CPU>({1, 2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, {1, 2, 2, 2},
{1, 2, 3, 4, 5, 6, 7, 8}, {1}, {2}, {1, 1, 1, 1},
{206});
Simple<DeviceType::CPU>(
{1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10},
{1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 1, 2, 5},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
{2}, {2, 3}, {1, 1, 1, 2}, {387, 3853});
Simple<DeviceType::CPU>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6},
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 1, 2, 3},
{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3,
4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6},
{5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96});
}
TEST_F(FullyConnectedOpTest, SimpleCPUWithBatch) {
Simple<DeviceType::CPU>({2, 1, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, {1, 4},
Simple<DeviceType::CPU>({2, 1, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, {1, 1, 2, 2},
{1, 2, 3, 4}, {1}, {2}, {2, 1, 1, 1}, {32, 72});
}
TEST_F(FullyConnectedOpTest, SimpleOPENCL) {
Simple<DeviceType::GPU>({1, 2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, {1, 8},
{1, 2, 3, 4, 5, 6, 7, 8}, {1}, {2}, {1, 1, 1, 1},
Simple<DeviceType::GPU>({1, 2, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, {1, 2, 2, 2},
{1, 3, 5, 7, 2, 4, 6, 8}, {1}, {2}, {1, 1, 1, 1},
{206});
Simple<DeviceType::GPU>(
{1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 10},
{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100},
{1, 1, 2, 5}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}, {2, 5, 1, 2},
{1, 6, 2, 7, 3, 8, 4, 9, 5, 10, 10, 60, 20, 70, 30, 80, 40, 90, 50, 100},
{2}, {2, 3}, {1, 1, 1, 2}, {387, 3853});
Simple<DeviceType::GPU>(
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 6},
{1, 2, 3, 4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3,
4, 5, 6, 10, 20, 30, 40, 50, 60, 1, 2, 3, 4, 5, 6},
{1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}, {5, 3, 1, 2},
{1, 4, 2, 5, 3, 6, 10, 40, 20, 50, 30, 60, 1, 4, 2, 5, 3, 6,
10, 40, 20, 50, 30, 60, 1, 4, 2, 5, 3, 6},
{5}, {1, 2, 3, 4, 5}, {1, 1, 1, 5}, {92, 912, 94, 914, 96});
}
TEST_F(FullyConnectedOpTest, SimpleGPUWithBatch) {
Simple<DeviceType::GPU>({2, 1, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, {1, 4},
{1, 2, 3, 4}, {1}, {2}, {2, 1, 1, 1}, {32, 72});
}
namespace {
template<typename T>
void Complex(const index_t batch,
const index_t height,
const index_t width,
const index_t channels,
const index_t out_channel) {
srand(time(NULL));
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<DeviceType::GPU, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::GPU, float>(
"Weight", {out_channel, height * width * channels});
net.AddRandomInput<DeviceType::GPU, float>("Bias", {out_channel});
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("Input")
.Input("Weight")
.Input("Bias")
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// run cpu
net.RunOp();
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC);
// Check
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
// Run on opencl
BufferToImage<DeviceType::GPU, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<DeviceType::GPU, T>(&net, "Weight", "WeightImage",
kernels::BufferType::WEIGHT_HEIGHT);
BufferToImage<DeviceType::GPU, float>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("InputImage")
.Input("WeightImage")
.Input("BiasImage")
.Output("OutputImage")
.AddIntArg("weight_type", kernels::BufferType::WEIGHT_HEIGHT)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on opencl
net.RunOp(DeviceType::GPU);
ImageToBuffer<DeviceType::GPU, float>(&net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT_CHANNEL);
if (DataTypeToEnum<T>::value == DataType::DT_HALF) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"),
1e-1, 1e-1);
} else {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"),
1e-5, 1e-4);
}
}
} // namespace
TEST_F(FullyConnectedOpTest, OPENCLAlignedWithoutBatch) {
Complex<float>(1, 16, 16, 32, 16);
Complex<float>(1, 16, 32, 32, 32);
}
TEST_F(FullyConnectedOpTest, OPENCLUnAlignedWithoutBatch) {
Complex<float>(1, 13, 11, 11, 17);
Complex<float>(1, 23, 29, 23, 113);
}
TEST_F(FullyConnectedOpTest, OPENCLUnAlignedWithBatch) {
Complex<float>(16, 11, 13, 23, 17);
Complex<float>(31, 13, 11, 29, 113);
}
TEST_F(FullyConnectedOpTest, OPENCLHalfAlignedWithoutBatch) {
Complex<half>(1, 16, 16, 32, 16);
Complex<half>(1, 16, 32, 32, 32);
}
TEST_F(FullyConnectedOpTest, OPENCLHalfUnAlignedWithBatch) {
Complex<half>(2, 11, 13, 61, 17);
Complex<half>(16, 13, 12, 31, 113);
Complex<half>(31, 21, 11, 23, 103);
Simple<DeviceType::GPU>({2, 1, 2, 2}, {1, 2, 3, 4, 5, 6, 7, 8}, {1, 2, 1, 2},
{1, 3, 2, 4}, {1}, {2}, {2, 1, 1, 1}, {32, 72});
}
namespace {
template<typename T>
void TestWXFormat(const index_t batch,
const index_t height,
const index_t width,
const index_t channels,
const index_t out_channel) {
void Random(const index_t batch,
const index_t height,
const index_t width,
const index_t channels,
const index_t out_channel) {
srand(time(NULL));
// Construct graph
......@@ -228,11 +137,15 @@ void TestWXFormat(const index_t batch,
net.AddRandomInput<DeviceType::GPU, float>(
"Input", {batch, height, width, channels});
net.AddRandomInput<DeviceType::GPU, float>(
"Weight", {out_channel, height * width * channels});
"Weight", {out_channel, channels, height, width});
net.AddRandomInput<DeviceType::GPU, float>("Bias", {out_channel});
net.TransformDataFormat<DeviceType::CPU, float>("Input",
NHWC,
"InputNCHW",
NCHW);
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("Input")
.Input("InputNCHW")
.Input("Weight")
.Input("Bias")
.Output("OutputNCHW")
......@@ -249,11 +162,11 @@ void TestWXFormat(const index_t batch,
// Run on opencl
BufferToImage<DeviceType::GPU, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
kernels::BufferType::IN_OUT_CHANNEL);
BufferToImage<DeviceType::GPU, T>(&net, "Weight", "WeightImage",
kernels::BufferType::WEIGHT_WIDTH);
kernels::BufferType::WEIGHT_WIDTH);
BufferToImage<DeviceType::GPU, T>(&net, "Bias", "BiasImage",
kernels::BufferType::ARGUMENT);
kernels::BufferType::ARGUMENT);
OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("InputImage")
......@@ -267,7 +180,7 @@ void TestWXFormat(const index_t batch,
net.RunOp(DeviceType::GPU);
ImageToBuffer<DeviceType::GPU, float>(&net, "OutputImage", "OPENCLOutput",
kernels::BufferType::IN_OUT_CHANNEL);
kernels::BufferType::IN_OUT_CHANNEL);
if (DataTypeToEnum<T>::value == DataType::DT_HALF) {
ExpectTensorNear<float>(expected, *net.GetOutput("OPENCLOutput"),
1e-1, 1e-1);
......@@ -278,22 +191,31 @@ void TestWXFormat(const index_t batch,
}
} // namespace
TEST_F(FullyConnectedOpTest, OPENCLWidthFormatAligned) {
TestWXFormat<float>(1, 7, 7, 32, 16);
TestWXFormat<float>(1, 7, 7, 512, 128);
TestWXFormat<float>(1, 1, 1, 2048, 1024);
TEST_F(FullyConnectedOpTest, ComplexAligned) {
Random<float>(1, 16, 16, 32, 16);
Random<float>(1, 7, 7, 32, 16);
Random<float>(1, 7, 7, 512, 128);
Random<float>(1, 1, 1, 2048, 1024);
}
TEST_F(FullyConnectedOpTest, ComplexUnAlignedWithoutBatch) {
Random<float>(1, 13, 11, 11, 17);
Random<float>(1, 23, 29, 23, 113);
Random<float>(1, 14, 14, 13, 23);
}
TEST_F(FullyConnectedOpTest, OPENCLWidthFormatMultiBatch) {
TestWXFormat<float>(11, 7, 7, 32, 16);
TestWXFormat<float>(5, 7, 7, 512, 128);
TestWXFormat<float>(3, 1, 1, 2048, 1024);
TEST_F(FullyConnectedOpTest, ComplexMultiBatch) {
Random<float>(11, 7, 7, 32, 16);
Random<float>(5, 7, 7, 512, 128);
Random<float>(3, 1, 1, 2048, 1024);
Random<float>(7, 14, 14, 13, 23);
}
TEST_F(FullyConnectedOpTest, OPENCLHalfWidthFormatAligned) {
TestWXFormat<half>(1, 2, 2, 512, 2);
TestWXFormat<half>(1, 11, 11, 32, 16);
TestWXFormat<half>(1, 16, 32, 32, 32);
TEST_F(FullyConnectedOpTest, ComplexHalfWidthFormatAligned) {
Random<half>(1, 2, 2, 512, 2);
Random<half>(1, 11, 11, 32, 16);
Random<half>(1, 16, 32, 32, 32);
Random<half>(1, 14, 14, 13, 23);
}
} // namespace test
......
......@@ -16,7 +16,7 @@
#define MACE_OPS_IMAGE_TO_BUFFER_H_
#include "mace/core/operator.h"
#include "mace/kernels/buffer_to_image.h"
#include "mace/kernels/image_to_buffer.h"
namespace mace {
namespace ops {
......@@ -25,21 +25,21 @@ template <DeviceType D, typename T>
class ImageToBufferOp : public Operator<D, T> {
public:
ImageToBufferOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), functor_(true) {}
: Operator<D, T>(op_def, ws) {}
bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
const Tensor *input = this->Input(INPUT);
Tensor *output = this->Output(OUTPUT);
kernels::BufferType type =
static_cast<kernels::BufferType>(OperatorBase::GetSingleArgument<int>(
"buffer_type", static_cast<int>(kernels::CONV2D_FILTER)));
functor_(output, type, const_cast<Tensor *>(input_tensor), future);
functor_(input, type, output, future);
return true;
}
private:
kernels::BufferToImageFunctor<D, T> functor_;
kernels::ImageToBufferFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT);
......
......@@ -59,11 +59,8 @@ void WinogradConvolution(const index_t batch,
// Construct graph
OpsTestNet net;
// Add input data
std::vector<float> filter_data;
std::vector<index_t> filter_shape = {3, 3, out_channels, in_channels};
GenerateRandomRealTypeData<float>(filter_shape, &filter_data);
net.AddRandomInput<D, float>("Input", {batch, height, width, in_channels});
net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data);
net.AddRandomInput<D, float>("Filter", {out_channels, in_channels, 3, 3});
net.AddRandomInput<D, T>("Bias", {out_channels});
BufferToImage<D, T>(&net, "Input", "InputImage",
......@@ -79,12 +76,13 @@ void WinogradConvolution(const index_t batch,
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
net.RunOp(D);
// Transfer output
ImageToBuffer<D, T>(&net, "OutputImage", "ConvOutput",
ImageToBuffer<D, float>(&net, "OutputImage", "ConvOutput",
kernels::BufferType::IN_OUT_CHANNEL);
Tensor expected;
expected.Copy(*net.GetOutput("ConvOutput"));
......@@ -92,11 +90,7 @@ void WinogradConvolution(const index_t batch,
// Winograd convolution
// transform filter
std::vector<float> wino_filter_data;
TransposeFilter(filter_data, filter_shape, &wino_filter_data);
net.AddInputFromArray<D, float>(
"WinoFilterData", {out_channels, in_channels, 3, 3}, wino_filter_data);
BufferToImage<D, T>(&net, "WinoFilterData", "WinoFilter",
BufferToImage<D, T>(&net, "Filter", "WinoFilter",
kernels::BufferType::WINOGRAD_FILTER);
// transform input
......@@ -128,6 +122,7 @@ void WinogradConvolution(const index_t batch,
.AddIntArg("height", output_shape[1])
.AddIntArg("width", output_shape[2])
.Output("WinoOutputImage")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
// Run on opencl
......@@ -180,12 +175,9 @@ void WinogradConvolutionWithPad(const index_t batch,
// Construct graph
OpsTestNet net;
// Add input data
std::vector<float> filter_data;
std::vector<index_t> filter_shape = {3, 3, out_channels, in_channels};
GenerateRandomRealTypeData<float>(filter_shape, &filter_data);
net.AddRandomInput<D, float>("Input", {batch, height, width, in_channels});
net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data);
net.AddRandomInput<D, T>("Bias", {out_channels});
net.AddRandomInput<D, float>("Filter", {out_channels, in_channels, 3, 3});
net.AddRandomInput<D, float>("Bias", {out_channels});
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
......@@ -200,12 +192,13 @@ void WinogradConvolutionWithPad(const index_t batch,
.AddIntsArg("strides", {1, 1})
.AddIntsArg("padding_values", {padding, padding})
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
net.RunOp(D);
// Transfer output
ImageToBuffer<D, T>(&net, "OutputImage", "ConvOutput",
ImageToBuffer<D, float>(&net, "OutputImage", "ConvOutput",
kernels::BufferType::IN_OUT_CHANNEL);
Tensor expected;
expected.Copy(*net.GetOutput("ConvOutput"));
......@@ -213,11 +206,7 @@ void WinogradConvolutionWithPad(const index_t batch,
// Winograd convolution
// transform filter
std::vector<float> wino_filter_data;
TransposeFilter(filter_data, filter_shape, &wino_filter_data);
net.AddInputFromArray<D, float>(
"WinoFilterData", {out_channels, in_channels, 3, 3}, wino_filter_data);
BufferToImage<D, T>(&net, "WinoFilterData", "WinoFilter",
BufferToImage<D, T>(&net, "Filter", "WinoFilter",
kernels::BufferType::WINOGRAD_FILTER);
// transform input
......@@ -248,6 +237,7 @@ void WinogradConvolutionWithPad(const index_t batch,
.AddIntArg("batch", batch)
.AddIntArg("height", output_shape[1])
.AddIntArg("width", output_shape[2])
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Output("WinoOutputImage")
.Finalize(net.NewOperatorDef());
......@@ -267,6 +257,27 @@ void WinogradConvolutionWithPad(const index_t batch,
}
} // namespace
TEST_F(WinogradConvlutionTest, AlignedConvolutionWithPad) {
WinogradConvolutionWithPad<DeviceType::GPU, float>(1, 32, 32, 32, 16,
1);
WinogradConvolutionWithPad<DeviceType::GPU, half>(1, 32, 32, 32, 16,
2);
}
TEST_F(WinogradConvlutionTest, UnAlignedConvolutionWithPad) {
WinogradConvolutionWithPad<DeviceType::GPU, float>(1, 61, 67, 31, 37,
1);
WinogradConvolutionWithPad<DeviceType::GPU, half>(1, 61, 67, 37, 31,
2);
}
TEST_F(WinogradConvlutionTest, BatchConvolutionWithPad) {
WinogradConvolutionWithPad<DeviceType::GPU, float>(3, 64, 64, 32, 32,
1);
WinogradConvolutionWithPad<DeviceType::GPU, half>(5, 61, 67, 37, 31,
2);
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -16,6 +16,7 @@ import argparse
import sys
import hashlib
import os.path
import copy
from mace.proto import mace_pb2
from mace.python.tools import tf_dsp_converter_lib
......@@ -25,6 +26,7 @@ from mace.python.tools.converter_tool import base_converter as cvt
from mace.python.tools.converter_tool import tensorflow_converter
from mace.python.tools.converter_tool import caffe_converter
from mace.python.tools.converter_tool import transformer
from mace.python.tools.convert_util import mace_check
# ./bazel-bin/mace/python/tools/tf_converter --model_file quantized_test.pb \
......@@ -34,11 +36,14 @@ from mace.python.tools.converter_tool import transformer
FLAGS = None
data_type_map = {'DT_HALF': mace_pb2.DT_HALF,
'DT_FLOAT': mace_pb2.DT_FLOAT}
device_type_map = {'cpu': mace_pb2.CPU,
'gpu': mace_pb2.GPU,
'dsp': mace_pb2.HEXAGON}
device_data_type_map = {
mace_pb2.CPU: mace_pb2.DT_FLOAT,
mace_pb2.GPU: mace_pb2.DT_HALF,
mace_pb2.HEXAGON: mace_pb2.DT_UINT8
}
def file_checksum(fname):
......@@ -81,7 +86,7 @@ def main(unused_args):
if FLAGS.platform not in ['tensorflow', 'caffe']:
print ("platform %s is not supported." % FLAGS.platform)
sys.exit(-1)
if FLAGS.runtime not in ['cpu', 'gpu', 'dsp']:
if FLAGS.runtime not in ['cpu', 'gpu', 'dsp', '']:
print ("runtime %s is not supported." % FLAGS.runtime)
sys.exit(-1)
......@@ -95,8 +100,6 @@ def main(unused_args):
sys.exit(-1)
else:
option = cvt.ConverterOption()
option.data_type = data_type_map[FLAGS.data_type]
option.device = device_type_map[FLAGS.runtime]
option.winograd_enabled = bool(FLAGS.winograd)
input_node_names = FLAGS.input_node.split(',')
......@@ -117,8 +120,8 @@ def main(unused_args):
print("Convert model to mace model.")
if FLAGS.platform == 'tensorflow':
converter = tensorflow_converter.TensorflowConverter(option,
FLAGS.model_file) # noqa
converter = tensorflow_converter.TensorflowConverter(
option, FLAGS.model_file)
elif FLAGS.platform == 'caffe':
converter = caffe_converter.CaffeConverter(option,
FLAGS.model_file,
......@@ -126,16 +129,49 @@ def main(unused_args):
output_graph_def = converter.run()
print("Transform model to one that can better run on device.")
# TODO(liuqi/liyin): transform gpu/cpu and merge their ops
mace_transformer = transformer.Transformer(option, output_graph_def)
output_graph_def = mace_transformer.run()
if not FLAGS.runtime:
cpu_graph_def = copy.deepcopy(output_graph_def)
option.device = mace_pb2.CPU
option.data_type = device_data_type_map[mace_pb2.CPU]
option.disable_transpose_filters()
mace_cpu_transformer = transformer.Transformer(
option, cpu_graph_def)
cpu_graph_def = mace_cpu_transformer.run()
print "start optimize cpu memory."
memory_optimizer.optimize_cpu_memory(cpu_graph_def)
print "CPU memory optimization done."
print "start optimize memory."
if FLAGS.runtime == 'gpu':
memory_optimizer.optimize_gpu_memory(output_graph_def)
elif FLAGS.runtime == 'cpu':
memory_optimizer.optimize_cpu_memory(output_graph_def)
print "Memory optimization done."
option.device = mace_pb2.GPU
option.data_type = device_data_type_map[mace_pb2.GPU]
option.enable_transpose_filters()
mace_gpu_transformer = transformer.Transformer(
option, output_graph_def)
output_gpu_graph_def = mace_gpu_transformer.run()
print "start optimize gpu memory."
memory_optimizer.optimize_gpu_memory(output_gpu_graph_def)
print "GPU memory optimization done."
print "Merge cpu and gpu ops together"
output_graph_def.op.extend(cpu_graph_def.op)
output_graph_def.mem_arena.mem_block.extend(
cpu_graph_def.mem_arena.mem_block)
print "Merge done"
else:
option.device = device_type_map[FLAGS.runtime]
option.data_type = device_data_type_map[option.device]
mace_transformer = transformer.Transformer(
option, output_graph_def)
output_graph_def = mace_transformer.run()
print "start optimize memory."
if FLAGS.runtime == 'gpu':
memory_optimizer.optimize_gpu_memory(output_graph_def)
elif FLAGS.runtime == 'cpu':
memory_optimizer.optimize_cpu_memory(output_graph_def)
else:
mace_check(False, "runtime only support [gpu|cpu|dsp]")
print "Memory optimization done."
if FLAGS.output_type == 'source':
source_converter_lib.convert_to_source(
......@@ -188,7 +224,7 @@ def parse_args():
default="",
help="File to save the output graph to.")
parser.add_argument(
"--runtime", type=str, default="cpu", help="Runtime: cpu/gpu/dsp")
"--runtime", type=str, default="", help="Runtime: cpu/gpu/dsp")
parser.add_argument(
"--input_node",
type=str,
......@@ -196,11 +232,6 @@ def parse_args():
help="e.g., input_node")
parser.add_argument(
"--output_node", type=str, default="softmax", help="e.g., softmax")
parser.add_argument(
"--data_type",
type=str,
default='DT_FLOAT',
help="e.g., DT_HALF/DT_FLOAT")
parser.add_argument(
"--output_type", type=str, default="pb", help="output type: source/pb")
parser.add_argument(
......
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from enum import Enum
from mace.proto import mace_pb2
......@@ -117,6 +132,27 @@ class MaceKeyword(object):
mace_axis_str = 'axis'
mace_shape_str = 'shape'
mace_winograd_filter_transformed = 'is_filter_transformed'
mace_device = 'device'
class TransformerRule(Enum):
REMOVE_IDENTITY_OP = 0
TRANSFORM_GLOBAL_POOLING = 1
FOLD_SOFTMAX = 2
FOLD_BATCHNORM = 3,
FOLD_CONV_AND_BN = 4,
FOLD_DEPTHWISE_CONV_AND_BN = 5,
TRANSFORM_GPU_WINOGRAD = 6,
TRANSFORM_ADD_TO_BIASADD = 7,
FOLD_BIASADD = 8,
FOLD_ACTIVATION = 9,
TRANSPOSE_FILTERS = 10,
RESHAPE_FC_WEIGHT = 11,
TRANSPOSE_DATA_FORMAT = 12,
TRANSFORM_GLOBAL_CONV_TO_FC = 13,
TRANSFORM_BUFFER_IMAGE = 14,
ADD_DEVICE_AND_DATA_TYPE = 15,
SORT_BY_EXECUTION = 16
class ConverterInterface(object):
......@@ -162,6 +198,25 @@ class ConverterOption(object):
self._data_type = mace_pb2.DT_FLOAT
self._device = mace_pb2.CPU
self._winograd_enabled = False
self._transformer_option = [
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_SOFTMAX,
TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
TransformerRule.TRANSFORM_GPU_WINOGRAD,
TransformerRule.TRANSFORM_ADD_TO_BIASADD,
TransformerRule.FOLD_BIASADD,
TransformerRule.FOLD_ACTIVATION,
TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.RESHAPE_FC_WEIGHT,
TransformerRule.TRANSPOSE_DATA_FORMAT,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
TransformerRule.TRANSFORM_BUFFER_IMAGE,
TransformerRule.ADD_DEVICE_AND_DATA_TYPE,
TransformerRule.SORT_BY_EXECUTION,
]
@property
def input_nodes(self):
......@@ -183,6 +238,10 @@ class ConverterOption(object):
def winograd_enabled(self):
return self._winograd_enabled
@property
def transformer_option(self):
return self._transformer_option
@input_nodes.setter
def input_nodes(self, input_nodes):
for node in input_nodes:
......@@ -211,6 +270,14 @@ class ConverterOption(object):
def winograd_enabled(self, winograd_enabled):
self._winograd_enabled = winograd_enabled
def disable_transpose_filters(self):
if TransformerRule.TRANSPOSE_FILTERS in self._transformer_option:
self._transformer_option.remove(TransformerRule.TRANSPOSE_FILTERS)
def enable_transpose_filters(self):
if TransformerRule.TRANSPOSE_FILTERS not in self._transformer_option:
self._transformer_option.append(TransformerRule.TRANSPOSE_FILTERS)
class ConverterUtil(object):
@staticmethod
......
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import google.protobuf.text_format
......@@ -325,10 +340,6 @@ class CaffeConverter(base_converter.ConverterInterface):
op.input.extend(caffe_op.layer.bottom)
op.output.extend(caffe_op.layer.top)
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
return op
......
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
......
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import numpy as np
import tensorflow as tf
......@@ -197,11 +212,6 @@ class TensorflowConverter(base_converter.ConverterInterface):
for tf_output in tf_op.outputs:
output_shape = op.output_shape.add()
output_shape.dims.extend(tf_output.shape.as_list())
op.output_type.append(self._option.data_type)
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
......@@ -289,7 +299,6 @@ class TensorflowConverter(base_converter.ConverterInterface):
op.input.extend([scale_name, offset_name])
del op.output[1:]
del op.output_shape[1:]
del op.output_type[1:]
def convert_pooling(self, tf_op):
op = self.convert_general_op(tf_op)
......
# Copyright 2018 Xiaomi, Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import numpy as np
......@@ -11,6 +26,7 @@ from mace.python.tools.converter_tool.base_converter import FilterFormat
from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import ConverterUtil
from mace.python.tools.converter_tool.base_converter import TransformerRule
from mace.python.tools.convert_util import mace_check
OPENCL_IMAGE_MAX_SIZE = 16384
......@@ -36,23 +52,52 @@ class Transformer(base_converter.ConverterInterface):
def __init__(self, option, model):
# DO NOT reorder the following transformers
self._registered_transformers = [
self.remove_identity_op,
self.transform_global_pooling,
self.fold_softmax,
self.fold_batchnorm,
self.fold_conv_and_bn, # data_format related
self.fold_depthwise_conv_and_bn, # data_format related
self.transform_gpu_winograd, # data_format related
self.transform_add_to_biasadd,
self.fold_biasadd,
self.fold_activation,
self.transpose_filters,
self.transpose_data_format,
self.transform_global_conv_to_fc,
self.transform_buffer_image,
self.sort_by_execution,
self._registered_transformers_order = [
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_SOFTMAX,
TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
TransformerRule.TRANSFORM_GPU_WINOGRAD,
TransformerRule.TRANSFORM_ADD_TO_BIASADD,
TransformerRule.FOLD_BIASADD,
TransformerRule.FOLD_ACTIVATION,
TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.RESHAPE_FC_WEIGHT,
TransformerRule.TRANSPOSE_DATA_FORMAT,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
TransformerRule.TRANSFORM_BUFFER_IMAGE,
TransformerRule.ADD_DEVICE_AND_DATA_TYPE,
TransformerRule.SORT_BY_EXECUTION,
]
self._registered_transformers = {
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
TransformerRule.TRANSFORM_GLOBAL_POOLING:
self.transform_global_pooling,
TransformerRule.FOLD_SOFTMAX: self.fold_softmax,
TransformerRule.FOLD_BATCHNORM: self.fold_batchnorm,
TransformerRule.FOLD_CONV_AND_BN:
self.fold_conv_and_bn, # data_format related
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN:
self.fold_depthwise_conv_and_bn, # data_format related
TransformerRule.TRANSFORM_GPU_WINOGRAD:
self.transform_gpu_winograd, # data_format related
TransformerRule.TRANSFORM_ADD_TO_BIASADD:
self.transform_add_to_biasadd,
TransformerRule.FOLD_BIASADD: self.fold_biasadd,
TransformerRule.FOLD_ACTIVATION: self.fold_activation,
TransformerRule.TRANSPOSE_FILTERS: self.transpose_filters,
TransformerRule.RESHAPE_FC_WEIGHT: self.reshape_fc_weight,
TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC:
self.transform_global_conv_to_fc,
TransformerRule.TRANSFORM_BUFFER_IMAGE:
self.transform_buffer_image,
TransformerRule.ADD_DEVICE_AND_DATA_TYPE:
self.add_device_and_data_type,
TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution,
}
self._option = option
self._model = model
......@@ -67,12 +112,14 @@ class Transformer(base_converter.ConverterInterface):
self._target_data_format = DataFormat.NCHW
def run(self):
for transformer in self._registered_transformers:
while True:
self.construct_ops_and_consumers()
changed = transformer()
if not changed:
break
for key in self._registered_transformers_order:
if key in self._option.transformer_option:
transformer = self._registered_transformers[key]
while True:
self.construct_ops_and_consumers()
changed = transformer()
if not changed:
break
return self._model
......@@ -404,19 +451,16 @@ class Transformer(base_converter.ConverterInterface):
wt_output_shape.dims.extend(
[16, in_channels, wt_output_width, 1])
arg = wt_op.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
if ConverterUtil.get_arg(op,
MaceKeyword.mace_padding_str) \
is not None:
padding_arg = wt_op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_str
padding_arg.i = ConverterUtil.get_arg(op,
MaceKeyword.mace_padding_str).i # noqa
elif ConverterUtil.get_arg(op,
MaceKeyword.mace_padding_values_str) is not None: # noqa
padding_arg.i = ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_str).i
elif ConverterUtil.get_arg(
op, MaceKeyword.mace_padding_values_str)\
is not None:
padding_arg = wt_op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_values_str
padding_arg.ints.extend(ConverterUtil.get_arg(
......@@ -432,9 +476,6 @@ class Transformer(base_converter.ConverterInterface):
matmul_output_shape.dims.extend(
[16, out_channels, wt_output_width, 1])
arg = matmul_op.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
arg = matmul_op.arg.add()
arg.name = MaceKeyword.mace_winograd_filter_transformed
arg.i = 1
......@@ -451,9 +492,6 @@ class Transformer(base_converter.ConverterInterface):
iwt_output_shape = iwt_op.output_shape.add()
iwt_output_shape.dims.extend(op.output_shape[0].dims)
arg = iwt_op.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
batch_arg = iwt_op.arg.add()
batch_arg.name = 'batch'
batch_arg.i = batch
......@@ -618,10 +656,6 @@ class Transformer(base_converter.ConverterInterface):
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend([0, 3, 1, 2])
arg = op.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
for output_node in self._option.output_nodes.values():
output_name = MaceKeyword.mace_output_node_name \
+ '_' + output_node.name
......@@ -639,75 +673,43 @@ class Transformer(base_converter.ConverterInterface):
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend([0, 2, 3, 1])
arg = op.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
return False
def transpose_filters(self):
net = self._model
filter_format = self.filter_format()
# TODO(liyin/liuqi): remove this if-condition after combine cpu/gpu
if self._option.device == mace_pb2.CPU:
print("Transpose filters to OIHW")
# transpose filter to OIHW/MIHW for tensorflow (HWIO/HWIM)
if filter_format == FilterFormat.HWIO:
for op in net.op:
if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.Deconv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name:
if ConverterUtil.get_arg(op,
MaceKeyword.mace_winograd_filter_transformed) is None: # noqa
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
self.set_filter_format(FilterFormat.OIHW)
elif self._option.device == mace_pb2.GPU:
# TODO(liyin/liuqi): remove this whole logic after combine cpu/gpu
print("Transpose filters to HWOI/HWIM")
print("Transpose filters to OIHW")
# transpose filter to OIHW/MIHW for tensorflow (HWIO/HWIM)
if filter_format == FilterFormat.HWIO:
for op in net.op:
if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.Deconv2D.name \
or op.type == MaceOp.DepthwiseConv2d.name:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
# transpose filter to HWOI/HWIM for
# tensorflow and caffe (OIHW/MIHW)
if filter_format == FilterFormat.HWIO \
and (op.type == MaceOp.Conv2D.name
or op.type == MaceOp.Deconv2D.name):
filter_data = filter_data.transpose(0, 1, 3, 2)
if ConverterUtil.get_arg(
op, MaceKeyword.mace_winograd_filter_transformed)\
is None:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
elif filter_format == FilterFormat.OIHW:
if op.type == MaceOp.Conv2D.name \
or op.type == MaceOp.Deconv2D.name:
filter_data = filter_data.transpose(2, 3, 0, 1)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
elif op.type == MaceOp.DepthwiseConv2d.name:
filter_data = filter_data.transpose(2, 3, 1, 0)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
if op.type == MaceOp.FullyConnected.name:
weight = self._consts[op.input[1]]
input_shape = list(self._producer[op.input[0]]
.output_shape[0].dims)
weight_shape = [weight.dims[0]] + input_shape[1:]
# OCHW -> OHWC
weight_data = np.array(weight.float_data).reshape(
weight_shape)
weight_data = weight_data.transpose(0, 2, 3, 1)
weight.float_data[:] = weight_data.flat
self.set_filter_format(FilterFormat.HWOI)
self.set_filter_format(FilterFormat.OIHW)
return False
def reshape_fc_weight(self):
net = self._model
for op in net.op:
if op.type == MaceOp.FullyConnected.name:
weight = self._consts[op.input[1]]
# NCHW
input_shape = list(self._producer[op.input[0]]
.output_shape[0].dims)
weight_shape = [weight.dims[0]] + input_shape[1:]
del weight.dims[:]
weight.dims.extend(weight_shape)
return False
......@@ -727,9 +729,6 @@ class Transformer(base_converter.ConverterInterface):
arg = op_def.arg.add()
arg.name = MaceKeyword.mace_mode
arg.i = 0
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
op.input[input_idx] = output_name
......@@ -788,9 +787,6 @@ class Transformer(base_converter.ConverterInterface):
arg = op_def.arg.add()
arg.name = MaceKeyword.mace_buffer_type
arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
for output_node in self._option.output_nodes.values():
output_name = MaceKeyword.mace_output_node_name \
......@@ -806,9 +802,6 @@ class Transformer(base_converter.ConverterInterface):
arg = op_def.arg.add()
arg.name = MaceKeyword.mace_buffer_type
arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value
arg = op_def.arg.add()
arg.name = 'T'
arg.i = self._option.data_type
return False
......@@ -885,6 +878,19 @@ class Transformer(base_converter.ConverterInterface):
in_channels * filter_width
* filter_height][:]
def add_device_and_data_type(self):
# TODO(liuqi) add device definition in OperatorDef
net = self._model
for op in net.op:
arg = op.arg.add()
arg.name = MaceKeyword.mace_device
arg.i = self._option.device
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
data_type_arg.i = self._option.data_type
return False
def sort_dfs(self, op, visited, sorted_nodes):
visited.update([op.name])
if len(op.input) > 0:
......
......@@ -167,7 +167,6 @@ def convert_to_source(net_def, model_checksum, weight_checksum, template_dir,
tensor_info=tensor_info,
tensor=t,
tag=model_tag,
runtime=runtime,
offset=offset,
)
model_data.extend(tensor_info.data)
......
......@@ -17,8 +17,11 @@ import functools
import argparse
import sys
import six
import copy
import tensorflow as tf
from tensorflow import gfile
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import tensor_shape_pb2
# ./bazel-bin/mace/python/tools/tf_ops_stats --input model.pb
......@@ -39,6 +42,26 @@ def to_int_list(long_list):
return int_list
def add_shape_info(input_graph_def, input_nodes, input_shapes):
inputs_replaced_graph = graph_pb2.GraphDef()
for node in input_graph_def.node:
if node.name in input_nodes:
idx = input_nodes.index(node.name)
input_shape = input_shapes[idx]
print input_shape
placeholder_node = copy.deepcopy(node)
placeholder_node.attr.clear()
placeholder_node.attr['shape'].shape.dim.extend([
tensor_shape_pb2.TensorShapeProto.Dim(size=i)
for i in input_shape
])
placeholder_node.attr['dtype'].CopyFrom(node.attr['dtype'])
inputs_replaced_graph.node.extend([placeholder_node])
else:
inputs_replaced_graph.node.extend([copy.deepcopy(node)])
return inputs_replaced_graph
def main(unused_args):
if not FLAGS.input or not gfile.Exists(FLAGS.input):
print('Input graph file ' + FLAGS.input + ' does not exist!')
......@@ -49,6 +72,16 @@ def main(unused_args):
data = f.read()
input_graph_def.ParseFromString(data)
input_nodes = [x for x in FLAGS.input_tensors.split(',')]
input_shapes = []
if FLAGS.input_shapes != "":
input_shape_strs = [x for x in FLAGS.input_shapes.split(':')]
for shape_str in input_shape_strs:
input_shapes.extend([[int(x) for x in shape_str.split(',')]])
input_graph_def = add_shape_info(
input_graph_def, input_nodes, input_shapes)
with tf.Session() as session:
with session.graph.as_default() as graph:
tf.import_graph_def(input_graph_def, name='')
......@@ -79,15 +112,12 @@ def main(unused_args):
strides = to_int_list(op.get_attr('strides'))
data_format = op.get_attr('data_format')
ksize = 'Unknown'
for input in op.inputs:
input_name = input.name
if input_name.endswith('weights/read:0'):
ksize = input.shape.as_list()
break
if input_name.endswith(
'weights:0') and input_name in tensor_shapes:
ksize = tensor_shapes[input_name]
break
input = op.inputs[1]
input_name = input.name
if input_name.endswith('read:0'):
ksize = input.shape.as_list()
elif input_name in tensor_shapes:
ksize = tensor_shapes[input_name]
print(
'%s(padding=%s, strides=%s, ksize=%s, format=%s) %s => %s'
% (op.type, padding, strides, ksize, data_format,
......@@ -189,6 +219,16 @@ def parse_args():
type=str,
default='',
help='TensorFlow \'GraphDef\' file to load.')
parser.add_argument(
'--input_tensors',
type=str,
default='',
help='input tensor names split by comma.')
parser.add_argument(
'--input_shapes',
type=str,
default='',
help='input tensor shapes split by colon and comma.')
return parser.parse_known_args()
......
......@@ -55,6 +55,7 @@ void BufferToImage(const std::string &input_name,
const std::string &output_name,
const int buffer_type,
const std::vector<int> &mem_ids,
const DeviceType device_type,
NetDef *net_def,
const int mode = NetMode::NORMAL) {
OperatorDef operator_def;
......@@ -64,6 +65,7 @@ void BufferToImage(const std::string &input_name,
.Output(output_name)
.AddIntArg("buffer_type", buffer_type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.AddIntArg("mode", mode)
.Finalize(&operator_def);
......@@ -76,6 +78,7 @@ template <typename T>
void ImageToBuffer(const std::string &input_name,
const std::string &output_name,
const int buffer_type,
const DeviceType device_type,
NetDef *net_def) {
OperatorDef operator_def;
......@@ -84,6 +87,7 @@ void ImageToBuffer(const std::string &input_name,
.Output(output_name)
.AddIntArg("buffer_type", buffer_type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def);
......@@ -94,6 +98,7 @@ void Conv3x3(const std::string &input_name,
const std::string &filter_name,
const std::string &output_name,
const std::vector<int> &mem_ids,
const DeviceType device_type,
NetDef *net_def) {
OperatorDef operator_def;
ops::test::OpDefBuilder("Conv2D", "Conv2dOp")
......@@ -104,6 +109,7 @@ void Conv3x3(const std::string &input_name,
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.Finalize(&operator_def);
operator_def.set_mem_id(mem_ids);
......@@ -113,6 +119,7 @@ void Conv3x3(const std::string &input_name,
template <typename T>
void Relu(const std::string &input_name,
const std::string &output_name,
const DeviceType device_type,
NetDef *net_def) {
OperatorDef operator_def;
ops::test::OpDefBuilder("Activation", "ReluTest")
......@@ -120,6 +127,7 @@ void Relu(const std::string &input_name,
.Output(output_name)
.AddStringArg("activation", "RELU")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def);
......@@ -195,7 +203,8 @@ std::map<std::string, int> AddMemoryOptimization(
const std::vector<std::vector<int64_t>> &output_shapes,
NetDef *net_def) {
std::map<std::string, int> res;
int mem_id = 0;
// TODO(liuqi) refactor based on PB
int mem_id = 20000;
size_t input_shape_size = input_shapes.size();
uint32_t in_mem_block_x = 0;
uint32_t in_mem_block_y = 0;
......@@ -250,7 +259,7 @@ void MaceRunFunc(const int in_out_size) {
const std::vector<std::vector<int64_t>> input_shapes = {{1, 32, 32, 16}};
const std::vector<std::vector<int64_t>> output_shapes = {{1, 32, 32, 16}};
const std::vector<int64_t> filter_shape = {3, 3, 16, 16};
const std::vector<int64_t> filter_shape = {16, 16, 3, 3};
NetDef net_def;
......@@ -269,21 +278,25 @@ void MaceRunFunc(const int in_out_size) {
BufferToImage<half>(input_name, input_names[i],
mace::kernels::IN_OUT_CHANNEL,
{mem_map[input_names[i]]},
device,
&net_def);
}
BufferToImage<half>(filter_tensor_name, filter_tensor_img_name,
mace::kernels::CONV2D_FILTER, {},
mace::kernels::CONV2D_FILTER, {}, device,
&net_def, NetMode::INIT);
for (size_t i = 0; i < output_names.size(); ++i) {
Conv3x3<half>(input_names[i], filter_tensor_img_name,
output_names[i], {mem_map[output_names[i]]},
device,
&net_def);
}
for (size_t i = 0; i < output_names.size(); ++i) {
std::string output_name = MakeString("mace_output_node_",
output_names[i]);
ImageToBuffer<float>(output_names[i], output_name,
mace::kernels::IN_OUT_CHANNEL, &net_def);
mace::kernels::IN_OUT_CHANNEL,
device,
&net_def);
}
const std::string file_path ="/data/local/tmp/mace";
......
......@@ -65,6 +65,7 @@ void BufferToImage(const std::string &input_name,
const std::string &output_name,
const int buffer_type,
const std::vector<int> &mem_ids,
const DeviceType device_type,
NetDef *net_def,
const int mode = NetMode::NORMAL) {
OperatorDef operator_def;
......@@ -74,6 +75,7 @@ void BufferToImage(const std::string &input_name,
.Output(output_name)
.AddIntArg("buffer_type", buffer_type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.AddIntArg("mode", mode)
.Finalize(&operator_def);
......@@ -86,6 +88,7 @@ template <typename T>
void ImageToBuffer(const std::string &input_name,
const std::string &output_name,
const int buffer_type,
const DeviceType device_type,
NetDef *net_def) {
OperatorDef operator_def;
......@@ -94,6 +97,7 @@ void ImageToBuffer(const std::string &input_name,
.Output(output_name)
.AddIntArg("buffer_type", buffer_type)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def);
......@@ -104,6 +108,7 @@ void Conv3x3(const std::string &input_name,
const std::string &filter_name,
const std::string &output_name,
const std::vector<int> &mem_ids,
const DeviceType device_type,
NetDef *net_def) {
OperatorDef operator_def;
ops::test::OpDefBuilder("Conv2D", "Conv2dOp")
......@@ -114,6 +119,7 @@ void Conv3x3(const std::string &input_name,
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.Finalize(&operator_def);
operator_def.set_mem_id(mem_ids);
......@@ -123,6 +129,7 @@ void Conv3x3(const std::string &input_name,
template <typename T>
void Relu(const std::string &input_name,
const std::string &output_name,
const DeviceType device_type,
NetDef *net_def) {
OperatorDef operator_def;
ops::test::OpDefBuilder("Activation", "ReluTest")
......@@ -130,6 +137,7 @@ void Relu(const std::string &input_name,
.Output(output_name)
.AddStringArg("activation", "RELU")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type))
.Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def);
......@@ -205,7 +213,8 @@ std::map<std::string, int> AddMemoryOptimization(
const std::vector<std::vector<int64_t>> &output_shapes,
NetDef *net_def) {
std::map<std::string, int> res;
int mem_id = 0;
// TODO(liuqi) refactor based on PB
int mem_id = 20000;
size_t input_shape_size = input_shapes.size();
uint32_t in_mem_block_x = 0;
uint32_t in_mem_block_y = 0;
......@@ -279,21 +288,24 @@ void MaceRun(const int in_out_size,
BufferToImage<half>(input_name, input_names[i],
mace::kernels::IN_OUT_CHANNEL,
{mem_map[input_names[i]]},
device,
&net_def);
}
BufferToImage<half>(filter_tensor_name, filter_tensor_img_name,
mace::kernels::CONV2D_FILTER, {},
mace::kernels::CONV2D_FILTER, {}, device,
&net_def, NetMode::INIT);
for (size_t i = 0; i < output_names.size(); ++i) {
Conv3x3<half>(input_names[i], filter_tensor_img_name,
output_names[i], {mem_map[output_names[i]]},
&net_def);
device, &net_def);
}
for (size_t i = 0; i < output_names.size(); ++i) {
std::string output_name = MakeString("mace_output_node_",
output_names[i]);
ImageToBuffer<float>(output_names[i], output_name,
mace::kernels::IN_OUT_CHANNEL, &net_def);
mace::kernels::IN_OUT_CHANNEL,
device,
&net_def);
}
MaceEngine engine(&net_def, device, input_names, output_names);
......@@ -318,30 +330,30 @@ void MaceRun(const int in_out_size,
} // namespace
TEST_F(MaceAPITest, GPUSingleInputOutput) {
MaceRun<float>(1, {{1, 32, 32, 16}}, {{1, 32, 32, 16}}, {3, 3, 16, 16});
MaceRun<half>(1, {{1, 32, 32, 16}}, {{1, 32, 32, 16}}, {3, 3, 16, 16});
MaceRun<float>(1, {{1, 32, 32, 16}}, {{1, 32, 32, 16}}, {16, 16, 3, 3});
MaceRun<half>(1, {{1, 32, 32, 16}}, {{1, 32, 32, 16}}, {16, 16, 3, 3});
}
TEST_F(MaceAPITest, GPUMultipleInputOutput) {
MaceRun<float>(2,
{{1, 16, 32, 16}},
{{1, 16, 32, 16}},
{3, 3, 16, 16});
{16, 16, 3, 3});
MaceRun<half>(2,
{{1, 16, 32, 16}},
{{1, 16, 32, 16}},
{3, 3, 16, 16});
{16, 16, 3, 3});
}
TEST_F(MaceAPITest, GPUVariableInputShape) {
MaceRun<float>(1,
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{3, 3, 16, 16});
{16, 16, 3, 3});
MaceRun<half>(2,
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{3, 3, 16, 16});
{16, 16, 3, 3});
}
} // namespace test
} // namespace mace
......@@ -62,27 +62,23 @@ def get_target_socs(configs):
return target_socs
def get_data_and_device_type(runtime):
data_type = ""
def parse_device_type(runtime):
device_type = ""
if runtime == "dsp":
data_type = "DT_UINT8"
device_type = "HEXAGON"
elif runtime == "gpu":
data_type = "DT_HALF"
device_type = "GPU"
elif runtime == "cpu":
data_type = "DT_FLOAT"
device_type = "CPU"
return data_type, device_type
return device_type
def get_hexagon_mode(configs):
runtime_list = []
for model_name in configs["models"]:
model_runtime = configs["models"][model_name]["runtime"]
model_runtime = configs["models"][model_name].get("runtime", "")
runtime_list.append(model_runtime.lower())
global_runtime = ""
......@@ -114,7 +110,7 @@ def model_benchmark_stdout_processor(stdout,
abi,
serialno,
model_name,
runtime):
device_type):
metrics = [0] * 3
for line in stdout.split('\n'):
line = line.strip()
......@@ -138,14 +134,14 @@ def model_benchmark_stdout_processor(stdout,
f.write("model_name,device_name,soc,abi,runtime,"
"init,warmup,run_avg\n")
data_str = "{model_name},{device_name},{soc},{abi},{runtime}," \
data_str = "{model_name},{device_name},{soc},{abi},{device_type}," \
"{init},{warmup},{run_avg}\n" \
.format(
model_name=model_name,
device_name=device_name,
soc=target_soc,
abi=abi,
runtime=runtime,
device_type=device_type,
init=metrics[0],
warmup=metrics[1],
run_avg=metrics[2]
......@@ -154,8 +150,7 @@ def model_benchmark_stdout_processor(stdout,
f.write(data_str)
def tuning_run(runtime,
target_abi,
def tuning_run(target_abi,
serialno,
vlog_level,
embed_model_data,
......@@ -205,7 +200,7 @@ def tuning_run(runtime,
if running_round > 0 and FLAGS.collect_report:
model_benchmark_stdout_processor(
stdout, target_abi, serialno, model_name, runtime)
stdout, target_abi, serialno, model_name, device_type)
def build_mace_run_prod(hexagon_mode, runtime, target_abi,
......@@ -222,7 +217,7 @@ def build_mace_run_prod(hexagon_mode, runtime, target_abi,
strip = "never"
debug = True
if runtime == "gpu":
if not runtime or runtime == "gpu":
gen_opencl_and_tuning_code(target_abi, serialno, [], False)
sh_commands.bazel_build(
mace_run_target,
......@@ -234,19 +229,14 @@ def build_mace_run_prod(hexagon_mode, runtime, target_abi,
sh_commands.update_mace_run_lib(model_output_dir,
model_name, embed_model_data)
tuning_run(runtime, target_abi, serialno, vlog_level, embed_model_data,
device_type = parse_device_type("gpu")
tuning_run(target_abi, serialno, vlog_level, embed_model_data,
model_output_dir, input_nodes, output_nodes, input_shapes,
output_shapes, model_name, device_type, running_round=0,
restart_round=1, out_of_range_check=False,
phone_data_dir=phone_data_dir, tuning=tuning,
limit_opencl_kernel_time=limit_opencl_kernel_time)
tuning_run(runtime, target_abi, serialno, vlog_level, embed_model_data,
model_output_dir, input_nodes, output_nodes, input_shapes,
output_shapes, model_name, device_type, running_round=0,
restart_round=1, out_of_range_check=True,
phone_data_dir=phone_data_dir, tuning=False)
gen_opencl_and_tuning_code(target_abi, serialno, [model_output_dir],
True)
sh_commands.bazel_build(
......@@ -391,8 +381,7 @@ def parse_model_configs():
print("'platform' must be 'tensorflow' or 'caffe'")
exit(1)
for key in ["model_file_path", "model_sha256_checksum",
"runtime"]:
for key in ["model_file_path", "model_sha256_checksum"]:
value = model_config.get(key, "")
if value == "":
print("CONFIG ERROR:")
......@@ -529,6 +518,11 @@ def parse_args():
type=str,
default="",
help="Valgrind command args.")
parser.add_argument(
"--validation_runtime",
type=str,
default="cpu",
help="validation runtime.")
return parser.parse_known_args()
......@@ -541,9 +535,11 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
print '===================', model_name, '==================='
model_config = configs["models"][model_name]
input_file_list = model_config["validation_inputs_data"]
data_type, device_type = get_data_and_device_type(
model_config["runtime"])
model_runtime = model_config.get("runtime", "")
model_device_type = parse_device_type(model_runtime)
run_device_type = model_device_type
if not run_device_type:
run_device_type = parse_device_type(FLAGS.validation_runtime)
# Create model build directory
model_path_digest = md5sum(model_config["model_file_path"])
model_output_base_dir = "%s/%s/%s/%s/%s" % (
......@@ -581,7 +577,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
if FLAGS.mode == "build" or FLAGS.mode == "all":
build_mace_run_prod(hexagon_mode,
model_config["runtime"],
model_runtime,
target_abi,
serialno,
vlog_level,
......@@ -592,7 +588,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_config["input_shapes"],
model_config["output_shapes"],
model_name,
device_type,
model_device_type,
FLAGS.round,
FLAGS.restart_round,
FLAGS.tuning,
......@@ -607,8 +603,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
if FLAGS.mode == "run" or FLAGS.mode == "validate" or \
FLAGS.mode == "all":
tuning_run(model_config["runtime"],
target_abi,
tuning_run(target_abi,
serialno,
vlog_level,
embed_model_data,
......@@ -618,7 +613,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_config["input_shapes"],
model_config["output_shapes"],
model_name,
device_type,
run_device_type,
FLAGS.round,
FLAGS.restart_round,
FLAGS.out_of_range_check,
......@@ -641,7 +636,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_config["input_shapes"],
model_config["output_shapes"],
model_name,
device_type,
run_device_type,
phone_data_dir,
FLAGS.omp_num_threads,
FLAGS.cpu_affinity_policy,
......@@ -654,7 +649,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_file_path,
weight_file_path,
model_config["platform"],
model_config["runtime"],
run_device_type,
model_config["input_nodes"],
model_config["output_nodes"],
model_config["input_shapes"],
......@@ -746,8 +741,7 @@ def main(unused_args):
for model_name in configs["models"]:
print '===================', model_name, '==================='
model_config = configs["models"][model_name]
data_type, device_type = get_data_and_device_type(
model_config["runtime"])
runtime = model_config.get("runtime", "")
# Create model build directory
model_path_digest = md5sum(model_config["model_file_path"])
......@@ -778,8 +772,7 @@ def main(unused_args):
model_config["model_sha256_checksum"],
",".join(model_config["input_nodes"]),
",".join(model_config["output_nodes"]),
data_type,
model_config["runtime"],
runtime,
model_name,
":".join(model_config["input_shapes"]),
model_config["dsp_mode"],
......
......@@ -465,7 +465,6 @@ def gen_model_code(model_codegen_dir,
model_sha256_checksum,
input_nodes,
output_nodes,
data_type,
runtime,
model_tag,
input_shapes,
......@@ -489,7 +488,6 @@ def gen_model_code(model_codegen_dir,
"--output=%s" % model_codegen_dir + "/model.cc",
"--input_node=%s" % input_nodes,
"--output_node=%s" % output_nodes,
"--data_type=%s" % data_type,
"--runtime=%s" % runtime,
"--output_type=source",
"--template=%s" % "mace/python/tools",
......@@ -703,7 +701,7 @@ def validate_model(abi,
model_file_path,
weight_file_path,
platform,
runtime,
device_type,
input_nodes,
output_nodes,
input_shapes,
......@@ -727,7 +725,7 @@ def validate_model(abi,
if platform == "tensorflow":
validate(platform, model_file_path, "",
"%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), runtime,
"%s/%s" % (model_output_dir, output_file_name), device_type,
":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes))
elif platform == "caffe":
......@@ -743,7 +741,8 @@ def validate_model(abi,
logger.error('There is no caffe python module.')
validate(platform, model_file_path, weight_file_path,
"%s/%s" % (model_output_dir, input_file_name),
"%s/%s" % (model_output_dir, output_file_name), runtime,
"%s/%s" % (model_output_dir, output_file_name),
device_type,
":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes))
elif caffe_env == common.CaffeEnvType.DOCKER:
......@@ -806,7 +805,7 @@ def validate_model(abi,
"--weight_file=/mace/%s" % weight_file_name,
"--input_file=/mace/%s" % input_file_name,
"--mace_out_file=/mace/%s" % output_file_name,
"--mace_runtime=%s" % runtime,
"--device_type=%s" % device_type,
"--input_node=%s" % ",".join(input_nodes),
"--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes),
......
......@@ -44,7 +44,7 @@ def load_data(file):
return np.empty([0])
def compare_output(platform, mace_runtime, output_name, mace_out_value,
def compare_output(platform, device_type, output_name, mace_out_value,
out_value):
if mace_out_value.size != 0:
out_value = out_value.reshape(-1)
......@@ -53,9 +53,9 @@ def compare_output(platform, mace_runtime, output_name, mace_out_value,
similarity = (1 - spatial.distance.cosine(out_value, mace_out_value))
print output_name, 'MACE VS', platform.upper(
), 'similarity: ', similarity
if (mace_runtime == "cpu" and similarity > 0.999) or \
(mace_runtime == "gpu" and similarity > 0.995) or \
(mace_runtime == "dsp" and similarity > 0.930):
if (device_type == "CPU" and similarity > 0.999) or \
(device_type == "GPU" and similarity > 0.995) or \
(device_type == "HEXAGON" and similarity > 0.930):
print '===================Similarity Test Passed=================='
else:
print '===================Similarity Test Failed=================='
......@@ -65,7 +65,7 @@ def compare_output(platform, mace_runtime, output_name, mace_out_value,
sys.exit(-1)
def validate_tf_model(platform, mace_runtime, model_file, input_file,
def validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, output_names):
import tensorflow as tf
if not os.path.isfile(model_file):
......@@ -100,11 +100,11 @@ def validate_tf_model(platform, mace_runtime, model_file, input_file,
output_file_name = common.formatted_file_name(
mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
compare_output(platform, mace_runtime, output_names[i],
compare_output(platform, device_type, output_names[i],
mace_out_value, output_values[i])
def validate_caffe_model(platform, mace_runtime, model_file, input_file,
def validate_caffe_model(platform, device_type, model_file, input_file,
mace_out_file, weight_file, input_names, input_shapes,
output_names, output_shapes):
os.environ['GLOG_minloglevel'] = '1' # suprress Caffe verbose prints
......@@ -144,12 +144,12 @@ def validate_caffe_model(platform, mace_runtime, model_file, input_file,
output_file_name = common.formatted_file_name(
mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
compare_output(platform, mace_runtime, output_names[i], mace_out_value,
compare_output(platform, device_type, output_names[i], mace_out_value,
value)
def validate(platform, model_file, weight_file, input_file, mace_out_file,
mace_runtime, input_shape, output_shape, input_node, output_node):
device_type, input_shape, output_shape, input_node, output_node):
input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')]
......@@ -158,14 +158,14 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
assert len(input_names) == len(input_shapes)
if platform == 'tensorflow':
validate_tf_model(platform, mace_runtime, model_file, input_file,
validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes,
output_names)
elif platform == 'caffe':
output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')]
for shape in output_shape_strs]
validate_caffe_model(platform, mace_runtime, model_file, input_file,
validate_caffe_model(platform, device_type, model_file, input_file,
mace_out_file, weight_file, input_names,
input_shapes, output_names, output_shapes)
......@@ -194,7 +194,7 @@ def parse_args():
default="",
help="mace output file to load.")
parser.add_argument(
"--mace_runtime", type=str, default="gpu", help="mace runtime device.")
"--device_type", type=str, default="", help="mace runtime device.")
parser.add_argument(
"--input_shape", type=str, default="1,64,64,3", help="input shape.")
parser.add_argument(
......@@ -214,7 +214,7 @@ if __name__ == '__main__':
FLAGS.weight_file,
FLAGS.input_file,
FLAGS.mace_out_file,
FLAGS.mace_runtime,
FLAGS.device_type,
FLAGS.input_shape,
FLAGS.output_shape,
FLAGS.input_node,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册