提交 9472dcee 编写于 作者: L luxuhui

feat: support extract_image_patches op and support dynamic filter of conv2d/deconv2d

N/A
Signed-off-by: NLuxuhui <luxuhui@xiaomi.com>
上级 1b436f5d
......@@ -106,7 +106,7 @@ mace_cc_test:
GIT_SSH_COMMAND="ssh -o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no" git clone git@v9.git.n.xiaomi.com:deep-computing/generic-mobile-devices.git
DEVICE_CONF_FILE=generic-mobile-devices/devices.yml
fi
- python tools/bazel_adb_run.py --target="//test/ccunit:mace_cc_test" --device_yml=${DEVICE_CONF_FILE} --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a,arm64 --target_socs=$TARGET_SOCS
- python tools/bazel_adb_run.py --target="//test/ccunit:mace_cc_test" --device_yml=${DEVICE_CONF_FILE} --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS
- python tools/bazel_adb_run.py --target="//micro/test/ccunit:micro_ops_test" --run_target=True --stdout_processor=ops_benchmark_stdout_processor --target_abis=arm64-v8a
mace_cc_benchmark:
......@@ -133,7 +133,7 @@ model_tests:
fi
- if [ -z "$TARGET_SOCS" ]; then TARGET_SOCS=random; fi
- python tools/converter.py convert --config=${CONF_FILE} --target_socs=$TARGET_SOCS --model_graph_format=file --model_data_format=file --cl_mem_type=buffer
- python tools/converter.py run --config=${CONF_FILE} --target_socs=$TARGET_SOCS --device_yml=${DEVICE_CONF_FILE} --round=1 --target_abis=armeabi-v7a,arm64 --validate --model_graph_format=file --model_data_format=file
- python tools/converter.py run --config=${CONF_FILE} --target_socs=$TARGET_SOCS --device_yml=${DEVICE_CONF_FILE} --round=1 --target_abis=armeabi-v7a --validate --model_graph_format=file --model_data_format=file
- CONF_FILE=mace-models/mobilenet-v2/mobilenet-v2-host.yml
- python tools/converter.py convert --config=${CONF_FILE} --target_socs=$TARGET_SOCS --model_graph_format=file --model_data_format=file
- python tools/converter.py run --config=${CONF_FILE} --target_socs=$TARGET_SOCS --round=1 --validate --model_graph_format=file --model_data_format=file --address_sanitizer
......
......@@ -15,6 +15,8 @@
#include "mace/core/net_def_adapter.h"
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "mace/core/ops/operator.h"
......@@ -27,6 +29,40 @@
namespace mace {
namespace {
struct FilterRulerInfo {
const int filter_idx;
const std::vector<int> nhwc_oihw;
const std::vector<int> nchw_oihw;
FilterRulerInfo(int filter_index, const std::vector<int> nhwc2oihw,
const std::vector<int> nchw2oihw)
: filter_idx(filter_index),
nhwc_oihw(std::move(nhwc2oihw)), nchw_oihw(std::move(nchw2oihw)) {}
};
typedef std::unordered_map<
std::string, std::unique_ptr<FilterRulerInfo>> FilterTransposeRuler;
FilterTransposeRuler GetFilterTransposeRuler() {
FilterTransposeRuler filter_ruler;
// filter's src format is actually HWIO in tf, OIHW in others
// for Conv2D in MACE, the dst format is OIHW
filter_ruler.emplace("Conv2D", make_unique<FilterRulerInfo>(
1, std::vector<int>({3, 2, 0, 1}), std::vector<int>({})));
// filter's src format is actually HWOI in tf, MIHW in others
filter_ruler.emplace("Deconv2D", make_unique<FilterRulerInfo>(
1, std::vector<int>({2, 3, 0, 1}), std::vector<int>({})));
filter_ruler.emplace("DepthwiseConv2d", make_unique<FilterRulerInfo>(
1, std::vector<int>({3, 2, 0, 1}), std::vector<int>({})));
filter_ruler.emplace("DepthwiseDeconv2d", make_unique<FilterRulerInfo>(
1, std::vector<int>({2, 3, 0, 1}), std::vector<int>({})));
return filter_ruler;
}
DataFormat GetDefaultDataFormat(DeviceType device_type,
bool is_quantized_model) {
if (device_type == CPU) {
......@@ -323,7 +359,7 @@ MaceStatus NetDefAdapter::AdaptDevice(OpConditionContext *context,
VLOG(3) << "Adapt device for op " << op_def->name();
DeviceType target_device_type = target_device->device_type();
DeviceType device_type = DeviceType::CPU;
context->set_device(cpu_device);
context->set_device(target_device);
if (target_device_type != DeviceType::CPU) {
std::vector<DeviceType> producer_devices;
for (auto input : op_def->input()) {
......@@ -344,9 +380,8 @@ MaceStatus NetDefAdapter::AdaptDevice(OpConditionContext *context,
target_device_type,
available_devices,
producer_devices);
if (device_type == target_device_type) {
context->set_device(target_device);
} else {
if (device_type != target_device_type) {
context->set_device(cpu_device);
LOG(INFO) << "Op " << op_def->name() << "(" << op_def->type() << ")"
<< " fall back to CPU";
}
......@@ -411,7 +446,7 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
target_mem_type = MemoryType::GPU_BUFFER;
}
auto inputs_data_format = op_registry_->InputsDataFormat(op_def->type(),
context);
context);
DataFormat src_df, dst_df;
int input_size = op_def->input_size();
for (int i = 0; i < input_size; ++i) {
......@@ -425,71 +460,13 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
}
src_df = output_map->at(op_def->input(i)).data_format;
dst_df = inputs_data_format[i];
if (src_df != DataFormat::NONE
&& dst_df != DataFormat::NONE
&& output_map->at(op_def->input(i)).shape.size() == 4
&& src_df != dst_df) {
std::string transformed_name = TransformedName(op_def->input(i),
"data_format", static_cast<int>(dst_df));
if (transformed_set->count(transformed_name) == 0) {
VLOG(1) << "Add Transpose operation " << op_def->name()
<< " to transpose tensor "
<< op_def->input(i) << "', from data format "
<< static_cast<int>(src_df) << " to "
<< static_cast<int>(dst_df);
// Only support transpose between NHWC and NCHW for now.
std::vector<int> dst_dims;
if (src_df == DataFormat::NCHW && dst_df == DataFormat::NHWC) {
dst_dims = {0, 2, 3, 1};
} else if (src_df == DataFormat::NHWC && dst_df == DataFormat::NCHW) {
dst_dims = {0, 3, 1, 2};
} else {
LOG(FATAL) << "Encounter unsupported data format transpose from "
<< static_cast<int>(src_df) << " to "
<< static_cast<int>(dst_df);
}
auto &input_info = output_map->at(op_def->input(i));
auto output_shape = input_info.shape.empty() ?
std::vector<index_t>() :
TransposeShape<index_t, index_t>(input_info.shape,
dst_dims);
OperatorDef *transpose_op_def = target_net_def->add_op();
BuildTransposeOpDef(
op_def->input(i),
transformed_name,
output_shape,
dst_dims,
input_info.dtype,
DeviceType::CPU,
transpose_op_def);
// set data format arg
SetProtoArg<int>(transpose_op_def,
"data_format",
static_cast<int>(dst_df));
// set output memory type argument
SetProtoArg<int>(transpose_op_def,
OutputMemoryTypeTagName(),
target_mem_type);
// update tensor consumer information
output_map->at(op_def->input(i)).consumer_op_indices.push_back(
target_net_def->op_size() - 1);
// update output information map
output_map->emplace(
transformed_name,
InternalOutputInfo(
target_mem_type,
input_info.dtype,
dst_df,
output_shape,
target_net_def->op_size() - 1));
// update tensor shape map
tensor_shape_map->emplace(transformed_name, output_shape);
// record transformed tensors
transformed_set->insert(transformed_name);
}
// update original op_def's input
op_def->set_input(i, transformed_name);
const std::vector<int> dst_dims =
GetDstDimsFromTransposeRuler(output_map, op_def, i, src_df, dst_df);
if (dst_dims.size() > 0) {
AddTranposeOpForDataFormat(output_map, tensor_shape_map, transformed_set,
target_net_def, target_mem_type,
op_def, i, dst_df, dst_dims);
}
}
return MaceStatus::MACE_SUCCESS;
......@@ -590,6 +567,86 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
return MaceStatus::MACE_SUCCESS;
}
std::vector<int> NetDefAdapter::GetDstDimsFromTransposeRuler(
TensorInfoMap *output_map, const OperatorDef *op_def, const int input_idx,
const DataFormat src_df, const DataFormat dst_df) {
std::vector<int> dst_dims;
if (src_df == DataFormat::NONE || dst_df == DataFormat::NONE
|| output_map->at(op_def->input(input_idx)).shape.size() != 4) {
return dst_dims;
}
if (src_df != dst_df) { // for other operators
bool transposable = false;
if (src_df == DataFormat::NCHW && dst_df == DataFormat::NHWC) {
dst_dims = {0, 2, 3, 1};
transposable = true;
} else if (src_df == DataFormat::NHWC && dst_df == DataFormat::NCHW) {
dst_dims = {0, 3, 1, 2};
transposable = true;
} else if (dst_df == DataFormat::OIHW) {
static const auto filter_transpose_ruler = GetFilterTransposeRuler();
auto &op_type = op_def->type();
MACE_CHECK((filter_transpose_ruler.count(op_type) > 0) &&
filter_transpose_ruler.at(op_type)->filter_idx == input_idx);
if (src_df == DataFormat::NCHW) {
dst_dims = filter_transpose_ruler.at(op_type)->nchw_oihw;
transposable = true;
} else if (src_df == DataFormat::NHWC) {
dst_dims = filter_transpose_ruler.at(op_type)->nhwc_oihw;
transposable = true;
}
}
if (!transposable) {
LOG(FATAL) << "Encounter unsupported data format transpose from "
<< static_cast<int>(src_df) << " to "
<< static_cast<int>(dst_df);
}
}
return dst_dims;
}
MaceStatus NetDefAdapter::AddTranposeOpForDataFormat(
TensorInfoMap *output_map, TensorShapeMap *tensor_shape_map,
std::unordered_set<std::string> *transformed_set, NetDef *target_net_def,
MemoryType target_mem_type, OperatorDef *op_def, const int i,
const DataFormat dst_df, const std::vector<int> &dst_dims) {
std::string transformed_name = TransformedName(
op_def->input(i), "data_format", MakeString(dst_dims));
if (transformed_set->count(transformed_name) == 0) {
auto &input_info = output_map->at(op_def->input(i));
auto output_shape = input_info.shape.empty() ?
std::vector<index_t>() :
TransposeShape<index_t, index_t>(input_info.shape,
dst_dims);
OperatorDef *transpose_op_def = target_net_def->add_op();
BuildTransposeOpDef(op_def->input(i), transformed_name, output_shape,
dst_dims, input_info.dtype, DeviceType::CPU,
transpose_op_def);
// set data format arg
SetProtoArg<int>(transpose_op_def, "data_format", static_cast<int>(dst_df));
// set output memory type argument
SetProtoArg<int>(transpose_op_def,
OutputMemoryTypeTagName(), target_mem_type);
// update tensor consumer information
output_map->at(op_def->input(i)).consumer_op_indices.push_back(
target_net_def->op_size() - 1);
// update output information map
output_map->emplace(transformed_name, InternalOutputInfo(
target_mem_type, input_info.dtype, dst_df, output_shape,
target_net_def->op_size() - 1));
// update tensor shape map
tensor_shape_map->emplace(transformed_name, output_shape);
// record transformed tensors
transformed_set->insert(transformed_name);
}
// update original op_def's input
op_def->set_input(i, transformed_name);
return MaceStatus::MACE_SUCCESS;
}
std::string NetDefAdapter::DebugString(const NetDef *net_def) {
std::stringstream sstream;
auto DeviceTypeToStrFunc = [](DeviceType device_type) -> std::string {
......@@ -630,9 +687,9 @@ std::string NetDefAdapter::DebugString(const NetDef *net_def) {
for (auto &op : net_def->op()) {
std::string device_type = DeviceTypeToStrFunc(
static_cast<DeviceType>(op.device_type()));
std::string data_type = DataTypeToString(static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "T", static_cast<int>(DT_FLOAT))));
auto dt = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "T", static_cast<int>(DT_FLOAT));
std::string data_type = DataTypeToString(static_cast<DataType>(dt));
std::string mem_type = MemoryTypeToStrFunc(
static_cast<MemoryType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
......@@ -645,10 +702,10 @@ std::string NetDefAdapter::DebugString(const NetDef *net_def) {
sstream << std::endl;
sstream << "{" << std::endl;
sstream << " name: " << op.name() << std::endl;
sstream << " type: " << op.type() << std::endl;
sstream << " device: " << device_type << std::endl;
sstream << " data type: " << data_type << std::endl;
sstream << " name: " << op.name() << std::endl;
sstream << " type: " << op.type() << std::endl;
sstream << " device: " << device_type << std::endl;
sstream << " data type: " << data_type << std::endl;
sstream << " data format: " << data_format << std::endl;
sstream << " memory type: " << mem_type << std::endl;
sstream << " inputs: [";
......
......@@ -34,7 +34,6 @@ class OperatorDef;
class OpRegistry;
class Workspace;
///////////////////////////////////////////////////////////////////////////////
/// Conventions
///
......@@ -68,8 +67,8 @@ class NetDefAdapter {
NetDef *target_net_def);
public:
NetDefAdapter(const NetDefAdapter&) = delete;
NetDefAdapter(const NetDefAdapter&&) = delete;
NetDefAdapter(const NetDefAdapter &) = delete;
NetDefAdapter(const NetDefAdapter &&) = delete;
NetDefAdapter &operator=(const NetDefAdapter &) = delete;
NetDefAdapter &operator=(const NetDefAdapter &&) = delete;
......@@ -122,6 +121,15 @@ class NetDefAdapter {
MemoryType *op_output_mem_types,
NetDef *target_net_def);
std::vector<int> GetDstDimsFromTransposeRuler(
TensorInfoMap *output_map, const OperatorDef *op_def, const int input_idx,
const DataFormat src_df, const DataFormat dst_df);
MaceStatus AddTranposeOpForDataFormat(
TensorInfoMap *output_map, TensorShapeMap *tensor_shape_map,
std::unordered_set<std::string> *transformed_set, NetDef *target_net_def,
MemoryType target_mem_type, OperatorDef *op_def, const int i,
const DataFormat dst_df, const std::vector<int> &dst_dims);
std::string DebugString(const NetDef *net_def);
private:
......
......@@ -18,6 +18,9 @@
#include <cmath>
#include <vector>
#include "mace/core/ops/op_condition_builder.h"
#include "mace/core/registry/ops_registry.h"
namespace mace {
namespace ops {
......@@ -435,5 +438,37 @@ void CalDeconvOutputShapeAndPadSize(const std::vector<index_t> &input_shape,
}
}
#ifdef MACE_ENABLE_OPENCL
void SetFilterMemoryType(OpConditionContext *context,
OpenCLBufferType buffer_type) {
MemoryType mem_type = MemoryType::CPU_BUFFER;
if (context->device()->device_type() == DeviceType::GPU) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE;
} else {
mem_type = MemoryType::GPU_BUFFER;
}
auto filter_tensor = context->workspace()->GetTensor(
context->operator_def()->input(1));
if (filter_tensor == nullptr || !filter_tensor->is_weight()) {
context->SetInputOpenCLBufferType(1, buffer_type);
}
}
context->set_output_mem_type(mem_type);
}
#endif // MACE_ENABLE_OPENCL
void RegisterFilterDataFormat(OpRegistry *op_registry, const char *op_name) {
auto builder = OpConditionBuilder(op_name).SetInputsDataFormatSelector(
[](OpConditionContext *context) -> std::vector<DataFormat> {
DataFormat op_data_format = static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*context->operator_def(), "data_format",
static_cast<int>(DataFormat::NONE)));
return {op_data_format, DataFormat::OIHW, DataFormat::NONE};
});
MACE_REGISTER_OP_CONDITION(op_registry, builder);
}
} // namespace ops
} // namespace mace
......@@ -16,6 +16,10 @@
#define MACE_OPS_COMMON_CONV_POOL_2D_UTIL_H_
#include <vector>
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/opencl_util.h"
#endif // MACE_ENABLE_OPENCL
#include "mace/core/tensor.h"
namespace mace {
......@@ -31,6 +35,9 @@ enum RoundType {
CEIL = 1,
};
class OpConditionContext;
class OpRegistry;
namespace ops {
void CalcPaddingAndOutputSize(const index_t *input_shape,
......@@ -98,6 +105,13 @@ void CalDeconvOutputShapeAndPadSize(const std::vector<index_t> &input_shape,
FrameworkType framework_type,
DataFormat data_format);
#ifdef MACE_ENABLE_OPENCL
void SetFilterMemoryType(OpConditionContext *context,
OpenCLBufferType buffer_type);
#endif // MACE_ENABLE_OPENCL
void RegisterFilterDataFormat(OpRegistry *op_registry, const char *op_name);
} // namespace ops
} // namespace mace
......
......@@ -445,33 +445,41 @@ class Conv2dOp<DeviceType::GPU, float> : public ConvPool2dOpBase {
kernel_ = make_unique<opencl::buffer::Conv2dKernel>();
}
// Transform filter tensor to target format
if ((wino_block_size_ == 2 || wino_block_size_ == 4) &&
(kernel_->CheckUseWinograd(
context->device()->gpu_runtime()->opencl_runtime(),
context->workspace()->GetTensor(
operator_def_->input(1))->shape(),
std::vector<index_t>(operator_def_->output_shape(0).dims().begin(),
operator_def_->output_shape(0).dims().end()),
strides_.data(),
dilations_.data(),
&wino_block_size_))) {
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1,
OpenCLBufferType::WINOGRAD_FILTER, mem_type, wino_block_size_)
== MaceStatus::MACE_SUCCESS);
auto *filter_tensor =
context->workspace()->GetTensor(operator_def_->input(FILTER));
if (filter_tensor != nullptr && filter_tensor->is_weight()) {
if ((wino_block_size_ == 2 || wino_block_size_ == 4) &&
(kernel_->CheckUseWinograd(
context->device()->gpu_runtime()->opencl_runtime(),
filter_tensor->shape(),
std::vector<index_t>(operator_def_->output_shape(0).dims().begin(),
operator_def_->output_shape(0).dims().end()),
strides_.data(),
dilations_.data(),
&wino_block_size_))) {
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1,
OpenCLBufferType::WINOGRAD_FILTER, mem_type, wino_block_size_)
== MaceStatus::MACE_SUCCESS);
} else {
wino_block_size_ = 0;
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1,
OpenCLBufferType::CONV2D_FILTER, mem_type)
== MaceStatus::MACE_SUCCESS);
}
} else {
// we don't know whether the kernal support winograd, so disable it.
wino_block_size_ = 0;
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1,
OpenCLBufferType::CONV2D_FILTER, mem_type)
== MaceStatus::MACE_SUCCESS);
}
if (operator_def_->input_size() > 2) {
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 2, OpenCLBufferType::ARGUMENT, mem_type)
== MaceStatus::MACE_SUCCESS);
auto ret = TransformFilter(context, operator_def_.get(), 2,
OpenCLBufferType::ARGUMENT, mem_type);
MACE_CHECK(ret == MaceStatus::MACE_SUCCESS);
}
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(INPUT);
const Tensor *filter = this->Input(FILTER);
......@@ -506,6 +514,17 @@ void RegisterConv2D(OpRegistry *op_registry) {
#endif // MACE_ENABLE_QUANTIZE
MACE_REGISTER_GPU_OP(op_registry, "Conv2D", Conv2dOp);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("Conv2D").SetInputMemoryTypeSetter(
[](OpConditionContext *context) -> void {
SetFilterMemoryType(context, OpenCLBufferType::CONV2D_FILTER);
}));
#endif // MACE_ENABLE_OPENCL
RegisterFilterDataFormat(op_registry, "Conv2D");
}
} // namespace ops
......
......@@ -154,10 +154,14 @@ class Deconv2dOp<DeviceType::GPU, float> : public Deconv2dOpBase {
} else {
MACE_NOT_IMPLEMENTED;
}
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1,
OpenCLBufferType::CONV2D_FILTER, mem_type)
== MaceStatus::MACE_SUCCESS);
auto *filter_tensor =
context->workspace()->GetTensor(operator_def_->input(1));
if (filter_tensor != nullptr && filter_tensor->is_weight()) {
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1,
OpenCLBufferType::CONV2D_FILTER, mem_type)
== MaceStatus::MACE_SUCCESS);
}
if (model_type_ == FrameworkType::TENSORFLOW) {
if (operator_def_->input_size() >= 4) {
MACE_CHECK(TransformFilter(
......@@ -238,34 +242,28 @@ void RegisterDeconv2D(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "Deconv2D", Deconv2dOp, DeviceType::CPU, float);
MACE_REGISTER_BF16_OP(op_registry, "Deconv2D", Deconv2dOp, DeviceType::CPU);
MACE_REGISTER_GPU_OP(op_registry, "Deconv2D", Deconv2dOp);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("Deconv2D")
.SetInputMemoryTypeSetter(
[](OpConditionContext *context) -> void {
MemoryType mem_type = MemoryType::CPU_BUFFER;
if (context->device()->device_type() == DeviceType::GPU) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE;
} else {
MACE_NOT_IMPLEMENTED;
}
context->set_output_mem_type(mem_type);
FrameworkType framework_type =
static_cast<FrameworkType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*(context->operator_def()), "framework_type",
FrameworkType::TENSORFLOW));
if (framework_type == FrameworkType::TENSORFLOW) {
context->SetInputInfo(2, MemoryType::CPU_BUFFER,
DataType::DT_INT32);
}
} else {
context->set_output_mem_type(mem_type);
}
}));
OpConditionBuilder("Deconv2D").SetInputMemoryTypeSetter(
[](OpConditionContext *context) -> void {
SetFilterMemoryType(context, OpenCLBufferType::DW_CONV2D_FILTER);
if (context->device()->device_type() == DeviceType::GPU) {
FrameworkType framework_type =
static_cast<FrameworkType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*(context->operator_def()), "framework_type",
FrameworkType::TENSORFLOW));
if (framework_type == FrameworkType::TENSORFLOW) {
context->SetInputInfo(2, MemoryType::CPU_BUFFER,
DataType::DT_INT32);
}
}
}));
#endif // MACE_ENABLE_OPENCL
RegisterFilterDataFormat(op_registry, "Deconv2D");
}
} // namespace ops
......
......@@ -417,38 +417,13 @@ void RegisterDepthwiseConv2d(OpRegistry *op_registry) {
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("DepthwiseConv2d")
.SetInputMemoryTypeSetter(
[](OpConditionContext *context) -> void {
MemoryType mem_type = MemoryType::CPU_BUFFER;
if (context->device()->device_type() == DeviceType::GPU) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE;
} else {
mem_type = MemoryType::GPU_BUFFER;
}
auto filter_tensor = context->workspace()->GetTensor(
context->operator_def()->input(1));
if (filter_tensor == nullptr || !filter_tensor->is_weight()) {
context->SetInputOpenCLBufferType(
1, OpenCLBufferType::DW_CONV2D_FILTER);
}
}
context->set_output_mem_type(mem_type);
}));
OpConditionBuilder("DepthwiseConv2d").SetInputMemoryTypeSetter(
[](OpConditionContext *context) -> void {
SetFilterMemoryType(context, OpenCLBufferType::DW_CONV2D_FILTER);
}));
#endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("DepthwiseConv2d")
.SetInputsDataFormatSelector(
[](OpConditionContext *context) -> std::vector<DataFormat> {
DataFormat op_data_format =
static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*context->operator_def(), "data_format",
static_cast<int>(DataFormat::NONE)));
return {op_data_format, DataFormat::OIHW, DataFormat::NONE};
}));
RegisterFilterDataFormat(op_registry, "DepthwiseConv2d");
}
} // namespace ops
......
......@@ -160,10 +160,13 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, float> : public Deconv2dOpBase {
} else {
MACE_NOT_IMPLEMENTED;
}
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 1,
OpenCLBufferType::DW_CONV2D_FILTER, mem_type)
== MaceStatus::MACE_SUCCESS);
auto *filter_tensor =
context->workspace()->GetTensor(operator_def_->input(1));
if (filter_tensor != nullptr && filter_tensor->is_weight()) {
auto ret = TransformFilter(context, operator_def_.get(), 1,
OpenCLBufferType::DW_CONV2D_FILTER, mem_type);
MACE_CHECK(ret == MaceStatus::MACE_SUCCESS);
}
if (operator_def_->input_size() >= 3) {
MACE_CHECK(TransformFilter(
context, operator_def_.get(), 2,
......@@ -223,6 +226,17 @@ void RegisterDepthwiseDeconv2d(OpRegistry *op_registry) {
DepthwiseDeconv2dOp, DeviceType::CPU);
MACE_REGISTER_GPU_OP(op_registry, "DepthwiseDeconv2d", DepthwiseDeconv2dOp);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("DepthwiseDeconv2d").SetInputMemoryTypeSetter(
[](OpConditionContext *context) -> void {
SetFilterMemoryType(context, OpenCLBufferType::DW_CONV2D_FILTER);
}));
#endif // MACE_ENABLE_OPENCL
RegisterFilterDataFormat(op_registry, "DepthwiseDeconv2d");
}
} // namespace ops
......
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// This Op is for fused StatisticsExtraction and StatisticsPooling
// Components in Kaldi.
// This op is used to extract moving-average mean and standard-deviation
// statistics of input data.
// 'forward_indexes' indicates which frames of input will be used for
// extraction.
// save statistics results.
// 'forward_indexes' and 'count' were from precomputed index in kaldi.
// Reference to tools/extract_pooling.py and
// http://kaldi-asr.org/doc/nnet-general-component_8h_source.html#l00158
#include <functional>
#include <memory>
#include "mace/core/ops/operator.h"
#include "mace/core/registry/ops_registry.h"
#include "mace/core/tensor.h"
#include "mace/ops/conv_pool_2d_base.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/ops/opencl/image/extract_image_patches.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace ops {
template<DeviceType D, class T>
class ExtractImagePatchesOp;
template<class T>
class ExtractImagePatchesOp<DeviceType::CPU, T> : public ConvPool2dOpBase {
public:
explicit ExtractImagePatchesOp(OpConstructContext *context)
: ConvPool2dOpBase(context),
kernels_(Operation::GetRepeatedArgs<int>("kernels")) {}
MaceStatus Run(OpContext *context) override {
const Tensor *input_tensor = this->Input(0);
Tensor *output_tensor = this->Output(0);
std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {
input_tensor->dim(1), input_tensor->dim(1), kernels_[0], kernels_[1]};
std::vector<int> paddings(2);
if (paddings_.empty()) {
ops::CalcNCHWPaddingAndOutputSize(
input_tensor->shape().data(), filter_shape.data(), dilations_.data(),
strides_.data(), padding_type_, output_shape.data(), paddings.data());
} else {
paddings = paddings_;
CalcNCHWOutputSize(input_tensor->shape().data(), filter_shape.data(),
paddings_.data(), dilations_.data(), strides_.data(),
RoundType::FLOOR, output_shape.data());
}
output_shape[1] *= kernels_[0] * kernels_[1];
MACE_RETURN_IF_ERROR(output_tensor->Resize(output_shape));
Tensor::MappingGuard input_guard(input_tensor);
Tensor::MappingGuard output_guard(output_tensor);
const T *input = input_tensor->data<T>();
MACE_CHECK(output_tensor->dtype() == DataTypeToEnum<T>::value);
T *output = output_tensor->mutable_data<T>();
const index_t *input_shape = input_tensor->shape().data();
int pad_hw[2] = {paddings[0] / 2, paddings[1] / 2};
return ExtractImagePatches(context, input, input_shape, output_shape.data(),
kernels_.data(), strides_.data(),
dilations_.data(), pad_hw, output);
}
private:
MaceStatus ExtractImagePatches(const OpContext *context,
const T *input,
const index_t *in_shape,
const index_t *out_shape,
const int *filter_hw,
const int *stride_hw,
const int *dilation_hw,
const int *pad_hw,
T *output) {
const index_t batch = out_shape[0];
const index_t out_channels = out_shape[1];
const index_t out_height = out_shape[2];
const index_t out_width = out_shape[3];
const index_t in_channels = in_shape[1];
const index_t in_height = in_shape[2];
const index_t in_width = in_shape[3];
const index_t in_image_size = in_height * in_width;
const index_t out_image_size = out_height * out_width;
const index_t in_batch_size = in_channels * in_image_size;
const index_t out_batch_size = out_channels * out_image_size;
utils::ThreadPool
&thread_pool = context->device()->cpu_runtime()->thread_pool();
thread_pool.Compute2D([=](index_t start0, index_t end0, index_t step0,
index_t start1, index_t end1, index_t step1) {
for (index_t b = start0; b < end0; b += step0) {
for (index_t c = start1; c < end1; c += step1) {
const index_t in_c = c % in_channels;
const index_t filter_idx = c / in_channels;
const index_t out_base = b * out_batch_size + c * out_image_size;
const index_t in_base = b * in_batch_size + in_c * in_image_size;
for (index_t h = 0; h < out_height; ++h) {
for (index_t w = 0; w < out_width; ++w) {
index_t out_offset = out_base + h * out_width + w;
index_t fh = filter_idx / filter_hw[1];
index_t fw = filter_idx % filter_hw[1];
index_t inh = h * stride_hw[0] + dilation_hw[0] * fh - pad_hw[0];
index_t inw = w * stride_hw[1] + dilation_hw[1] * fw - pad_hw[1];
if (inh >= 0 && inh < in_height && inw >= 0 && inw < in_width) {
index_t input_offset = in_base + inh * in_width + inw;
output[out_offset] = input[input_offset];
} else {
output[out_offset] = 0;
}
}
}
}
}
}, 0, batch, 1, 0, out_channels, 1);
return MaceStatus::MACE_SUCCESS;
}
private:
std::vector<int> kernels_;
MACE_OP_INPUT_TAGS(INPUT);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
#ifdef MACE_ENABLE_OPENCL
template<>
class ExtractImagePatchesOp<DeviceType::GPU, float> : public ConvPool2dOpBase {
public:
explicit ExtractImagePatchesOp(OpConstructContext *context)
: ConvPool2dOpBase(context),
kernels_(Operation::GetRepeatedArgs<int>("kernels")) {
if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ExtractImagePatchesKernel>();
} else {
MACE_NOT_IMPLEMENTED;
}
}
MaceStatus Run(OpContext *context) override {
const Tensor *input = this->Input(0);
Tensor *output = this->Output(0);
return kernel_->Compute(context, input, kernels_.data(), strides_.data(),
padding_type_, paddings_, dilations_.data(),
output);
}
private:
std::vector<int> kernels_;
std::unique_ptr<OpenCLExtractImagePatchesKernel> kernel_;
};
#endif // MACE_ENABLE_OPENCL
void RegisterExtractImagePatches(OpRegistry *op_registry) {
MACE_REGISTER_OP(op_registry, "ExtractImagePatches", ExtractImagePatchesOp,
DeviceType::CPU, float);
MACE_REGISTER_BF16_OP(op_registry, "ExtractImagePatches",
ExtractImagePatchesOp, DeviceType::CPU);
MACE_REGISTER_GPU_OP(op_registry, "ExtractImagePatches",
ExtractImagePatchesOp);
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("ExtractImagePatches").SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return {DeviceType::CPU, DeviceType::GPU};
}
auto kernels = ProtoArgHelper::GetRepeatedArgs<OperatorDef, int>(
*op, "kernels");
auto &output_shape = op->output_shape(0);
auto &output_dims = output_shape.dims();
auto in_channel = output_dims[3] / kernels[0] / kernels[1];
if (output_shape.dims_size() != 4 || in_channel % 4 != 0) {
return {DeviceType::CPU};
}
#ifdef MACE_ENABLE_OPENCL
if (context->device()->device_type() == DeviceType::GPU) {
auto opencl_runtime =
context->device()->gpu_runtime()->opencl_runtime();
auto max_2d_size = opencl_runtime->GetMaxImage2DSize();
auto image_width = output_dims[2] * output_dims[3] / 4;
if (image_width > static_cast<index_t>(max_2d_size[0])) {
return {DeviceType::CPU};
}
}
#endif // MACE_ENABLE_OPENCL
return {DeviceType::CPU, DeviceType::GPU};
}));
}
} // namespace ops
} // namespace mace
......@@ -85,6 +85,7 @@ __kernel void deconv_2d(OUT_OF_RANGE_PARAMS
for (int f_y = f_start_y, idx_h = start_y ; f_y >= 0; f_y -= stride_h, ++idx_h) {
index_y = mad24(b, in_height, idx_h);
in_pos.y = select(index_y, -1, idx_h < 0 || idx_h >= in_height);
#pragma unroll
for (int f_x = f_start_x, idx_w = start_x; f_x >= 0; f_x -= stride_w, ++idx_w) {
f_pos_y = mad24(f_y, kernel_w, f_x);
f_pos_y = mad24(c, kernel_size, f_pos_y);
......
#include <common.h>
__kernel void extract_image_patches(OUT_OF_RANGE_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int in_height,
__private const int in_width,
__private const int out_height,
__private const int pad_top,
__private const int pad_left,
__private const int stride_h,
__private const int stride_w,
__private const int kernel_h,
__private const int kernel_w,
__write_only image2d_t output) {
const int out_chan_idx = get_global_id(0);
const int out_width_idx = get_global_id(1);
const int out_hb_idx = get_global_id(2);
#ifndef NON_UNIFORM_WORK_GROUP
if (out_chan_idx >= global_size_dim0 || out_width_idx >= global_size_dim1
|| out_hb_idx >= global_size_dim2) {
return;
}
#endif
const int out_width = global_size_dim1;
const int kernel_size = kernel_h * kernel_w;
const int in_channel = global_size_dim0 / kernel_size;
const int n_b = out_hb_idx / out_height;
const int mod_b = out_hb_idx - mul24(n_b, out_height);
const int in_batch_base = mul24(n_b, in_height);
const int in_height_start = mad24(mod_b, stride_h, -pad_top);
const int in_width_start = mad24(out_width_idx, stride_w, -pad_left);
const int in_chan_idx = out_chan_idx % in_channel;
const int in_channel_base = mul24(in_chan_idx, in_width);
const int kernel_base = out_chan_idx / in_channel;
const int kernel_h_idx = kernel_base / kernel_w;
const int kernel_w_idx = kernel_base % kernel_w;
int in_height_idx = in_height_start + kernel_h_idx;
in_height_idx = select(in_batch_base + in_height_idx, -1,
(in_height_idx < 0 || in_height_idx >= in_height));
int in_width_idx = in_width_start + kernel_w_idx;
in_width_idx = select(in_channel_base + in_width_idx, -1,
(in_width_idx < 0 || in_width_idx >= in_width));
const int pos = mad24(out_chan_idx, out_width, out_width_idx);
if (in_height_idx != -1 && in_width_idx != -1) {
DATA_TYPE4 in = READ_IMAGET(input, SAMPLER, (int2)(in_width_idx, in_height_idx));
WRITE_IMAGET(output, (int2)(pos, out_hb_idx), in);
} else {
WRITE_IMAGET(output, (int2)(pos, out_hb_idx), 0);
}
}
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_EXTRACT_IMAGE_PATCHES_H_
#define MACE_OPS_OPENCL_EXTRACT_IMAGE_PATCHES_H_
#include <vector>
#include "mace/ops/common/conv_pool_2d_util.h"
namespace mace {
class OpContext;
class Tensor;
namespace ops {
class OpenCLExtractImagePatchesKernel {
public:
virtual MaceStatus Compute(
OpContext *context,
const Tensor *input,
const int *kernels,
const int *strides,
const Padding &padding_type,
const std::vector<int> &padding_data,
const int *dilations,
Tensor *output) = 0;
MACE_EMPTY_VIRTUAL_DESTRUCTOR(OpenCLExtractImagePatchesKernel);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_EXTRACT_IMAGE_PATCHES_H_
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/opencl/image/extract_image_patches.h"
namespace mace {
namespace ops {
namespace opencl {
namespace image {
MaceStatus ExtractImagePatchesKernel::Compute(
OpContext *context,
const Tensor *input,
const int *kernels,
const int *strides,
const Padding &padding_type,
const std::vector<int> &padding_data,
const int *dilations,
Tensor *output) {
MACE_CHECK(dilations[0] == 1 && dilations[1] == 1,
"ExtractImagePatches opencl kernel not support dilation yet");
MACE_CHECK(input->dim(3) % 4 == 0, "Only support channel % 4 == 0");
std::vector<index_t> output_shape(4);
std::vector<index_t> filter_shape = {input->dim(3), input->dim(3),
kernels[0], kernels[1]};
std::vector<int> paddings(2);
if (padding_data.empty()) {
ops::CalcNHWCPaddingAndOutputSize(
input->shape().data(), filter_shape.data(), dilations, strides,
padding_type, output_shape.data(), paddings.data());
} else {
paddings = padding_data;
CalcOutputSize(input->shape().data(), filter_shape.data(),
padding_data.data(), dilations, strides, RoundType::FLOOR,
output_shape.data());
}
output_shape[3] *= kernels[0] * kernels[1];
std::vector<size_t> output_image_shape;
OpenCLUtil::CalImage2DShape(output_shape, OpenCLBufferType::IN_OUT_CHANNEL,
&output_image_shape);
MACE_RETURN_IF_ERROR(output->ResizeImage(output_shape, output_image_shape));
auto runtime = context->device()->gpu_runtime()->opencl_runtime();
MACE_OUT_OF_RANGE_DEFINITION;
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
MACE_OUT_OF_RANGE_CONFIG;
MACE_NON_UNIFORM_WG_CONFIG;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("extract_image_patches");
built_options.emplace("-Dextract_image_patches=" + kernel_name);
if (input->dtype() == output->dtype()) {
auto data_dt = input->dtype();
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(data_dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(data_dt));
} else {
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DT_FLOAT));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToCLCMDDt(DT_FLOAT));
}
MACE_RETURN_IF_ERROR(runtime->BuildKernel("extract_image_patches",
kernel_name,
built_options,
&kernel_));
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
const uint32_t gws[3] = {
static_cast<uint32_t>(RoundUpDiv4(output->dim(3))),
static_cast<uint32_t>(output->dim(2)),
static_cast<uint32_t>(output->dim(0) * output->dim(1)),
};
MACE_OUT_OF_RANGE_INIT(kernel_);
if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0;
MACE_OUT_OF_RANGE_SET_ARGS(kernel_);
MACE_SET_3D_GWS_ARGS(kernel_, gws);
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, static_cast<int32_t>(input->dim(1)));
kernel_.setArg(idx++, static_cast<int32_t>(input->dim(2)));
kernel_.setArg(idx++, static_cast<int32_t>(output->dim(1)));
kernel_.setArg(idx++, paddings[0] / 2);
kernel_.setArg(idx++, paddings[1] / 2);
kernel_.setArg(idx++, strides[0]);
kernel_.setArg(idx++, strides[1]);
kernel_.setArg(idx++, kernels[0]);
kernel_.setArg(idx++, kernels[1]);
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = Default3DLocalWS(runtime, gws, kwg_size_);
std::string tuning_key =
Concat("extract_image_patches_opencl_kernel_", output->dim(0),
output->dim(1), output->dim(2), output->dim(3));
MACE_RETURN_IF_ERROR(TuningOrRun3DKernel(runtime, kernel_, tuning_key,
gws, lws, context->future()));
MACE_OUT_OF_RANGE_VALIDATION;
return MaceStatus::MACE_SUCCESS;
}
} // namespace image
} // namespace opencl
} // namespace ops
} // namespace mace
// Copyright 2020 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_OPENCL_IMAGE_EXTRACT_IMAGE_PATCHES_H_
#define MACE_OPS_OPENCL_IMAGE_EXTRACT_IMAGE_PATCHES_H_
#include <algorithm>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "mace/core/ops/op_context.h"
#include "mace/core/runtime/opencl/opencl_helper.h"
#include "mace/core/tensor.h"
#include "mace/ops/opencl/extract_image_patches.h"
namespace mace {
namespace ops {
namespace opencl {
namespace image {
class ExtractImagePatchesKernel : public OpenCLExtractImagePatchesKernel {
public:
MaceStatus Compute(
OpContext *context,
const Tensor *input,
const int *kernels,
const int *strides,
const Padding &padding_type,
const std::vector<int> &padding_data,
const int *dilations,
Tensor *output) override;
private:
cl::Kernel kernel_;
uint32_t kwg_size_;
std::vector<index_t> input_shape_;
};
} // namespace image
} // namespace opencl
} // namespace ops
} // namespace mace
#endif // MACE_OPS_OPENCL_IMAGE_EXTRACT_IMAGE_PATCHES_H_
......@@ -56,6 +56,7 @@ extern void RegisterPad(OpRegistry *op_registry);
extern void RegisterPadContext(OpRegistry *op_registry);
extern void RegisterPNorm(OpRegistry *op_registry);
extern void RegisterPooling(OpRegistry *op_registry);
extern void RegisterExtractImagePatches(OpRegistry *op_registry);
extern void RegisterReduce(OpRegistry *op_registry);
extern void RegisterReplaceIndex(OpRegistry *op_registry);
extern void RegisterPriorBox(OpRegistry *op_registry);
......@@ -136,6 +137,7 @@ void RegisterAllOps(OpRegistry *registry) {
ops::RegisterPadContext(registry);
ops::RegisterPNorm(registry);
ops::RegisterPooling(registry);
ops::RegisterExtractImagePatches(registry);
ops::RegisterReduce(registry);
ops::RegisterReplaceIndex(registry);
ops::RegisterPriorBox(registry);
......
......@@ -46,6 +46,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/depthwise_conv2d.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/depthwise_conv2d_buffer.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/eltwise.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/extract_image_patches.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/fully_connected.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/lstmcell.cl"))
unused_var = repository_ctx.path(Label("//:mace/ops/opencl/cl/matmul.cl"))
......
// Copyright 2018 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "mace/ops/common/conv_pool_2d_util.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class ExtractImagePatchesOpTest : public OpsTestBase {};
TEST_F(ExtractImagePatchesOpTest, VALID) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 4, 4, 2},
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
net.TransformDataFormat<DeviceType::CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ExtractImagePatches", "ExtractImagePatchesTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check
auto expected =
net.CreateTensor<float>({1, 2, 2, 8},
{0, 16, 1, 17, 4, 20, 5, 21, 2, 18, 3, 19, 6, 22, 7, 23,
8, 24, 9, 25, 12, 28, 13, 29, 10, 26, 11, 27, 14, 30, 15, 31});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ExtractImagePatchesOpTest, SAME) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 3, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8});
net.TransformDataFormat<DeviceType::CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ExtractImagePatches", "ExtractImagePatchesTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check
auto expected = net.CreateTensor<float>(
{1, 2, 2, 4}, {0, 1, 3, 4, 2, 0, 5, 0, 6, 7, 0, 0, 8, 0, 0, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ExtractImagePatchesOpTest, VALID_DILATION) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 4, 4, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
net.TransformDataFormat<DeviceType::CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ExtractImagePatches", "ExtractImagePatchesTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {1, 1})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {2, 2})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check
auto expected = net.CreateTensor<float>(
{1, 2, 2, 4}, {0, 2, 8, 10, 1, 3, 9, 11, 4, 6, 12, 14, 5, 7, 13, 15});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
TEST_F(ExtractImagePatchesOpTest, k2x2s2x2) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<DeviceType::CPU, float>(
"Input", {1, 2, 9, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
net.TransformDataFormat<DeviceType::CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ExtractImagePatches", "ExtractImagePatchesTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntsArg("kernels", {2, 2})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check
auto expected = net.CreateTensor<float>(
{1, 1, 5, 4},
{0, 1, 9, 10, 2, 3, 11, 12, 4, 5, 13, 14, 6, 7, 15, 16, 8, 0, 17, 0});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
namespace {
template <DeviceType D>
void SimpleExtractImagePatches3S2() {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>(
"Input", {1, 3, 9, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26});
if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Run
OpDefBuilder("ExtractImagePatches", "ExtractImagePatchesTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntsArg("kernels", {3, 3})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) {
OpDefBuilder("ExtractImagePatches", "ExtractImagePatchesTest")
.Input("Input")
.Output("Output")
.AddIntsArg("kernels", {3, 3})
.AddIntsArg("strides", {2, 2})
.AddIntArg("padding", Padding::VALID)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
net.RunOp(D);
}
// Check
auto expected = net.CreateTensor<float>({1, 1, 4, 9},
{0, 1, 2, 9, 10, 11, 18, 19, 20,
2, 3, 4, 11, 12, 13, 20, 21, 22,
4, 5, 6, 13, 14, 15, 22, 23, 24,
6, 7, 8, 15, 16, 17, 24, 25, 26});
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
} // namespace
TEST_F(ExtractImagePatchesOpTest, CPUSimpleExtractImagePatches3S2) {
SimpleExtractImagePatches3S2<CPU>();
}
namespace {
template <DeviceType D, typename T>
void ExtractImagePatches3S2(const std::vector<index_t> &input_shape,
const std::vector<int> &strides,
Padding padding) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input", input_shape);
net.TransformDataFormat<DeviceType::CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ExtractImagePatches", "ExtractImagePatchesTest")
.Input("InputNCHW")
.Output("OutputNCHW")
.AddIntsArg("kernels", {3, 3})
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef());
// run on cpu
net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output"));
OpDefBuilder("ExtractImagePatches", "ExtractImagePatchesTest")
.Input("Input")
.Output("Output")
.AddIntsArg("kernels", {3, 3})
.AddIntsArg("strides", strides)
.AddIntArg("padding", padding)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
net.RunOp(D);
if (DataTypeToEnum<T>::value == DT_HALF) {
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-3,
1e-4);
} else {
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
}
}
} // namespace
TEST_F(ExtractImagePatchesOpTest, OPENCLAlignedExtractImagePatches3S2) {
ExtractImagePatches3S2<GPU, float>({3, 64, 32, 32}, {1, 1}, Padding::VALID);
ExtractImagePatches3S2<GPU, float>({3, 64, 32, 32}, {2, 2}, Padding::VALID);
ExtractImagePatches3S2<GPU, float>({3, 64, 32, 32}, {1, 2}, Padding::VALID);
ExtractImagePatches3S2<GPU, float>({3, 64, 32, 32}, {1, 1}, Padding::SAME);
ExtractImagePatches3S2<GPU, float>({3, 64, 32, 32}, {2, 2}, Padding::SAME);
ExtractImagePatches3S2<GPU, float>({3, 64, 32, 32}, {2, 1}, Padding::SAME);
ExtractImagePatches3S2<GPU, float>({3, 63, 31, 32}, {2, 2}, Padding::VALID);
ExtractImagePatches3S2<GPU, float>({3, 65, 27, 32}, {2, 1}, Padding::SAME);
}
TEST_F(ExtractImagePatchesOpTest, OPENCLHalfAlignedExtractImagePatches3S2) {
ExtractImagePatches3S2<GPU, half>({3, 64, 32, 32}, {1, 1}, Padding::VALID);
ExtractImagePatches3S2<GPU, half>({3, 64, 32, 32}, {2, 2}, Padding::VALID);
ExtractImagePatches3S2<GPU, half>({3, 64, 32, 32}, {1, 2}, Padding::VALID);
ExtractImagePatches3S2<GPU, half>({3, 64, 32, 32}, {1, 1}, Padding::SAME);
ExtractImagePatches3S2<GPU, half>({3, 64, 32, 32}, {2, 2}, Padding::SAME);
ExtractImagePatches3S2<GPU, half>({3, 64, 32, 32}, {2, 1}, Padding::SAME);
ExtractImagePatches3S2<GPU, half>({3, 63, 31, 32}, {2, 2}, Padding::VALID);
ExtractImagePatches3S2<GPU, half>({3, 65, 27, 32}, {2, 1}, Padding::SAME);
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -109,6 +109,7 @@ MaceSupportedOps = [
'Dequantize',
'Eltwise',
'ExpandDims',
'ExtractImagePatches',
'ExtractPooling',
'Fill',
'FullyConnected',
......@@ -173,6 +174,7 @@ MaceFixedDataFormatOps = [MaceOp.BatchNorm,
MaceOp.DepthwiseDeconv2d,
MaceOp.FullyConnected,
MaceOp.Pooling,
MaceOp.ExtractImagePatches,
MaceOp.ResizeBicubic,
MaceOp.ResizeBilinear,
MaceOp.ResizeNearestNeighbor,
......
......@@ -72,6 +72,7 @@ TFSupportedOps = [
'Div',
'Equal',
'ExpandDims',
'ExtractImagePatches',
'FakeQuantWithMinMaxVars',
'FakeQuantWithMinMaxArgs',
'Fill',
......@@ -231,6 +232,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Div.name: self.convert_elementwise,
TFOpType.Equal.name: self.convert_elementwise,
TFOpType.ExpandDims.name: self.convert_expand_dims,
TFOpType.ExtractImagePatches.name:
self.convert_extract_image_patches,
TFOpType.FakeQuantWithMinMaxVars.name: self.convert_fake_quantize,
TFOpType.FakeQuantWithMinMaxArgs.name: self.convert_fake_quantize,
TFOpType.Fill.name: self.convert_fill,
......@@ -714,6 +717,27 @@ class TensorflowConverter(base_converter.ConverterInterface):
kernels_arg.name = MaceKeyword.mace_kernel_str
kernels_arg.ints.extend(tf_op.get_attr(tf_kernel_str)[1:3])
def convert_extract_image_patches(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.ExtractImagePatches.name
padding_arg = op.arg.add()
padding_arg.name = MaceKeyword.mace_padding_str
padding_arg.i = self.padding_mode[tf_op.get_attr(tf_padding_str)].value
strides_arg = op.arg.add()
strides_arg.name = MaceKeyword.mace_strides_str
strides_arg.ints.extend(tf_op.get_attr(tf_strides_str)[1:3])
dilation_arg = op.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str
dilations = tf_op.get_attr('rates')[1:3]
dilation_arg.ints.extend(dilations)
kernels_arg = op.arg.add()
kernels_arg.name = MaceKeyword.mace_kernel_str
kernels_arg.ints.extend(tf_op.get_attr('ksizes')[1:3])
def convert_softmax(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Softmax.name
......
......@@ -1218,6 +1218,7 @@ class Transformer(base_converter.ConverterInterface):
for op in net.op:
if op.type in [MaceOp.Deconv2D.name,
MaceOp.DepthwiseDeconv2d] \
and op.input[1] in self._consts \
and op.input[1] not in transposed_deconv_filter:
filter = self._consts[op.input[1]]
filter_data = np.array(filter.float_data).reshape(
......
......@@ -36,10 +36,11 @@ def execute(cmd, verbose=True):
universal_newlines=True)
if not verbose:
# use p.communicate instead of p.wait to avoid such situation: pipe is filled and the child process is blocked.
# use p.communicate instead of p.wait to avoid such situation:
# pipe is filled and the child process is blocked.
out, err = p.communicate()
if p.returncode != 0:
raise Exception("errorcode: {}".format(p.returncode) )
raise Exception("errorcode: {}".format(p.returncode))
return out
buf = []
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册