提交 95b32c24 编写于 作者: L liuqi

Update the memory type choose logic and polish some code.

1. Change DataFormat from enum to enum class.
上级 74dcd617
...@@ -83,7 +83,7 @@ DataFormat ParseDataFormat(const std::string &data_format_str) { ...@@ -83,7 +83,7 @@ DataFormat ParseDataFormat(const std::string &data_format_str) {
} else if (data_format_str == "OIHW") { } else if (data_format_str == "OIHW") {
return DataFormat::OIHW; return DataFormat::OIHW;
} else { } else {
return DataFormat::DF_NONE; return DataFormat::NONE;
} }
} }
......
...@@ -123,14 +123,13 @@ MACE_GET_REPEATED_ARGUMENT_FUNC(int64_t, ints, true) ...@@ -123,14 +123,13 @@ MACE_GET_REPEATED_ARGUMENT_FUNC(int64_t, ints, true)
MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, float, f) \ MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, float, f) \
MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, bool, i) \ MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, bool, i) \
MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, int, i) \ MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, int, i) \
MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, int64_t, i) \ MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, int64_t, i)
MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, std::string, s)
MACE_SET_OPTIONAL_ARGUMENT_FUNC_MACRO(OperatorDef) MACE_SET_OPTIONAL_ARGUMENT_FUNC_MACRO(OperatorDef)
MACE_SET_OPTIONAL_ARGUMENT_FUNC_MACRO(NetDef) MACE_SET_OPTIONAL_ARGUMENT_FUNC_MACRO(NetDef)
#undef MACE_SET_OPTIONAL_ARGUMENT_FUNC #undef MACE_SET_OPTIONAL_ARGUMENT_FUNC
std::string OutputMemoryTypeTagName() { const std::string OutputMemoryTypeTagName() {
static const char *kOutputMemTypeArgName = "output_mem_type"; static const char *kOutputMemTypeArgName = "output_mem_type";
return kOutputMemTypeArgName; return kOutputMemTypeArgName;
} }
......
...@@ -65,7 +65,7 @@ void SetProtoArg(NetDef *op_def, ...@@ -65,7 +65,7 @@ void SetProtoArg(NetDef *op_def,
const std::string &arg_name, const std::string &arg_name,
const T&value); const T&value);
std::string OutputMemoryTypeTagName(); const std::string OutputMemoryTypeTagName();
bool IsQuantizedModel(const NetDef &def); bool IsQuantizedModel(const NetDef &def);
......
...@@ -126,7 +126,8 @@ void MemoryOptimizer::Optimize( ...@@ -126,7 +126,8 @@ void MemoryOptimizer::Optimize(
DataFormat data_format = static_cast<DataFormat>( DataFormat data_format = static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op_def, "data_format", DataFormat::DF_NONE)); *op_def, "data_format",
static_cast<int>(DataFormat::NONE)));
int output_size = op_def->output_size(); int output_size = op_def->output_size();
for (int i = 0; i < output_size; ++i) { for (int i = 0; i < output_size; ++i) {
if (i < op_def->output_type_size()) { if (i < op_def->output_type_size()) {
......
...@@ -76,7 +76,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -76,7 +76,7 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
if (target_device_->device_type() == DeviceType::GPU) { if (target_device_->device_type() == DeviceType::GPU) {
// update the map : output_tensor -> Operation // update the map : output_tensor -> MemoryType
MemoryType out_mem_type = MemoryType out_mem_type =
static_cast<MemoryType>( static_cast<MemoryType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
......
...@@ -37,7 +37,7 @@ DataFormat GetDefaultDataFormat(DeviceType device_type, ...@@ -37,7 +37,7 @@ DataFormat GetDefaultDataFormat(DeviceType device_type,
return DataFormat::NHWC; return DataFormat::NHWC;
} else { } else {
LOG(FATAL) << "MACE do not support the device " << device_type; LOG(FATAL) << "MACE do not support the device " << device_type;
return DataFormat::DF_NONE; return DataFormat::NONE;
} }
} }
...@@ -50,19 +50,21 @@ std::string TransformedName(const std::string &input_name, ...@@ -50,19 +50,21 @@ std::string TransformedName(const std::string &input_name,
return ss.str(); return ss.str();
} }
#ifdef MACE_ENABLE_OPENCL
bool TransformRequiredOp(const std::string &op_type) { bool TransformRequiredOp(const std::string &op_type) {
static const std::unordered_set<std::string> kNoTransformOp = { static const std::unordered_set<std::string> kNoTransformOp = {
"Shape", "InferConv2dShape" "Shape", "InferConv2dShape"
}; };
return kNoTransformOp.count(op_type) == 0; return kNoTransformOp.count(op_type) == 0;
} }
#endif // MACE_ENABLE_OPENCL
void BuildTransposeOpDef( void BuildTransposeOpDef(
const std::string &input_name, const std::string &input_name,
const std::string &output_name, const std::string &output_name,
const std::vector<mace::index_t> &output_shape, const std::vector<index_t> &output_shape,
const std::vector<int> dst_dims, const std::vector<int> dst_dims,
const mace::DataType dt, const DataType dt,
DeviceType device_type, DeviceType device_type,
OperatorDef *op_def) { OperatorDef *op_def) {
std::string op_name = "mace_node_" + output_name; std::string op_name = "mace_node_" + output_name;
...@@ -89,21 +91,13 @@ void BuildTransposeOpDef( ...@@ -89,21 +91,13 @@ void BuildTransposeOpDef(
} // namespace } // namespace
NetDefAdapter::NetDefAdapter(const mace::OpRegistryBase *op_registry, NetDefAdapter::NetDefAdapter(const OpRegistryBase *op_registry,
const mace::Workspace *ws) const Workspace *ws)
: op_registry_(op_registry), ws_(ws) {} : op_registry_(op_registry), ws_(ws) {}
// Adapt original net_def to a better net.
// 1. Adapt device: choose best device for every op in the net.
// 2. Adapt data type: Add data type related transform ops
// for mixing precision.
// 3. Adapt data format: confirm data format of every op
// and add transpose if necessary.
// 4. Adapt memory type: Add BufferTransform if necessary
// for transforming memory type between ops.
MaceStatus NetDefAdapter::AdaptNetDef( MaceStatus NetDefAdapter::AdaptNetDef(
const mace::NetDef *net_def, const NetDef *net_def,
mace::Device *target_device, Device *target_device,
NetDef *target_net_def) { NetDef *target_net_def) {
MACE_LATENCY_LOGGER(1, "Adapting original NetDef"); MACE_LATENCY_LOGGER(1, "Adapting original NetDef");
// Copy from original op_def, leave ops alone. // Copy from original op_def, leave ops alone.
...@@ -115,7 +109,7 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -115,7 +109,7 @@ MaceStatus NetDefAdapter::AdaptNetDef(
std::unique_ptr<CPUDevice> cpu_device = make_unique<CPUDevice>( std::unique_ptr<CPUDevice> cpu_device = make_unique<CPUDevice>(
target_device->cpu_runtime()->num_threads(), target_device->cpu_runtime()->num_threads(),
target_device->cpu_runtime()->policy(), target_device->cpu_runtime()->policy(),
target_device->cpu_runtime()->use_gemmlowp()); &(target_device->cpu_runtime()->thread_pool()));
// quantize model flag // quantize model flag
bool is_quantized_model = IsQuantizedModel(*net_def); bool is_quantized_model = IsQuantizedModel(*net_def);
...@@ -131,40 +125,40 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -131,40 +125,40 @@ MaceStatus NetDefAdapter::AdaptNetDef(
std::vector<index_t>(tensor.dims().begin(), tensor.dims().end()); std::vector<index_t>(tensor.dims().begin(), tensor.dims().end());
} }
MemoryType mem_type = MemoryType::CPU_BUFFER;
if (target_device->device_type() == DeviceType::CPU) {
mem_type = MemoryType::CPU_BUFFER;
} else if (target_device->device_type() == DeviceType::GPU) {
mem_type = MemoryType::GPU_BUFFER;
} else {
LOG(FATAL) << "MACE do not support the device type: "
<< target_device->device_type();
}
int input_size = target_net_def->input_info_size(); int input_size = target_net_def->input_info_size();
for (int i = 0; i < input_size; ++i) { for (int i = 0; i < input_size; ++i) {
auto input_info = target_net_def->mutable_input_info(i); auto input_info = target_net_def->mutable_input_info(i);
MemoryType mem_type = MemoryType::CPU_BUFFER; auto input_data_format = static_cast<DataFormat>(
if (target_device->device_type() == DeviceType::CPU) {
mem_type = MemoryType::CPU_BUFFER;
} else if (target_device->device_type() == DeviceType::GPU) {
mem_type = MemoryType::GPU_BUFFER;
} else {
LOG(FATAL) << "MACE do not support the device type: "
<< target_device->device_type();
}
DataFormat input_data_format = static_cast<DataFormat>(
input_info->data_format()); input_info->data_format());
DataFormat expected_data_format = GetDefaultDataFormat( DataFormat expected_data_format = GetDefaultDataFormat(
target_device->device_type(), is_quantized_model); target_device->device_type(), is_quantized_model);
std::vector<index_t> input_shape = std::vector<index_t> input_shape(input_info->dims().begin(),
std::vector<index_t>(input_info->dims().begin(), input_info->dims().end());
input_info->dims().end()); if (input_data_format != DataFormat::NONE
if (input_data_format != DataFormat::DF_NONE
&& input_data_format != expected_data_format && input_data_format != expected_data_format
&& input_shape.size() == 4) { && input_shape.size() == 4) {
if (input_data_format == DataFormat::NHWC if (input_data_format == DataFormat::NHWC
&& expected_data_format == DataFormat::NCHW) { && expected_data_format == DataFormat::NCHW) {
std::vector<int> dst_dims = {0, 3, 1, 2}; std::vector<int> dst_dims{0, 3, 1, 2};
input_data_format = DataFormat::NCHW; input_data_format = DataFormat::NCHW;
input_shape = TransposeShape<index_t, index_t>(input_shape, dst_dims); input_shape = TransposeShape<index_t, index_t>(input_shape, dst_dims);
} else if (input_data_format == DataFormat::NCHW } else if (input_data_format == DataFormat::NCHW
&& expected_data_format == DataFormat::NHWC) { && expected_data_format == DataFormat::NHWC) {
std::vector<int> dst_dims = {0, 2, 3, 1}; std::vector<int> dst_dims{0, 2, 3, 1};
input_data_format = DataFormat::NHWC; input_data_format = DataFormat::NHWC;
input_shape = TransposeShape<index_t, index_t>(input_shape, dst_dims); input_shape = TransposeShape<index_t, index_t>(input_shape, dst_dims);
} }
input_info->set_data_format(input_data_format); input_info->set_data_format(static_cast<int>(input_data_format));
int input_shape_size = input_shape.size(); int input_shape_size = input_shape.size();
for (int j = 0; j < input_shape_size; ++j) { for (int j = 0; j < input_shape_size; ++j) {
input_info->set_dims(j, input_shape[j]); input_info->set_dims(j, input_shape[j]);
...@@ -287,9 +281,10 @@ MaceStatus NetDefAdapter::AdaptNetDef( ...@@ -287,9 +281,10 @@ MaceStatus NetDefAdapter::AdaptNetDef(
internal_output_info.data_format, internal_output_info.data_format,
transformed_op_def); transformed_op_def);
// set data format arg // set data format arg
SetProtoArg<int>(transformed_op_def, SetProtoArg<int>(
"data_format", transformed_op_def,
internal_output_info.data_format); "data_format",
static_cast<int>(internal_output_info.data_format));
// set output memory type argument // set output memory type argument
SetProtoArg<int>(transformed_op_def, SetProtoArg<int>(transformed_op_def,
OutputMemoryTypeTagName(), OutputMemoryTypeTagName(),
...@@ -309,7 +304,7 @@ MaceStatus NetDefAdapter::AdaptDevice(OpConditionContext *context, ...@@ -309,7 +304,7 @@ MaceStatus NetDefAdapter::AdaptDevice(OpConditionContext *context,
const TensorInfoMap &output_map, const TensorInfoMap &output_map,
const NetDef *net_def, const NetDef *net_def,
OperatorDef *op_def) { OperatorDef *op_def) {
VLOG(1) << "Adapt device for op " << op_def->name(); VLOG(3) << "Adapt device for op " << op_def->name();
DeviceType target_device_type = target_device->device_type(); DeviceType target_device_type = target_device->device_type();
DeviceType device_type = DeviceType::CPU; DeviceType device_type = DeviceType::CPU;
context->set_device(cpu_device); context->set_device(cpu_device);
...@@ -335,15 +330,18 @@ MaceStatus NetDefAdapter::AdaptDevice(OpConditionContext *context, ...@@ -335,15 +330,18 @@ MaceStatus NetDefAdapter::AdaptDevice(OpConditionContext *context,
producer_devices); producer_devices);
if (device_type == target_device_type) { if (device_type == target_device_type) {
context->set_device(target_device); context->set_device(target_device);
} else {
LOG(INFO) << "Op " << op_def->name() << " fall back to CPU";
} }
} }
op_def->set_device_type(device_type); op_def->set_device_type(device_type);
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
MaceStatus NetDefAdapter::AdaptDataType(mace::OpConditionContext *context, MaceStatus NetDefAdapter::AdaptDataType(OpConditionContext *context,
mace::OperatorDef *op_def) { OperatorDef *op_def) {
MACE_UNUSED(context); MACE_UNUSED(context);
// Where to add logic to support mixing precision
// Adjust data type of op ran on CPU // Adjust data type of op ran on CPU
DataType dtype = static_cast<DataType>( DataType dtype = static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
...@@ -355,20 +353,20 @@ MaceStatus NetDefAdapter::AdaptDataType(mace::OpConditionContext *context, ...@@ -355,20 +353,20 @@ MaceStatus NetDefAdapter::AdaptDataType(mace::OpConditionContext *context,
} }
MaceStatus NetDefAdapter::AdaptDataFormat( MaceStatus NetDefAdapter::AdaptDataFormat(
mace::OpConditionContext *context, OpConditionContext *context,
mace::OperatorDef *op_def, OperatorDef *op_def,
bool is_quantized_model, bool is_quantized_model,
TensorInfoMap *output_map, TensorInfoMap *output_map,
std::unordered_set<std::string> *transformed_set, std::unordered_set<std::string> *transformed_set,
DataFormat *op_output_df, DataFormat *op_output_df,
mace::NetDef *target_net_def) { NetDef *target_net_def) {
VLOG(1) << "Adapt data format for op " << op_def->name(); VLOG(3) << "Adapt data format for op " << op_def->name();
MACE_UNUSED(context);
DataFormat op_data_format = DataFormat op_data_format =
static_cast<DataFormat>(ProtoArgHelper::GetOptionalArg<OperatorDef, int>( static_cast<DataFormat>(ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op_def, "data_format", 0)); *op_def, "data_format",
static_cast<int>(DataFormat::NONE)));
// adjust the data format of operation // adjust the data format of operation
if (op_data_format == DataFormat::DF_AUTO) { if (op_data_format == DataFormat::AUTO) {
op_data_format = GetDefaultDataFormat( op_data_format = GetDefaultDataFormat(
static_cast<DeviceType>(op_def->device_type()), is_quantized_model); static_cast<DeviceType>(op_def->device_type()), is_quantized_model);
SetProtoArg<int>(op_def, "data_format", static_cast<int>(op_data_format)); SetProtoArg<int>(op_def, "data_format", static_cast<int>(op_data_format));
...@@ -376,14 +374,15 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -376,14 +374,15 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
int output_shape_size = op_def->output_shape_size(); int output_shape_size = op_def->output_shape_size();
for (int i = 0; i < output_shape_size; ++i) { for (int i = 0; i < output_shape_size; ++i) {
auto output_shape = op_def->mutable_output_shape(i); auto output_shape = op_def->mutable_output_shape(i);
if (output_shape->dims_size() == 4) { MACE_CHECK(output_shape->dims_size() == 4,
// transpose output shape format from NHWC to NCHW "Output shape should be 4D if the of has data format. ",
int64_t height = output_shape->dims(1); op_def->name());
int64_t width = output_shape->dims(2); // transpose output shape format from NHWC to NCHW
output_shape->set_dims(1, output_shape->dims(3)); int64_t height = output_shape->dims(1);
output_shape->set_dims(2, height); int64_t width = output_shape->dims(2);
output_shape->set_dims(3, width); output_shape->set_dims(1, output_shape->dims(3));
} output_shape->set_dims(2, height);
output_shape->set_dims(3, width);
} }
} }
} }
...@@ -394,8 +393,8 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -394,8 +393,8 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
if (op_def->device_type() == DeviceType::GPU) { if (op_def->device_type() == DeviceType::GPU) {
target_mem_type = MemoryType::GPU_BUFFER; target_mem_type = MemoryType::GPU_BUFFER;
} }
// Use op's data format as inputs' data format for now. auto inputs_data_format = op_registry_->InputsDataFormat(op_def->type(),
// Could move the logic to OpRegistry if necessary. context);
DataFormat src_df, dst_df; DataFormat src_df, dst_df;
int input_size = op_def->input_size(); int input_size = op_def->input_size();
for (int i = 0; i < input_size; ++i) { for (int i = 0; i < input_size; ++i) {
...@@ -408,20 +407,21 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -408,20 +407,21 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
continue; continue;
} }
src_df = output_map->at(op_def->input(i)).data_format; src_df = output_map->at(op_def->input(i)).data_format;
dst_df = op_data_format; dst_df = inputs_data_format[i];
if (src_df == DataFormat::DF_NONE if (src_df == DataFormat::NONE
|| dst_df == DataFormat::DF_NONE || dst_df == DataFormat::NONE
|| output_map->at(op_def->input(i)).shape.size() != 4) { || output_map->at(op_def->input(i)).shape.size() != 4) {
continue; continue;
} }
if (src_df != dst_df) { if (src_df != dst_df) {
std::string transformed_name = TransformedName(op_def->input(i), std::string transformed_name = TransformedName(op_def->input(i),
"data_format", dst_df); "data_format", static_cast<int>(dst_df));
if (transformed_set->count(transformed_name) == 0) { if (transformed_set->count(transformed_name) == 0) {
VLOG(1) << "Add Transpose operation " << op_def->name() VLOG(1) << "Add Transpose operation " << op_def->name()
<< " to transpose tensor " << " to transpose tensor "
<< op_def->input(i) << "', from data format " << op_def->input(i) << "', from data format "
<< src_df << " to " << dst_df; << static_cast<int>(src_df) << " to "
<< static_cast<int>(dst_df);
// Only support transpose between NHWC and NCHW for now. // Only support transpose between NHWC and NCHW for now.
std::vector<int> dst_dims; std::vector<int> dst_dims;
if (src_df == DataFormat::NCHW && dst_df == DataFormat::NHWC) { if (src_df == DataFormat::NCHW && dst_df == DataFormat::NHWC) {
...@@ -430,7 +430,8 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -430,7 +430,8 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
dst_dims = {0, 3, 1, 2}; dst_dims = {0, 3, 1, 2};
} else { } else {
LOG(FATAL) << "Encounter unsupported data format transpose from " LOG(FATAL) << "Encounter unsupported data format transpose from "
<< src_df << " to " << dst_df; << static_cast<int>(src_df) << " to "
<< static_cast<int>(dst_df);
} }
auto &input_info = output_map->at(op_def->input(i)); auto &input_info = output_map->at(op_def->input(i));
auto output_shape = input_info.shape.empty() ? auto output_shape = input_info.shape.empty() ?
...@@ -449,7 +450,7 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -449,7 +450,7 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
// set data format arg // set data format arg
SetProtoArg<int>(transpose_op_def, SetProtoArg<int>(transpose_op_def,
"data_format", "data_format",
dst_df); static_cast<int>(dst_df));
// set output memory type argument // set output memory type argument
SetProtoArg<int>(transpose_op_def, SetProtoArg<int>(transpose_op_def,
OutputMemoryTypeTagName(), OutputMemoryTypeTagName(),
...@@ -475,20 +476,20 @@ MaceStatus NetDefAdapter::AdaptDataFormat( ...@@ -475,20 +476,20 @@ MaceStatus NetDefAdapter::AdaptDataFormat(
} }
MaceStatus NetDefAdapter::AdaptMemoryType( MaceStatus NetDefAdapter::AdaptMemoryType(
mace::OpConditionContext *context, OpConditionContext *context,
mace::OperatorDef *op_def, OperatorDef *op_def,
mace::NetDefAdapter::TensorInfoMap *output_map, NetDefAdapter::TensorInfoMap *output_map,
std::unordered_set<std::string> *transformed_set, std::unordered_set<std::string> *transformed_set,
MemoryType *op_output_mem_types, MemoryType *op_output_mem_types,
mace::NetDef *target_net_def) { NetDef *target_net_def) {
VLOG(1) << "Adapt memory type for op " << op_def->name(); VLOG(3) << "Adapt memory type for op " << op_def->name();
// Get expected output memory type // Get expected output memory type
// (only support one kind of memory type for multiple outputs) // (only support one kind of memory type for multiple outputs)
op_registry_->GetInOutMemoryTypes(op_def->type(), context); op_registry_->GetInOutMemoryTypes(op_def->type(), context);
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
int input_size = op_def->input_size();
// if op is memory-unused op, no transformation // if op is memory-unused op, no transformation
if (TransformRequiredOp(op_def->type())) { if (TransformRequiredOp(op_def->type())) {
int input_size = op_def->input_size();
for (int i = 0; i < input_size; ++i) { for (int i = 0; i < input_size; ++i) {
if (output_map->count(op_def->input(i)) == 0) { if (output_map->count(op_def->input(i)) == 0) {
MACE_CHECK(ws_->GetTensor(op_def->input(i)) != nullptr MACE_CHECK(ws_->GetTensor(op_def->input(i)) != nullptr
...@@ -498,14 +499,14 @@ MaceStatus NetDefAdapter::AdaptMemoryType( ...@@ -498,14 +499,14 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
continue; continue;
} }
auto &input_info = output_map->at(op_def->input(i)); auto &input_info = output_map->at(op_def->input(i));
if (input_info.data_format == DataFormat::DF_NONE
|| input_info.shape.size() != 4) {
continue;
}
// check whether to do transform // check whether to do transform
MemoryType src_mem_type = input_info.mem_type; MemoryType src_mem_type = input_info.mem_type;
MemoryType dst_mem_type = context->GetInputMemType(i); MemoryType dst_mem_type = context->GetInputMemType(i);
if (src_mem_type != dst_mem_type) { auto wanted_input_dtype = context->GetInputDataType(i);
if (src_mem_type != dst_mem_type ||
(input_info.dtype != wanted_input_dtype &&
(src_mem_type != MemoryType::CPU_BUFFER
|| dst_mem_type != MemoryType::CPU_BUFFER))) {
auto transformed_name = TransformedName(op_def->input(i), auto transformed_name = TransformedName(op_def->input(i),
"mem_type", "mem_type",
dst_mem_type); dst_mem_type);
...@@ -521,7 +522,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType( ...@@ -521,7 +522,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
op_def->input(i), op_def->input(i),
input_info.shape, input_info.shape,
transformed_name, transformed_name,
context->GetInputDataType(i), wanted_input_dtype,
context->GetInputOpenCLBufferType(i), context->GetInputOpenCLBufferType(i),
dst_mem_type, dst_mem_type,
input_info.data_format, input_info.data_format,
...@@ -529,7 +530,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType( ...@@ -529,7 +530,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
// set data format arg // set data format arg
SetProtoArg<int>(transformed_op_def, SetProtoArg<int>(transformed_op_def,
"data_format", "data_format",
input_info.data_format); static_cast<int>(input_info.data_format));
// set output memory type argument // set output memory type argument
SetProtoArg<int>(transformed_op_def, SetProtoArg<int>(transformed_op_def,
OutputMemoryTypeTagName(), OutputMemoryTypeTagName(),
...@@ -564,7 +565,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType( ...@@ -564,7 +565,7 @@ MaceStatus NetDefAdapter::AdaptMemoryType(
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
std::string NetDefAdapter::DebugString(const mace::NetDef *net_def) { std::string NetDefAdapter::DebugString(const NetDef *net_def) {
std::stringstream sstream; std::stringstream sstream;
auto DeviceTypeToStrFunc = [](DeviceType device_type) -> std::string { auto DeviceTypeToStrFunc = [](DeviceType device_type) -> std::string {
if (device_type == DeviceType::CPU) { if (device_type == DeviceType::CPU) {
...@@ -591,10 +592,10 @@ std::string NetDefAdapter::DebugString(const mace::NetDef *net_def) { ...@@ -591,10 +592,10 @@ std::string NetDefAdapter::DebugString(const mace::NetDef *net_def) {
return "NHWC"; return "NHWC";
} else if (type == DataFormat::NCHW) { } else if (type == DataFormat::NCHW) {
return "NCHW"; return "NCHW";
} else if (type == DataFormat::DF_NONE) { } else if (type == DataFormat::NONE) {
return "DF_NONE"; return "NONE";
} else if (type == DataFormat::DF_AUTO) { } else if (type == DataFormat::AUTO) {
return "DT_AUTO"; return "AUTO";
} else if (type == DataFormat::OIHW) { } else if (type == DataFormat::OIHW) {
return "OIHW"; return "OIHW";
} else { } else {
...@@ -615,7 +616,7 @@ std::string NetDefAdapter::DebugString(const mace::NetDef *net_def) { ...@@ -615,7 +616,7 @@ std::string NetDefAdapter::DebugString(const mace::NetDef *net_def) {
std::string data_format = DataFormatToStrFunc( std::string data_format = DataFormatToStrFunc(
static_cast<DataFormat>( static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "data_format", 0))); op, "data_format", static_cast<int>(DataFormat::NONE))));
sstream << std::endl; sstream << std::endl;
sstream << "{" << std::endl; sstream << "{" << std::endl;
......
...@@ -32,16 +32,22 @@ class OpRegistryBase; ...@@ -32,16 +32,22 @@ class OpRegistryBase;
class Workspace; class Workspace;
class Device; class Device;
/** /// Conventions:
* Conventions: /// 1. DataFormat::AUTO stands for formatted (NHWC or NCHW)
* 1. DataFormat::DT_AUTO stands for formatted (NHWC or NCHW) /// 2. if Op with DataFormat::AUTO, the arguments of this op
* 2. if Op with DataFormat::DT_AUTO, the arguments of this op /// is formatted to NHWC
* is formatted to NHWC
*/
class NetDefAdapter { class NetDefAdapter {
public: public:
NetDefAdapter(const OpRegistryBase *op_registry, NetDefAdapter(const OpRegistryBase *op_registry,
const Workspace *ws); const Workspace *ws);
// Adapt original net_def to a better net.
// 1. Adapt device: choose best device for every op in the net.
// 2. Adapt data type: Add data type related transform ops
// for mixing precision.
// 3. Adapt data format: confirm data format of every op
// and add transpose if necessary.
// 4. Adapt memory type: Add BufferTransform if necessary
// for transforming memory type between ops.
MaceStatus AdaptNetDef( MaceStatus AdaptNetDef(
const NetDef *net_def, const NetDef *net_def,
Device *target_device, Device *target_device,
...@@ -91,12 +97,12 @@ class NetDefAdapter { ...@@ -91,12 +97,12 @@ class NetDefAdapter {
NetDef *target_net_def); NetDef *target_net_def);
MaceStatus AdaptMemoryType( MaceStatus AdaptMemoryType(
mace::OpConditionContext *context, OpConditionContext *context,
mace::OperatorDef *op_def, OperatorDef *op_def,
TensorInfoMap *output_map, TensorInfoMap *output_map,
std::unordered_set<std::string> *transformed_set, std::unordered_set<std::string> *transformed_set,
MemoryType *op_output_mem_types, MemoryType *op_output_mem_types,
mace::NetDef *target_net_def); NetDef *target_net_def);
std::string DebugString(const NetDef *net_def); std::string DebugString(const NetDef *net_def);
......
...@@ -19,10 +19,10 @@ ...@@ -19,10 +19,10 @@
namespace mace { namespace mace {
DeviceType NetOptimizer::SelectBestDevice( DeviceType NetOptimizer::SelectBestDevice(
const mace::OperatorDef *op_def, const OperatorDef *op_def,
DeviceType target_device_type, DeviceType target_device_type,
const std::set<mace::DeviceType> &available_devices, const std::set<DeviceType> &available_devices,
const std::vector<mace::DeviceType> &inputs_op_devices) { const std::vector<DeviceType> &inputs_op_devices) {
static const std::set<std::string> kComputeIntensiveOps = { static const std::set<std::string> kComputeIntensiveOps = {
"Conv2D", "DepthwiseConv2d", "Deconv2D", "DepthwiseDeconv2d", "Conv2D", "DepthwiseConv2d", "Deconv2D", "DepthwiseDeconv2d",
"FullyConnected" "FullyConnected"
......
...@@ -23,8 +23,21 @@ ...@@ -23,8 +23,21 @@
namespace mace { namespace mace {
/// Any optimization for Net could be put in here in the future.
class NetOptimizer { class NetOptimizer {
public: public:
/// Select best device for the op to support mixing usage of CPU and GPU.
/// Greedy strategy: one way to the end. If the op fallback to CPU, then
/// the follow-up ops will run on CPU too util meet
/// some compute-intensive ops(Convolution) to
/// reduce the memory copy between CPU and GPU.
/// Simple but effective.
///
/// \param op_def the op
/// \param target_device target device to run on
/// \param available_devices available devices of the op
/// \param inputs_op_devices devices of father ops run on
/// \return Best device for the op_def
DeviceType SelectBestDevice(const OperatorDef *op_def, DeviceType SelectBestDevice(const OperatorDef *op_def,
DeviceType target_device, DeviceType target_device,
const std::set<DeviceType> &available_devices, const std::set<DeviceType> &available_devices,
......
...@@ -21,22 +21,22 @@ ...@@ -21,22 +21,22 @@
namespace mace { namespace mace {
OpConditionContext::OpConditionContext( OpConditionContext::OpConditionContext(
const mace::Workspace *ws, const Workspace *ws,
mace::OpConditionContext::TensorShapeMap *info) OpConditionContext::TensorShapeMap *info)
: operator_def_(nullptr), : operator_def_(nullptr),
ws_(ws), ws_(ws),
device_(nullptr), device_(nullptr),
tensor_shape_info_(info) {} tensor_shape_info_(info) {}
void OpConditionContext::set_operator_def( void OpConditionContext::set_operator_def(
const mace::OperatorDef *operator_def) { const OperatorDef *operator_def) {
operator_def_ = operator_def; operator_def_ = operator_def;
input_data_types_.clear(); input_data_types_.clear();
} }
void OpConditionContext::SetInputInfo(size_t idx, void OpConditionContext::SetInputInfo(size_t idx,
mace::MemoryType mem_type, MemoryType mem_type,
mace::DataType dt) { DataType dt) {
if (input_mem_types_.empty()) { if (input_mem_types_.empty()) {
// the default inputs' memory types are same as output memory type. // the default inputs' memory types are same as output memory type.
input_mem_types_.resize(operator_def_->input_size(), output_mem_type_); input_mem_types_.resize(operator_def_->input_size(), output_mem_type_);
...@@ -53,7 +53,7 @@ void OpConditionContext::SetInputInfo(size_t idx, ...@@ -53,7 +53,7 @@ void OpConditionContext::SetInputInfo(size_t idx,
input_data_types_[idx] = dt; input_data_types_[idx] = dt;
} }
void OpConditionContext::set_output_mem_type(mace::MemoryType type) { void OpConditionContext::set_output_mem_type(MemoryType type) {
MACE_CHECK(operator_def_ != nullptr); MACE_CHECK(operator_def_ != nullptr);
output_mem_type_ = type; output_mem_type_ = type;
input_mem_types_.clear(); input_mem_types_.clear();
...@@ -106,7 +106,7 @@ OpConstructContext::OpConstructContext(Workspace *ws) ...@@ -106,7 +106,7 @@ OpConstructContext::OpConstructContext(Workspace *ws)
device_(nullptr) {} device_(nullptr) {}
void OpConstructContext::set_operator_def( void OpConstructContext::set_operator_def(
std::shared_ptr<mace::OperatorDef> operator_def) { std::shared_ptr<OperatorDef> operator_def) {
operator_def_ = operator_def; operator_def_ = operator_def;
} }
...@@ -225,9 +225,20 @@ OpRegistrationInfo::OpRegistrationInfo() { ...@@ -225,9 +225,20 @@ OpRegistrationInfo::OpRegistrationInfo() {
context->set_output_mem_type(MemoryType::CPU_BUFFER); context->set_output_mem_type(MemoryType::CPU_BUFFER);
} }
}; };
data_format_selector = [](OpConditionContext *context)
-> std::vector<DataFormat> {
DataFormat op_data_format =
static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*context->operator_def(), "data_format",
static_cast<int>(DataFormat::NONE)));
return std::vector<DataFormat>(context->operator_def()->input_size(),
op_data_format);
};
} }
void OpRegistrationInfo::AddDevice(mace::DeviceType device) { void OpRegistrationInfo::AddDevice(DeviceType device) {
devices.insert(device); devices.insert(device);
} }
...@@ -239,9 +250,9 @@ void OpRegistrationInfo::Register(const std::string &key, OpCreator creator) { ...@@ -239,9 +250,9 @@ void OpRegistrationInfo::Register(const std::string &key, OpCreator creator) {
MaceStatus OpRegistryBase::Register( MaceStatus OpRegistryBase::Register(
const std::string &op_type, const std::string &op_type,
const mace::DeviceType device_type, const DeviceType device_type,
const mace::DataType dt, const DataType dt,
mace::OpRegistrationInfo::OpCreator creator) { OpRegistrationInfo::OpCreator creator) {
if (registry_.count(op_type) == 0) { if (registry_.count(op_type) == 0) {
registry_[op_type] = std::unique_ptr<OpRegistrationInfo>( registry_[op_type] = std::unique_ptr<OpRegistrationInfo>(
new OpRegistrationInfo); new OpRegistrationInfo);
...@@ -277,12 +288,20 @@ const std::set<DeviceType> OpRegistryBase::AvailableDevices( ...@@ -277,12 +288,20 @@ const std::set<DeviceType> OpRegistryBase::AvailableDevices(
void OpRegistryBase::GetInOutMemoryTypes( void OpRegistryBase::GetInOutMemoryTypes(
const std::string &op_type, const std::string &op_type,
mace::OpConditionContext *context) const { OpConditionContext *context) const {
MACE_CHECK(registry_.count(op_type) != 0, MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered."); op_type, " operation is not registered.");
return registry_.at(op_type)->memory_type_setter(context); return registry_.at(op_type)->memory_type_setter(context);
} }
const std::vector<DataFormat> OpRegistryBase::InputsDataFormat(
const std::string &op_type,
OpConditionContext *context) const {
MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered.");
return registry_.at(op_type)->data_format_selector(context);
}
std::unique_ptr<Operation> OpRegistryBase::CreateOperation( std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
OpConstructContext *context, OpConstructContext *context,
DeviceType device_type) const { DeviceType device_type) const {
...@@ -321,11 +340,17 @@ OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc( ...@@ -321,11 +340,17 @@ OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc(
} }
OpConditionBuilder& OpConditionBuilder::SetInputMemoryTypeSetter( OpConditionBuilder& OpConditionBuilder::SetInputMemoryTypeSetter(
mace::OpRegistrationInfo::MemoryTypeSetter setter) { OpRegistrationInfo::MemoryTypeSetter setter) {
memory_type_setter_ = setter; memory_type_setter_ = setter;
return *this; return *this;
} }
OpConditionBuilder& OpConditionBuilder::SetInputsDataFormatSelector(
OpRegistrationInfo::DataFormatSelector selector) {
data_format_selector_ = selector;
return *this;
}
void OpConditionBuilder::Finalize(OpRegistrationInfo *info) const { void OpConditionBuilder::Finalize(OpRegistrationInfo *info) const {
if (info != nullptr) { if (info != nullptr) {
if (placer_) { if (placer_) {
...@@ -334,6 +359,10 @@ void OpConditionBuilder::Finalize(OpRegistrationInfo *info) const { ...@@ -334,6 +359,10 @@ void OpConditionBuilder::Finalize(OpRegistrationInfo *info) const {
if (memory_type_setter_) { if (memory_type_setter_) {
info->memory_type_setter = memory_type_setter_; info->memory_type_setter = memory_type_setter_;
} }
if (data_format_selector_) {
info->data_format_selector = data_format_selector_;
}
} }
} }
......
...@@ -117,6 +117,14 @@ class OpConstructContext { ...@@ -117,6 +117,14 @@ class OpConstructContext {
inline Device *device() const { inline Device *device() const {
return device_; return device_;
} }
#ifdef MACE_ENABLE_OPENCL
inline MemoryType GetOpMemoryType() const {
return static_cast<MemoryType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def_, OutputMemoryTypeTagName(),
static_cast<int>(MemoryType::CPU_BUFFER)));
}
#endif // MACE_ENABLE_OPENCL
private: private:
std::shared_ptr<OperatorDef> operator_def_; std::shared_ptr<OperatorDef> operator_def_;
...@@ -270,6 +278,9 @@ class OpConditionBuilder { ...@@ -270,6 +278,9 @@ class OpConditionBuilder {
OpConditionBuilder &SetInputMemoryTypeSetter( OpConditionBuilder &SetInputMemoryTypeSetter(
OpRegistrationInfo::MemoryTypeSetter setter); OpRegistrationInfo::MemoryTypeSetter setter);
OpConditionBuilder &SetInputsDataFormatSelector(
OpRegistrationInfo::DataFormatSelector selector);
void Finalize(OpRegistrationInfo *info) const; void Finalize(OpRegistrationInfo *info) const;
private: private:
...@@ -297,6 +308,9 @@ class OpRegistryBase { ...@@ -297,6 +308,9 @@ class OpRegistryBase {
void GetInOutMemoryTypes( void GetInOutMemoryTypes(
const std::string &op_type, OpConditionContext *context) const; const std::string &op_type, OpConditionContext *context) const;
const std::vector<DataFormat> InputsDataFormat(
const std::string &op_type, OpConditionContext *context) const;
std::unique_ptr<Operation> CreateOperation( std::unique_ptr<Operation> CreateOperation(
OpConstructContext *context, OpConstructContext *context,
DeviceType device_type) const; DeviceType device_type) const;
......
...@@ -173,7 +173,7 @@ void OpenCLUtil::BuildTransformOpDef( ...@@ -173,7 +173,7 @@ void OpenCLUtil::BuildTransformOpDef(
arg->set_i(static_cast<int32_t>(dt)); arg->set_i(static_cast<int32_t>(dt));
arg = op_def->add_arg(); arg = op_def->add_arg();
arg->set_name("data_format"); arg->set_name("data_format");
arg->set_i(data_format); arg->set_i(static_cast<int>(data_format));
if (!input_shape.empty()) { if (!input_shape.empty()) {
OutputShape *shape = op_def->add_output_shape(); OutputShape *shape = op_def->add_output_shape();
for (auto value : input_shape) { for (auto value : input_shape) {
......
...@@ -269,7 +269,7 @@ MaceStatus Workspace::PreallocateOutputTensor( ...@@ -269,7 +269,7 @@ MaceStatus Workspace::PreallocateOutputTensor(
tensor_mem.second.data_type, tensor_mem.second.data_type,
false, tensor_mem.first)); false, tensor_mem.first));
tensor->set_data_format(tensor_mem.second.data_format); tensor->set_data_format(tensor_mem.second.data_format);
if (tensor_mem.second.data_format != DataFormat::DF_NONE) { if (tensor_mem.second.data_format != DataFormat::NONE) {
if (mem_blocks[tensor_mem.second.mem_id].mem_type() if (mem_blocks[tensor_mem.second.mem_id].mem_type()
== MemoryType::GPU_IMAGE) { == MemoryType::GPU_IMAGE) {
VLOG(1) << "Tensor: " << tensor_mem.first VLOG(1) << "Tensor: " << tensor_mem.first
......
...@@ -94,7 +94,7 @@ DataFormat ParseDataFormat(const std::string &data_format_str) { ...@@ -94,7 +94,7 @@ DataFormat ParseDataFormat(const std::string &data_format_str) {
} else if (data_format_str == "OIHW") { } else if (data_format_str == "OIHW") {
return DataFormat::OIHW; return DataFormat::OIHW;
} else { } else {
return DataFormat::DF_NONE; return DataFormat::NONE;
} }
} }
......
...@@ -143,7 +143,7 @@ void BMNet::SetUp() { ...@@ -143,7 +143,7 @@ void BMNet::SetUp() {
// Add input and output information // Add input and output information
for (size_t i = 0; i < input_names_.size(); ++i) { for (size_t i = 0; i < input_names_.size(); ++i) {
InputOutputInfo *info = net_.add_input_info(); InputOutputInfo *info = net_.add_input_info();
info->set_data_format(DataFormat::NHWC); info->set_data_format(static_cast<int>(DataFormat::NHWC));
info->set_name(input_names_[i]); info->set_name(input_names_[i]);
for (auto d : input_shapes_[i]) { for (auto d : input_shapes_[i]) {
info->add_dims(static_cast<int>(d)); info->add_dims(static_cast<int>(d));
...@@ -244,7 +244,7 @@ void BMNet::AddConv(const std::string &conv_type, ...@@ -244,7 +244,7 @@ void BMNet::AddConv(const std::string &conv_type,
op_def->add_output(output_name); op_def->add_output(output_name);
AddIntsArg(op_def, "strides", strides); AddIntsArg(op_def, "strides", strides);
AddIntArg(op_def, "padding", padding_type); AddIntArg(op_def, "padding", padding_type);
AddIntArg(op_def, "has_data_format", 1); AddIntArg(op_def, "data_format", static_cast<int>(DataFormat::AUTO));
AddIntArg(op_def, "T", DT_HALF); AddIntArg(op_def, "T", DT_HALF);
if (has_relu6) { if (has_relu6) {
AddStringArg(op_def, "activation", "RELUX"); AddStringArg(op_def, "activation", "RELUX");
...@@ -271,7 +271,7 @@ void BMNet::AddEltwise(const std::string &op_name, ...@@ -271,7 +271,7 @@ void BMNet::AddEltwise(const std::string &op_name,
op_def->add_output(output); op_def->add_output(output);
AddIntArg(op_def, "type", type); AddIntArg(op_def, "type", type);
AddIntArg(op_def, "T", DT_HALF); AddIntArg(op_def, "T", DT_HALF);
AddIntArg(op_def, "has_data_format", 1); AddIntArg(op_def, "data_format", static_cast<int>(DataFormat::AUTO));
OutputShape *shape = op_def->add_output_shape(); OutputShape *shape = op_def->add_output_shape();
for (auto dim : output_shape) { for (auto dim : output_shape) {
shape->add_dims(dim); shape->add_dims(dim);
......
...@@ -283,9 +283,9 @@ MaceTensor::MaceTensor(const std::vector<int64_t> &shape, ...@@ -283,9 +283,9 @@ MaceTensor::MaceTensor(const std::vector<int64_t> &shape,
std::shared_ptr<void> data, std::shared_ptr<void> data,
const DataFormat format) { const DataFormat format) {
MACE_CHECK_NOTNULL(data.get()); MACE_CHECK_NOTNULL(data.get());
MACE_CHECK(format == DataFormat::DF_NONE || format == DataFormat::NHWC MACE_CHECK(format == DataFormat::NONE || format == DataFormat::NHWC
|| format == DataFormat::NCHW || format == OIHW, || format == DataFormat::NCHW || format == DataFormat::OIHW,
"MACE only support DF_NONE, NHWC, NCHW and OIHW " "MACE only support NONE, NHWC, NCHW and OIHW "
"formats of input now."); "formats of input now.");
impl_ = make_unique<MaceTensor::Impl>(); impl_ = make_unique<MaceTensor::Impl>();
impl_->shape = shape; impl_->shape = shape;
...@@ -496,7 +496,7 @@ MaceStatus MaceEngine::Impl::Init( ...@@ -496,7 +496,7 @@ MaceStatus MaceEngine::Impl::Init(
DataType output_dt = output_info_map_[output_name].data_type(); DataType output_dt = output_info_map_[output_name].data_type();
Tensor *output_tensor = Tensor *output_tensor =
ws_->CreateTensor(output_name, device_->allocator(), output_dt); ws_->CreateTensor(output_name, device_->allocator(), output_dt);
output_tensor->set_data_format(NHWC); output_tensor->set_data_format(DataFormat::NHWC);
#endif #endif
} }
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA) #if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
...@@ -585,14 +585,14 @@ MaceEngine::Impl::~Impl() { ...@@ -585,14 +585,14 @@ MaceEngine::Impl::~Impl() {
MaceStatus MaceEngine::Impl::TransposeInput( MaceStatus MaceEngine::Impl::TransposeInput(
const std::pair<const std::string, MaceTensor> &input, const std::pair<const std::string, MaceTensor> &input,
Tensor *input_tensor) { Tensor *input_tensor) {
bool has_data_format = input_tensor->data_format() != DataFormat::DF_NONE; bool has_data_format = input_tensor->data_format() != DataFormat::NONE;
DataFormat data_format = DataFormat::DF_NONE; DataFormat data_format = DataFormat::NONE;
DataType input_dt = input_tensor->dtype(); DataType input_dt = input_tensor->dtype();
if (has_data_format) { if (has_data_format) {
std::vector<int> dst_dims; std::vector<int> dst_dims;
if (device_->device_type() == DeviceType::CPU && if (device_->device_type() == DeviceType::CPU &&
input.second.shape().size() == 4 && input.second.shape().size() == 4 &&
input.second.data_format() == NHWC && input.second.data_format() == DataFormat::NHWC &&
!is_quantized_model_) { !is_quantized_model_) {
VLOG(1) << "Transform input " << input.first << " from NHWC to NCHW"; VLOG(1) << "Transform input " << input.first << " from NHWC to NCHW";
input_tensor->set_data_format(DataFormat::NCHW); input_tensor->set_data_format(DataFormat::NCHW);
...@@ -654,28 +654,28 @@ MaceStatus MaceEngine::Impl::TransposeOutput( ...@@ -654,28 +654,28 @@ MaceStatus MaceEngine::Impl::TransposeOutput(
DataType output_dt = output_tensor->dtype(); DataType output_dt = output_tensor->dtype();
// save output // save output
if (output_tensor != nullptr && output->second.data() != nullptr) { if (output_tensor != nullptr && output->second.data() != nullptr) {
if (output_tensor->data_format() != DataFormat::DF_NONE && if (output_tensor->data_format() != DataFormat::NONE &&
output->second.data_format() != DataFormat::DF_NONE && output->second.data_format() != DataFormat::NONE &&
output->second.shape().size() == 4 && output->second.shape().size() == 4 &&
output->second.data_format() != output_tensor->data_format()) { output->second.data_format() != output_tensor->data_format()) {
VLOG(1) << "Transform output " << output->first << " from " VLOG(1) << "Transform output " << output->first << " from "
<< output_tensor->data_format() << " to " << static_cast<int>(output_tensor->data_format()) << " to "
<< output->second.data_format(); << static_cast<int>(output->second.data_format());
std::vector<int> dst_dims; std::vector<int> dst_dims;
if (output_tensor->data_format() == NCHW && if (output_tensor->data_format() == DataFormat::NCHW &&
output->second.data_format() == NHWC) { output->second.data_format() == DataFormat::NHWC) {
dst_dims = {0, 2, 3, 1}; dst_dims = {0, 2, 3, 1};
} else if (output_tensor->data_format() == NHWC && } else if (output_tensor->data_format() == DataFormat::NHWC &&
output->second.data_format() == NCHW) { output->second.data_format() == DataFormat::NCHW) {
dst_dims = {0, 3, 1, 2}; dst_dims = {0, 3, 1, 2};
} else { } else {
LOG(FATAL) << "Not supported output data format: " LOG(FATAL) << "Not supported output data format: "
<< output->second.data_format() << " vs " << static_cast<int>(output->second.data_format()) << " vs "
<< output_tensor->data_format(); << static_cast<int>(output_tensor->data_format());
} }
VLOG(1) << "Transform output " << output->first << " from " VLOG(1) << "Transform output " << output->first << " from "
<< output_tensor->data_format() << " to " << static_cast<int>(output_tensor->data_format()) << " to "
<< output->second.data_format(); << static_cast<int>(output->second.data_format());
std::vector<index_t> shape = std::vector<index_t> shape =
TransposeShape<index_t, index_t>(output_tensor->shape(), TransposeShape<index_t, index_t>(output_tensor->shape(),
dst_dims); dst_dims);
......
...@@ -96,7 +96,7 @@ class ActivationOp<DeviceType::GPU, T> : public Operation { ...@@ -96,7 +96,7 @@ class ActivationOp<DeviceType::GPU, T> : public Operation {
auto leakyrelu_coefficient = static_cast<T>( auto leakyrelu_coefficient = static_cast<T>(
Operation::GetOptionalArg<float>("leakyrelu_coefficient", 0.0f)); Operation::GetOptionalArg<float>("leakyrelu_coefficient", 0.0f));
MemoryType mem_type; MemoryType mem_type;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::ActivationKernel<T>>( kernel_ = make_unique<opencl::image::ActivationKernel<T>>(
type, relux_max_limit, leakyrelu_coefficient); type, relux_max_limit, leakyrelu_coefficient);
...@@ -140,11 +140,13 @@ void RegisterActivation(OpRegistryBase *op_registry) { ...@@ -140,11 +140,13 @@ void RegisterActivation(OpRegistryBase *op_registry) {
.SetDevicePlacerFunc( .SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
int has_data_format = int has_data_format =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "has_data_format", 0); *op, "has_data_format", 0);
if (!has_data_format || if (!has_data_format ||
(op->output_shape_size() != op->output_size()) ||
op->output_shape(0).dims_size() != 4) { op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU }; return { DeviceType::CPU };
} }
......
...@@ -207,7 +207,8 @@ void TestSimplePrelu() { ...@@ -207,7 +207,8 @@ void TestSimplePrelu() {
// Run // Run
net.RunOp(D); net.RunOp(D);
} else { } else {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<D, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Activation", "PreluTest") OpDefBuilder("Activation", "PreluTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Alpha") .Input("Alpha")
...@@ -217,7 +218,8 @@ void TestSimplePrelu() { ...@@ -217,7 +218,8 @@ void TestSimplePrelu() {
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
auto expected = net.CreateTensor<float>( auto expected = net.CreateTensor<float>(
......
...@@ -69,7 +69,7 @@ class AddNOp<DeviceType::GPU, T> : public Operation { ...@@ -69,7 +69,7 @@ class AddNOp<DeviceType::GPU, T> : public Operation {
public: public:
explicit AddNOp(OpConstructContext *context) explicit AddNOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::AddNKernel<T>>(); kernel_ = make_unique<opencl::image::AddNKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -109,11 +109,13 @@ void RegisterAddN(OpRegistryBase *op_registry) { ...@@ -109,11 +109,13 @@ void RegisterAddN(OpRegistryBase *op_registry) {
.SetDevicePlacerFunc( .SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
int has_data_format = int has_data_format =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "has_data_format", 0); *op, "has_data_format", 0);
if (!has_data_format || if (!has_data_format ||
(op->output_shape_size() != op->output_size()) ||
op->output_shape(0).dims_size() != 4) { op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU }; return { DeviceType::CPU };
} }
......
...@@ -54,7 +54,7 @@ MaceStatus Deconv2dBase::ResizeOutAndPadOut( ...@@ -54,7 +54,7 @@ MaceStatus Deconv2dBase::ResizeOutAndPadOut(
out_pad_size, out_pad_size,
&padded_out_shape, &padded_out_shape,
framework_type_, framework_type_,
NCHW); DataFormat::NCHW);
MACE_RETURN_IF_ERROR(output->Resize(out_shape)); MACE_RETURN_IF_ERROR(output->Resize(out_shape));
......
...@@ -174,7 +174,7 @@ class BatchNormOp<DeviceType::GPU, T> : public Operation { ...@@ -174,7 +174,7 @@ class BatchNormOp<DeviceType::GPU, T> : public Operation {
float leakyrelu_coefficient = Operation::GetOptionalArg<float>( float leakyrelu_coefficient = Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f); "leakyrelu_coefficient", 0.0f);
MemoryType mem_type; MemoryType mem_type;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::BatchNormKernel<T>>( kernel_ = make_unique<opencl::image::BatchNormKernel<T>>(
epsilon, activation, relux_max_limit, leakyrelu_coefficient); epsilon, activation, relux_max_limit, leakyrelu_coefficient);
......
...@@ -34,7 +34,8 @@ void Simple() { ...@@ -34,7 +34,8 @@ void Simple() {
net.AddInputFromArray<D, float>("Var", {1}, {11.67f}, true); net.AddInputFromArray<D, float>("Var", {1}, {11.67f}, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<D, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Scale") .Input("Scale")
...@@ -47,7 +48,8 @@ void Simple() { ...@@ -47,7 +48,8 @@ void Simple() {
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input") .Input("Input")
...@@ -93,8 +95,8 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { ...@@ -93,8 +95,8 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
...@@ -112,8 +114,8 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { ...@@ -112,8 +114,8 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -163,8 +165,8 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) { ...@@ -163,8 +165,8 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -179,8 +181,8 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) { ...@@ -179,8 +181,8 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -230,8 +232,8 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) { ...@@ -230,8 +232,8 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -246,8 +248,8 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) { ...@@ -246,8 +248,8 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -296,8 +298,8 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -296,8 +298,8 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -312,8 +314,8 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -312,8 +314,8 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
......
...@@ -264,7 +264,7 @@ class BatchToSpaceNDOp<DeviceType::GPU, T> : public BatchToSpaceOpBase { ...@@ -264,7 +264,7 @@ class BatchToSpaceNDOp<DeviceType::GPU, T> : public BatchToSpaceOpBase {
public: public:
explicit BatchToSpaceNDOp(OpConstructContext *context) explicit BatchToSpaceNDOp(OpConstructContext *context)
: BatchToSpaceOpBase(context) { : BatchToSpaceOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::BatchToSpaceKernel<T>>(); kernel_ = make_unique<opencl::image::BatchToSpaceKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
......
...@@ -103,7 +103,7 @@ class BiasAddOp<DeviceType::GPU, T> : public Operation { ...@@ -103,7 +103,7 @@ class BiasAddOp<DeviceType::GPU, T> : public Operation {
: Operation(context), : Operation(context),
has_data_format_(Operation::GetOptionalArg<int>("has_data_format", 1)) { has_data_format_(Operation::GetOptionalArg<int>("has_data_format", 1)) {
MemoryType mem_type = MemoryType::CPU_BUFFER; MemoryType mem_type = MemoryType::CPU_BUFFER;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::BiasAddKernel<T>>(); kernel_ = make_unique<opencl::image::BiasAddKernel<T>>();
} else { } else {
...@@ -151,11 +151,13 @@ void RegisterBiasAdd(OpRegistryBase *op_registry) { ...@@ -151,11 +151,13 @@ void RegisterBiasAdd(OpRegistryBase *op_registry) {
.SetDevicePlacerFunc( .SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
int has_data_format = int has_data_format =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "has_data_format", 0); *op, "has_data_format", 0);
if (!has_data_format || if (!has_data_format ||
(op->output_shape_size() != op->output_size()) ||
op->output_shape(0).dims_size() != 4) { op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU }; return { DeviceType::CPU };
} }
......
...@@ -27,9 +27,7 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) { ...@@ -27,9 +27,7 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) {
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
DataFormat data_format = NHWC;
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
data_format = NCHW;
net.AddRandomInput<D, T>("Input", {batch, channels, height, width}); net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
net.AddRandomInput<D, T>("Input", {batch, height, width, channels}); net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
......
...@@ -31,8 +31,8 @@ void BiasAddSimple() { ...@@ -31,8 +31,8 @@ void BiasAddSimple() {
net.AddInputFromArray<D, float>("Bias", {1}, {0.5f}, true); net.AddInputFromArray<D, float>("Bias", {1}, {0.5f}, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Bias") .Input("Bias")
...@@ -41,8 +41,8 @@ void BiasAddSimple() { ...@@ -41,8 +41,8 @@ void BiasAddSimple() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("Input") .Input("Input")
...@@ -83,8 +83,8 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) { ...@@ -83,8 +83,8 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) {
{batch, height, width, channels}); {batch, height, width, channels});
net.AddRandomInput<DeviceType::GPU, float>("Bias", {channels}, true, true); net.AddRandomInput<DeviceType::GPU, float>("Bias", {channels}, true, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
...@@ -97,8 +97,8 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) { ...@@ -97,8 +97,8 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -132,8 +132,8 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) { ...@@ -132,8 +132,8 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) {
{batch, height, width, channels}); {batch, height, width, channels});
net.AddRandomInput<DeviceType::GPU, float>("Bias", {channels}, true, true); net.AddRandomInput<DeviceType::GPU, float>("Bias", {channels}, true, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
...@@ -146,8 +146,8 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) { ...@@ -146,8 +146,8 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
......
...@@ -48,7 +48,6 @@ void FilterBufferToImage(int iters, ...@@ -48,7 +48,6 @@ void FilterBufferToImage(int iters,
OpenCLBufferType::IN_OUT_CHANNEL, OpenCLBufferType::IN_OUT_CHANNEL,
MemoryType::GPU_IMAGE, MemoryType::GPU_IMAGE,
0, 0,
DataFormat::NHWC,
b2i_output); b2i_output);
}; };
......
...@@ -37,14 +37,14 @@ void TestBidirectionTransform(const OpenCLBufferType type, ...@@ -37,14 +37,14 @@ void TestBidirectionTransform(const OpenCLBufferType type,
OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE) OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE)
.Transform(&context, net.ws()->GetTensor("Input"), .Transform(&context, net.ws()->GetTensor("Input"),
type, MemoryType::GPU_IMAGE, 0, DataFormat::NHWC, b2i_output); type, MemoryType::GPU_IMAGE, 0, b2i_output);
// Inverse Transform // Inverse Transform
Tensor *i2b_output = net.ws()->CreateTensor( Tensor *i2b_output = net.ws()->CreateTensor(
"I2BOutput", context.device()->allocator(), DataTypeToEnum<T>::value); "I2BOutput", context.device()->allocator(), DataTypeToEnum<T>::value);
OpenCLBufferTransformer<T>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER) OpenCLBufferTransformer<T>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER)
.Transform(&context, b2i_output, .Transform(&context, b2i_output,
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, i2b_output); type, MemoryType::GPU_BUFFER, 0, i2b_output);
// Check // Check
ExpectTensorNear<T>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), ExpectTensorNear<T>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"),
...@@ -178,14 +178,14 @@ void TestDiffTypeBidirectionTransform(const OpenCLBufferType type, ...@@ -178,14 +178,14 @@ void TestDiffTypeBidirectionTransform(const OpenCLBufferType type,
OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE) OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE)
.Transform(&context, net.ws()->GetTensor("Input"), .Transform(&context, net.ws()->GetTensor("Input"),
type, MemoryType::GPU_IMAGE, 0, DataFormat::NHWC, b2i_output); type, MemoryType::GPU_IMAGE, 0, b2i_output);
// Inverse Transform // Inverse Transform
Tensor *i2b_output = net.ws()->CreateTensor( Tensor *i2b_output = net.ws()->CreateTensor(
"I2BOutput", context.device()->allocator(), DT_FLOAT); "I2BOutput", context.device()->allocator(), DT_FLOAT);
OpenCLBufferTransformer<float>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER) OpenCLBufferTransformer<float>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER)
.Transform(&context, b2i_output, .Transform(&context, b2i_output,
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, i2b_output); type, MemoryType::GPU_BUFFER, 0, i2b_output);
// Check // Check
ExpectTensorNear<float>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), ExpectTensorNear<float>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"),
...@@ -218,14 +218,14 @@ void TestStringHalfBidirectionTransform(const OpenCLBufferType type, ...@@ -218,14 +218,14 @@ void TestStringHalfBidirectionTransform(const OpenCLBufferType type,
// Transform // Transform
OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE) OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE)
.Transform(&context, net.ws()->GetTensor("Input"), .Transform(&context, net.ws()->GetTensor("Input"),
type, MemoryType::GPU_IMAGE, 0, DataFormat::NHWC, b2i_output); type, MemoryType::GPU_IMAGE, 0, b2i_output);
// Inverse Transform // Inverse Transform
Tensor *i2b_output = net.ws()->CreateTensor( Tensor *i2b_output = net.ws()->CreateTensor(
"I2BOutput", context.device()->allocator(), DataTypeToEnum<T>::value); "I2BOutput", context.device()->allocator(), DataTypeToEnum<T>::value);
OpenCLBufferTransformer<T>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER) OpenCLBufferTransformer<T>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER)
.Transform(&context, b2i_output, .Transform(&context, b2i_output,
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, i2b_output); type, MemoryType::GPU_BUFFER, 0, i2b_output);
// Check // Check
ExpectTensorNear<half>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), ExpectTensorNear<half>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"),
......
...@@ -39,14 +39,11 @@ class BufferTransformOp<DeviceType::GPU, T> : public Operation { ...@@ -39,14 +39,11 @@ class BufferTransformOp<DeviceType::GPU, T> : public Operation {
auto type = auto type =
static_cast<OpenCLBufferType>(Operation::GetOptionalArg<int>( static_cast<OpenCLBufferType>(Operation::GetOptionalArg<int>(
"buffer_type", static_cast<int>(CONV2D_FILTER))); "buffer_type", static_cast<int>(CONV2D_FILTER)));
DataFormat data_format = static_cast<DataFormat>(
Operation::GetOptionalArg<int>("data_format", DataFormat::DF_NONE));
MemoryType in_mem_type = context->workspace()->GetTensor( MemoryType in_mem_type = context->workspace()->GetTensor(
operator_def_->input(0))->memory_type(); operator_def_->input(0))->memory_type();
return OpenCLBufferTransformer<T>(in_mem_type, out_mem_type_).Transform( return OpenCLBufferTransformer<T>(in_mem_type, out_mem_type_).Transform(
context, input, type, out_mem_type_, wino_blk_size_, context, input, type, out_mem_type_, wino_blk_size_, output);
data_format, output);
} }
private: private:
......
...@@ -48,7 +48,7 @@ void TestBidirectionTransform(const OpenCLBufferType type, ...@@ -48,7 +48,7 @@ void TestBidirectionTransform(const OpenCLBufferType type,
OpenCLBufferTransformer<DstType>(MemoryType::GPU_BUFFER, OpenCLBufferTransformer<DstType>(MemoryType::GPU_BUFFER,
MemoryType::GPU_BUFFER) MemoryType::GPU_BUFFER)
.Transform(&context, net.ws()->GetTensor("Input"), .Transform(&context, net.ws()->GetTensor("Input"),
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, bt_output); type, MemoryType::GPU_BUFFER, 0, bt_output);
// Inverse Transform // Inverse Transform
Tensor *output = net.ws()->CreateTensor( Tensor *output = net.ws()->CreateTensor(
...@@ -57,7 +57,7 @@ void TestBidirectionTransform(const OpenCLBufferType type, ...@@ -57,7 +57,7 @@ void TestBidirectionTransform(const OpenCLBufferType type,
OpenCLBufferTransformer<OrgType>(MemoryType::GPU_BUFFER, OpenCLBufferTransformer<OrgType>(MemoryType::GPU_BUFFER,
MemoryType::GPU_BUFFER) MemoryType::GPU_BUFFER)
.Transform(&context, bt_output, .Transform(&context, bt_output,
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, output); type, MemoryType::GPU_BUFFER, 0, output);
if (DataTypeToEnum<OrgType>::value == DataTypeToEnum<DstType>::value) { if (DataTypeToEnum<OrgType>::value == DataTypeToEnum<DstType>::value) {
EXPECT_EQ(net.GetOutput("Input")->UnderlyingBuffer(), EXPECT_EQ(net.GetOutput("Input")->UnderlyingBuffer(),
...@@ -94,7 +94,7 @@ void TestArgumentTransform(const index_t input_size) { ...@@ -94,7 +94,7 @@ void TestArgumentTransform(const index_t input_size) {
MemoryType::GPU_BUFFER) MemoryType::GPU_BUFFER)
.Transform(&context, net.ws()->GetTensor("Input"), .Transform(&context, net.ws()->GetTensor("Input"),
OpenCLBufferType::ARGUMENT, MemoryType::GPU_BUFFER, OpenCLBufferType::ARGUMENT, MemoryType::GPU_BUFFER,
0, DataFormat::NHWC, output); 0, output);
index_t expected_size = RoundUp<index_t>(input_size, 4); index_t expected_size = RoundUp<index_t>(input_size, 4);
EXPECT_EQ(expected_size, output->buffer_shape()[0]); EXPECT_EQ(expected_size, output->buffer_shape()[0]);
......
...@@ -82,7 +82,7 @@ class ChannelShuffleOp<DeviceType::GPU, T> : public Operation { ...@@ -82,7 +82,7 @@ class ChannelShuffleOp<DeviceType::GPU, T> : public Operation {
explicit ChannelShuffleOp(OpConstructContext *context) explicit ChannelShuffleOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
const int groups = Operation::GetOptionalArg<int>("group", 1); const int groups = Operation::GetOptionalArg<int>("group", 1);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ChannelShuffleKernel<T>>(groups); kernel_ = make_unique<opencl::image::ChannelShuffleKernel<T>>(groups);
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -119,7 +119,7 @@ void RegisterChannelShuffle(OpRegistryBase *op_registry) { ...@@ -119,7 +119,7 @@ void RegisterChannelShuffle(OpRegistryBase *op_registry) {
[](OpConditionContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) { if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU }; return { DeviceType::CPU, DeviceType::GPU };
} }
int groups = ProtoArgHelper::GetOptionalArg<OperatorDef, int>( int groups = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "group", 1); *op, "group", 1);
......
...@@ -28,8 +28,8 @@ TEST_F(ChannelShuffleOpTest, C8G4_CPU) { ...@@ -28,8 +28,8 @@ TEST_F(ChannelShuffleOpTest, C8G4_CPU) {
"Input", {1, 1, 2, 8}, "Input", {1, 1, 2, 8},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("ChannelShuffle", "ChannelShuffleTest") OpDefBuilder("ChannelShuffle", "ChannelShuffleTest")
...@@ -40,8 +40,8 @@ TEST_F(ChannelShuffleOpTest, C8G4_CPU) { ...@@ -40,8 +40,8 @@ TEST_F(ChannelShuffleOpTest, C8G4_CPU) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>( auto expected = net.CreateTensor<float>(
......
...@@ -40,19 +40,19 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, ...@@ -40,19 +40,19 @@ void CalcPaddingAndOutputSize(const index_t *input_shape,
index_t input_height = 0, input_width = 0; index_t input_height = 0, input_width = 0;
index_t kernel_height = 0, kernel_width = 0; index_t kernel_height = 0, kernel_width = 0;
if (input_format == NCHW) { if (input_format == DataFormat::NCHW) {
input_height = input_shape[2]; input_height = input_shape[2];
input_width = input_shape[3]; input_width = input_shape[3];
} else if (input_format == NHWC) { } else if (input_format == DataFormat::NHWC) {
input_height = input_shape[1]; input_height = input_shape[1];
input_width = input_shape[2]; input_width = input_shape[2];
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
if (filter_format == OIHW) { if (filter_format == DataFormat::OIHW) {
kernel_height = filter_shape[2]; kernel_height = filter_shape[2];
kernel_width = filter_shape[3]; kernel_width = filter_shape[3];
} else if (filter_format == OHWI) { } else if (filter_format == DataFormat::OHWI) {
kernel_height = filter_shape[1]; kernel_height = filter_shape[1];
kernel_width = filter_shape[2]; kernel_width = filter_shape[2];
} else { } else {
...@@ -97,11 +97,11 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, ...@@ -97,11 +97,11 @@ void CalcPaddingAndOutputSize(const index_t *input_shape,
0, (output_width - 1) * strides[1] + k_extent_width - input_width); 0, (output_width - 1) * strides[1] + k_extent_width - input_width);
output_shape[0] = input_shape[0]; output_shape[0] = input_shape[0];
if (input_format == NCHW) { if (input_format == DataFormat::NCHW) {
output_shape[1] = output_channels; output_shape[1] = output_channels;
output_shape[2] = output_height; output_shape[2] = output_height;
output_shape[3] = output_width; output_shape[3] = output_width;
} else if (input_format == NHWC) { } else if (input_format == DataFormat::NHWC) {
output_shape[1] = output_height; output_shape[1] = output_height;
output_shape[2] = output_width; output_shape[2] = output_width;
output_shape[3] = output_channels; output_shape[3] = output_channels;
...@@ -117,7 +117,8 @@ void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW ...@@ -117,7 +117,8 @@ void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW
Padding padding, Padding padding,
index_t *output_shape, index_t *output_shape,
int *padding_size) { int *padding_size) {
CalcPaddingAndOutputSize(input_shape, NCHW, filter_shape, OIHW, dilations, CalcPaddingAndOutputSize(input_shape, DataFormat::NCHW, filter_shape,
DataFormat::OIHW, dilations,
strides, padding, output_shape, padding_size); strides, padding, output_shape, padding_size);
} }
...@@ -128,7 +129,8 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC ...@@ -128,7 +129,8 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
Padding padding, Padding padding,
index_t *output_shape, index_t *output_shape,
int *padding_size) { int *padding_size) {
CalcPaddingAndOutputSize(input_shape, NHWC, filter_shape, OIHW, dilations, CalcPaddingAndOutputSize(input_shape, DataFormat::NHWC, filter_shape,
DataFormat::OIHW, dilations,
strides, padding, output_shape, padding_size); strides, padding, output_shape, padding_size);
} }
...@@ -151,19 +153,19 @@ void CalcOutputSize(const index_t *input_shape, ...@@ -151,19 +153,19 @@ void CalcOutputSize(const index_t *input_shape,
index_t input_height = 0, input_width = 0; index_t input_height = 0, input_width = 0;
index_t kernel_height = 0, kernel_width = 0; index_t kernel_height = 0, kernel_width = 0;
if (input_format == NCHW) { if (input_format == DataFormat::NCHW) {
input_height = input_shape[2]; input_height = input_shape[2];
input_width = input_shape[3]; input_width = input_shape[3];
} else if (input_format == NHWC) { } else if (input_format == DataFormat::NHWC) {
input_height = input_shape[1]; input_height = input_shape[1];
input_width = input_shape[2]; input_width = input_shape[2];
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
if (filter_format == OIHW) { if (filter_format == DataFormat::OIHW) {
kernel_height = filter_shape[2]; kernel_height = filter_shape[2];
kernel_width = filter_shape[3]; kernel_width = filter_shape[3];
} else if (filter_format == OHWI) { } else if (filter_format == DataFormat::OHWI) {
kernel_height = filter_shape[1]; kernel_height = filter_shape[1];
kernel_width = filter_shape[2]; kernel_width = filter_shape[2];
} else { } else {
...@@ -195,11 +197,11 @@ void CalcOutputSize(const index_t *input_shape, ...@@ -195,11 +197,11 @@ void CalcOutputSize(const index_t *input_shape,
} }
output_shape[0] = input_shape[0]; output_shape[0] = input_shape[0];
if (input_format == NCHW) { if (input_format == DataFormat::NCHW) {
output_shape[1] = output_channels; output_shape[1] = output_channels;
output_shape[2] = output_height; output_shape[2] = output_height;
output_shape[3] = output_width; output_shape[3] = output_width;
} else if (input_format == NHWC) { } else if (input_format == DataFormat::NHWC) {
output_shape[1] = output_height; output_shape[1] = output_height;
output_shape[2] = output_width; output_shape[2] = output_width;
output_shape[3] = output_channels; output_shape[3] = output_channels;
...@@ -215,7 +217,8 @@ void CalcOutputSize(const index_t *input_shape, // NHWC ...@@ -215,7 +217,8 @@ void CalcOutputSize(const index_t *input_shape, // NHWC
const int *strides, const int *strides,
const RoundType round_type, const RoundType round_type,
index_t *output_shape) { index_t *output_shape) {
CalcOutputSize(input_shape, NHWC, filter_shape, OIHW, padding_size, dilations, CalcOutputSize(input_shape, DataFormat::NHWC, filter_shape,
DataFormat::OIHW, padding_size, dilations,
strides, round_type, output_shape); strides, round_type, output_shape);
} }
...@@ -226,7 +229,8 @@ void CalcNCHWOutputSize(const index_t *input_shape, // NCHW ...@@ -226,7 +229,8 @@ void CalcNCHWOutputSize(const index_t *input_shape, // NCHW
const int *strides, const int *strides,
const RoundType round_type, const RoundType round_type,
index_t *output_shape) { index_t *output_shape) {
CalcOutputSize(input_shape, NCHW, filter_shape, OIHW, padding_size, dilations, CalcOutputSize(input_shape, DataFormat::NCHW, filter_shape,
DataFormat::OIHW, padding_size, dilations,
strides, round_type, output_shape); strides, round_type, output_shape);
} }
...@@ -241,14 +245,18 @@ void CalcDeconvShape_TF(const std::vector<index_t> &input_shape, ...@@ -241,14 +245,18 @@ void CalcDeconvShape_TF(const std::vector<index_t> &input_shape,
std::vector<index_t> *padded_out_shape, std::vector<index_t> *padded_out_shape,
DataFormat data_format) { DataFormat data_format) {
const index_t const index_t
in_height = data_format == NCHW ? input_shape[2] : input_shape[1]; in_height =
data_format == DataFormat::NCHW ? input_shape[2] : input_shape[1];
const index_t const index_t
in_width = data_format == NCHW ? input_shape[3] : input_shape[2]; in_width =
data_format == DataFormat::NCHW ? input_shape[3] : input_shape[2];
const index_t const index_t
out_height = data_format == NCHW ? output_shape[2] : output_shape[1]; out_height =
data_format == DataFormat::NCHW ? output_shape[2] : output_shape[1];
const index_t const index_t
out_width = data_format == NCHW ? output_shape[3] : output_shape[2]; out_width =
data_format == DataFormat::NCHW ? output_shape[3] : output_shape[2];
const index_t extended_in_height = (in_height - 1) * strides[0] + 1; const index_t extended_in_height = (in_height - 1) * strides[0] + 1;
const index_t extended_in_width = (in_width - 1) * strides[1] + 1; const index_t extended_in_width = (in_width - 1) * strides[1] + 1;
...@@ -307,11 +315,11 @@ void CalcDeconvShape_TF(const std::vector<index_t> &input_shape, ...@@ -307,11 +315,11 @@ void CalcDeconvShape_TF(const std::vector<index_t> &input_shape,
padded_out_shape->resize(4); padded_out_shape->resize(4);
(*padded_out_shape)[0] = output_shape[0]; (*padded_out_shape)[0] = output_shape[0];
(*padded_out_shape)[1] = (*padded_out_shape)[1] =
data_format == NCHW ? output_channel : padded_out_height; data_format == DataFormat::NCHW ? output_channel : padded_out_height;
(*padded_out_shape)[2] = (*padded_out_shape)[2] =
data_format == NCHW ? padded_out_height : padded_out_width; data_format == DataFormat::NCHW ? padded_out_height : padded_out_width;
(*padded_out_shape)[3] = (*padded_out_shape)[3] =
data_format == NCHW ? padded_out_width : output_channel; data_format == DataFormat::NCHW ? padded_out_width : output_channel;
} }
} }
...@@ -325,9 +333,11 @@ void CalcDeconvShape_Caffe(const std::vector<index_t> &input_shape, ...@@ -325,9 +333,11 @@ void CalcDeconvShape_Caffe(const std::vector<index_t> &input_shape,
std::vector<index_t> *padded_out_shape, std::vector<index_t> *padded_out_shape,
DataFormat data_format) { DataFormat data_format) {
const index_t const index_t
in_height = data_format == NCHW ? input_shape[2] : input_shape[1]; in_height =
data_format == DataFormat::NCHW ? input_shape[2] : input_shape[1];
const index_t const index_t
in_width = data_format == NCHW ? input_shape[3] : input_shape[2]; in_width =
data_format == DataFormat::NCHW ? input_shape[3] : input_shape[2];
const index_t output_channel = filter_shape[0] * group; const index_t output_channel = filter_shape[0] * group;
...@@ -351,11 +361,11 @@ void CalcDeconvShape_Caffe(const std::vector<index_t> &input_shape, ...@@ -351,11 +361,11 @@ void CalcDeconvShape_Caffe(const std::vector<index_t> &input_shape,
padded_out_shape->resize(4); padded_out_shape->resize(4);
(*padded_out_shape)[0] = input_shape[0]; (*padded_out_shape)[0] = input_shape[0];
(*padded_out_shape)[1] = (*padded_out_shape)[1] =
data_format == NCHW ? output_channel : padded_out_height; data_format == DataFormat::NCHW ? output_channel : padded_out_height;
(*padded_out_shape)[2] = (*padded_out_shape)[2] =
data_format == NCHW ? padded_out_height : padded_out_width; data_format == DataFormat::NCHW ? padded_out_height : padded_out_width;
(*padded_out_shape)[3] = (*padded_out_shape)[3] =
data_format == NCHW ? padded_out_width : output_channel; data_format == DataFormat::NCHW ? padded_out_width : output_channel;
} }
if (out_shape != nullptr) { if (out_shape != nullptr) {
...@@ -363,9 +373,11 @@ void CalcDeconvShape_Caffe(const std::vector<index_t> &input_shape, ...@@ -363,9 +373,11 @@ void CalcDeconvShape_Caffe(const std::vector<index_t> &input_shape,
index_t out_width = padded_out_width - out_pad_size[1]; index_t out_width = padded_out_width - out_pad_size[1];
out_shape->resize(4); out_shape->resize(4);
(*out_shape)[0] = input_shape[0]; (*out_shape)[0] = input_shape[0];
(*out_shape)[1] = data_format == NCHW ? output_channel : out_height; (*out_shape)[1] =
(*out_shape)[2] = data_format == NCHW ? out_height : out_width; data_format == DataFormat::NCHW ? output_channel : out_height;
(*out_shape)[3] = data_format == NCHW ? out_width : output_channel; (*out_shape)[2] = data_format == DataFormat::NCHW ? out_height : out_width;
(*out_shape)[3] =
data_format == DataFormat::NCHW ? out_width : output_channel;
} }
} }
...@@ -385,7 +397,7 @@ void CalDeconvOutputShapeAndPadSize(const std::vector<index_t> &input_shape, ...@@ -385,7 +397,7 @@ void CalDeconvOutputShapeAndPadSize(const std::vector<index_t> &input_shape,
MACE_CHECK(output_shape->size() == 4, MACE_CHECK(output_shape->size() == 4,
"deconv output shape shoud be 4-dims"); "deconv output shape shoud be 4-dims");
std::vector<index_t> &out_shape = *output_shape; std::vector<index_t> &out_shape = *output_shape;
if (data_format == NCHW) { if (data_format == DataFormat::NCHW) {
const index_t t = out_shape[1]; const index_t t = out_shape[1];
out_shape[1] = out_shape[3]; out_shape[1] = out_shape[3];
out_shape[3] = out_shape[2]; out_shape[3] = out_shape[2];
......
...@@ -199,7 +199,7 @@ class ConcatOp<DeviceType::GPU, T> : public ConcatOpBase { ...@@ -199,7 +199,7 @@ class ConcatOp<DeviceType::GPU, T> : public ConcatOpBase {
public: public:
explicit ConcatOp(OpConstructContext *context) explicit ConcatOp(OpConstructContext *context)
: ConcatOpBase(context) { : ConcatOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ConcatKernel<T>>(); kernel_ = make_unique<opencl::image::ConcatKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -243,9 +243,11 @@ void RegisterConcat(OpRegistryBase *op_registry) { ...@@ -243,9 +243,11 @@ void RegisterConcat(OpRegistryBase *op_registry) {
.SetDevicePlacerFunc( .SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
auto tensor_shape_info = context->tensor_shape_info(); auto tensor_shape_info = context->tensor_shape_info();
if (op->output_shape_size() != op->output_size() || if (op->output_shape(0).dims_size() != 4) {
op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU }; return { DeviceType::CPU };
} else { } else {
int has_data_format = int has_data_format =
......
...@@ -231,9 +231,9 @@ class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase { ...@@ -231,9 +231,9 @@ class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase {
std::vector<int> paddings(2); std::vector<int> paddings(2);
if (paddings_.empty()) { if (paddings_.empty()) {
CalcPaddingAndOutputSize(input->shape().data(), CalcPaddingAndOutputSize(input->shape().data(),
NHWC, DataFormat::NHWC,
filter->shape().data(), filter->shape().data(),
OHWI, DataFormat::OHWI,
dilations_.data(), dilations_.data(),
strides_.data(), strides_.data(),
padding_type_, padding_type_,
...@@ -242,9 +242,9 @@ class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase { ...@@ -242,9 +242,9 @@ class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase {
} else { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), CalcOutputSize(input->shape().data(),
NHWC, DataFormat::NHWC,
filter->shape().data(), filter->shape().data(),
OHWI, DataFormat::OHWI,
paddings_.data(), paddings_.data(),
dilations_.data(), dilations_.data(),
strides_.data(), strides_.data(),
...@@ -459,7 +459,7 @@ class Conv2dOp<DeviceType::GPU, T> : public ConvPool2dOpBase { ...@@ -459,7 +459,7 @@ class Conv2dOp<DeviceType::GPU, T> : public ConvPool2dOpBase {
"leakyrelu_coefficient", 0.0f)), "leakyrelu_coefficient", 0.0f)),
wino_block_size_(Operation::GetOptionalArg<int>("wino_block_size", 0)) { wino_block_size_(Operation::GetOptionalArg<int>("wino_block_size", 0)) {
MemoryType mem_type; MemoryType mem_type;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::Conv2dKernel<T>>(); kernel_ = make_unique<opencl::image::Conv2dKernel<T>>();
} else { } else {
......
...@@ -47,8 +47,8 @@ void TestNHWCSimple3x3VALID(int wino_blk_size = 0) { ...@@ -47,8 +47,8 @@ void TestNHWCSimple3x3VALID(int wino_blk_size = 0) {
const std::vector<index_t> output_shape = {1, 1, 1, 1}; const std::vector<index_t> output_shape = {1, 1, 1, 1};
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -60,8 +60,8 @@ void TestNHWCSimple3x3VALID(int wino_blk_size = 0) { ...@@ -60,8 +60,8 @@ void TestNHWCSimple3x3VALID(int wino_blk_size = 0) {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
...@@ -105,8 +105,8 @@ void TestNHWCSimple3x3SAME(int wino_blk_size = 0) { ...@@ -105,8 +105,8 @@ void TestNHWCSimple3x3SAME(int wino_blk_size = 0) {
const std::vector<index_t> output_shape = {1, 3, 3, 1}; const std::vector<index_t> output_shape = {1, 3, 3, 1};
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -118,8 +118,8 @@ void TestNHWCSimple3x3SAME(int wino_blk_size = 0) { ...@@ -118,8 +118,8 @@ void TestNHWCSimple3x3SAME(int wino_blk_size = 0) {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
...@@ -189,8 +189,8 @@ void TestNHWCSimple3x3WithoutBias() { ...@@ -189,8 +189,8 @@ void TestNHWCSimple3x3WithoutBias() {
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, true); 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -203,8 +203,8 @@ void TestNHWCSimple3x3WithoutBias() { ...@@ -203,8 +203,8 @@ void TestNHWCSimple3x3WithoutBias() {
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
...@@ -256,8 +256,8 @@ void TestNHWCCombined3x3() { ...@@ -256,8 +256,8 @@ void TestNHWCCombined3x3() {
net.AddInputFromArray<D, T>("Bias", {2}, {0.1f, 0.2f}, true); net.AddInputFromArray<D, T>("Bias", {2}, {0.1f, 0.2f}, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -270,8 +270,8 @@ void TestNHWCCombined3x3() { ...@@ -270,8 +270,8 @@ void TestNHWCCombined3x3() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
...@@ -321,8 +321,8 @@ void TestFusedNHWCSimple3x3VALID(int wino_blk_size = 0) { ...@@ -321,8 +321,8 @@ void TestFusedNHWCSimple3x3VALID(int wino_blk_size = 0) {
const std::vector<index_t> output_shape = {1, 1, 1, 1}; const std::vector<index_t> output_shape = {1, 1, 1, 1};
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -336,8 +336,8 @@ void TestFusedNHWCSimple3x3VALID(int wino_blk_size = 0) { ...@@ -336,8 +336,8 @@ void TestFusedNHWCSimple3x3VALID(int wino_blk_size = 0) {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
...@@ -376,8 +376,8 @@ void TestFusedNHWCSimple3x3WithoutBias(int wino_blk_size = 0) { ...@@ -376,8 +376,8 @@ void TestFusedNHWCSimple3x3WithoutBias(int wino_blk_size = 0) {
const std::vector<index_t> output_shape = {1, 1, 1, 1}; const std::vector<index_t> output_shape = {1, 1, 1, 1};
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -391,8 +391,8 @@ void TestFusedNHWCSimple3x3WithoutBias(int wino_blk_size = 0) { ...@@ -391,8 +391,8 @@ void TestFusedNHWCSimple3x3WithoutBias(int wino_blk_size = 0) {
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
...@@ -459,8 +459,8 @@ void TestConv1x1() { ...@@ -459,8 +459,8 @@ void TestConv1x1() {
net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f}, true); net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f}, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -472,8 +472,8 @@ void TestConv1x1() { ...@@ -472,8 +472,8 @@ void TestConv1x1() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
...@@ -532,8 +532,8 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape, ...@@ -532,8 +532,8 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape,
"Filter", {output_channels, input_channels, kernel_h, kernel_w}, true, "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true,
false); false);
net.AddRandomInput<D, T>("Bias", {output_channels}, true, false); net.AddRandomInput<D, T>("Bias", {output_channels}, true, false);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
...@@ -552,8 +552,8 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape, ...@@ -552,8 +552,8 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -651,8 +651,8 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape, ...@@ -651,8 +651,8 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape,
float_bias_data, float_bias_data,
true); true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -667,8 +667,8 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape, ...@@ -667,8 +667,8 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -811,8 +811,8 @@ void TestDilationConvNxN(const std::vector<index_t> &shape, ...@@ -811,8 +811,8 @@ void TestDilationConvNxN(const std::vector<index_t> &shape,
"Filter", {output_channels, input_channels, kernel_h, kernel_w}, true); "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true);
net.AddRandomInput<D, T>("Bias", {output_channels}, true); net.AddRandomInput<D, T>("Bias", {output_channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
...@@ -828,8 +828,8 @@ void TestDilationConvNxN(const std::vector<index_t> &shape, ...@@ -828,8 +828,8 @@ void TestDilationConvNxN(const std::vector<index_t> &shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -900,8 +900,8 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape, ...@@ -900,8 +900,8 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape,
"Filter", {output_channels, input_channels, kernel_h, kernel_w}, true); "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true);
net.AddRandomInput<D, float>("Bias", {output_channels}, true); net.AddRandomInput<D, float>("Bias", {output_channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -916,8 +916,8 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape, ...@@ -916,8 +916,8 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
...@@ -979,8 +979,8 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape, ...@@ -979,8 +979,8 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape,
"Filter", {output_channels, input_channels, kernel_h, kernel_w}, true); "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true);
net.AddRandomInput<D, float>("Bias", {output_channels}, true); net.AddRandomInput<D, float>("Bias", {output_channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -994,8 +994,8 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape, ...@@ -994,8 +994,8 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -1118,12 +1118,12 @@ void TestQuant(const index_t batch, ...@@ -1118,12 +1118,12 @@ void TestQuant(const index_t batch,
net.AddRandomInput<CPU, float>("Filter", {out_channels, k_height, k_width, net.AddRandomInput<CPU, float>("Filter", {out_channels, k_height, k_width,
in_channels}, true); in_channels}, true);
net.AddRandomInput<CPU, float>("Bias", {out_channels}, true); net.AddRandomInput<CPU, float>("Bias", {out_channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.TransformFilterDataFormat<DeviceType::CPU, float>("Filter", net.TransformFilterDataFormat<DeviceType::CPU, float>("Filter",
OHWI, DataFormat::OHWI,
"FilterOIHW", "FilterOIHW",
OIHW); DataFormat::OIHW);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -1136,8 +1136,8 @@ void TestQuant(const index_t batch, ...@@ -1136,8 +1136,8 @@ void TestQuant(const index_t batch,
.AddIntArg("T", static_cast<int>(DT_FLOAT)) .AddIntArg("T", static_cast<int>(DT_FLOAT))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Quantize", "QuantizeFilter") OpDefBuilder("Quantize", "QuantizeFilter")
.Input("Filter") .Input("Filter")
......
...@@ -117,7 +117,7 @@ class CropOp<DeviceType::GPU, T> : public Operation { ...@@ -117,7 +117,7 @@ class CropOp<DeviceType::GPU, T> : public Operation {
public: public:
explicit CropOp(OpConstructContext *context) explicit CropOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::CropKernel<T>>( kernel_ = make_unique<opencl::image::CropKernel<T>>(
Operation::GetRepeatedArgs<int>("offset")); Operation::GetRepeatedArgs<int>("offset"));
} else { } else {
...@@ -151,11 +151,13 @@ void RegisterCrop(OpRegistryBase *op_registry) { ...@@ -151,11 +151,13 @@ void RegisterCrop(OpRegistryBase *op_registry) {
.SetDevicePlacerFunc( .SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
int has_data_format = int has_data_format =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "has_data_format", 0); *op, "has_data_format", 0);
if (!has_data_format || if (!has_data_format ||
(op->output_shape_size() != op->output_size()) ||
op->output_shape(0).dims_size() != 4) { op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU }; return { DeviceType::CPU };
} }
......
...@@ -42,13 +42,13 @@ void RunCrop(const std::vector<index_t> &input_shape, ...@@ -42,13 +42,13 @@ void RunCrop(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else if (D == CPU) { } else if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input0", net.TransformDataFormat<DeviceType::CPU, float>("Input0",
NHWC, DataFormat::NHWC,
"InputNCHW0", "InputNCHW0",
NCHW); DataFormat::NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Input1", net.TransformDataFormat<DeviceType::CPU, float>("Input1",
NHWC, DataFormat::NHWC,
"InputNCHW1", "InputNCHW1",
NCHW); DataFormat::NCHW);
OpDefBuilder("Crop", "CropTest") OpDefBuilder("Crop", "CropTest")
.Input("InputNCHW0") .Input("InputNCHW0")
.Input("InputNCHW1") .Input("InputNCHW1")
...@@ -62,8 +62,8 @@ void RunCrop(const std::vector<index_t> &input_shape, ...@@ -62,8 +62,8 @@ void RunCrop(const std::vector<index_t> &input_shape,
net.RunOp(D); net.RunOp(D);
if (D == CPU) { if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
// Check // Check
auto expected = net.CreateTensor<float>(expected_shape, expected_data); auto expected = net.CreateTensor<float>(expected_shape, expected_data);
......
...@@ -32,8 +32,8 @@ void SimpleTestWithDataFormat(const std::vector<index_t> &shape, ...@@ -32,8 +32,8 @@ void SimpleTestWithDataFormat(const std::vector<index_t> &shape,
OpsTestNet net; OpsTestNet net;
net.AddInputFromArray<CPU, T>("Input", shape, input); net.AddInputFromArray<CPU, T>("Input", shape, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Cumsum", "CumsumTest") OpDefBuilder("Cumsum", "CumsumTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -48,8 +48,8 @@ void SimpleTestWithDataFormat(const std::vector<index_t> &shape, ...@@ -48,8 +48,8 @@ void SimpleTestWithDataFormat(const std::vector<index_t> &shape,
// Run // Run
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
net.AddInputFromArray<CPU, T>("ExpectedOutput", shape, output); net.AddInputFromArray<CPU, T>("ExpectedOutput", shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"), ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
......
...@@ -173,7 +173,7 @@ class Deconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase { ...@@ -173,7 +173,7 @@ class Deconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase {
explicit Deconv2dOp(OpConstructContext *context) explicit Deconv2dOp(OpConstructContext *context)
: Deconv2dOpBase(context) { : Deconv2dOpBase(context) {
MemoryType mem_type = MemoryType::GPU_IMAGE; MemoryType mem_type = MemoryType::GPU_IMAGE;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::Deconv2dKernel<T>>(); kernel_ = make_unique<opencl::image::Deconv2dKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -240,7 +240,7 @@ class Deconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase { ...@@ -240,7 +240,7 @@ class Deconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase {
&out_paddings, &out_paddings,
nullptr, nullptr,
model_type_, model_type_,
NHWC); DataFormat::NHWC);
return kernel_->Compute(context, input, filter, bias, return kernel_->Compute(context, input, filter, bias,
strides_.data(), in_paddings.data(), activation_, strides_.data(), in_paddings.data(), activation_,
...@@ -276,7 +276,7 @@ void RegisterDeconv2D(OpRegistryBase *op_registry) { ...@@ -276,7 +276,7 @@ void RegisterDeconv2D(OpRegistryBase *op_registry) {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
FrameworkType framework_type = FrameworkType framework_type =
static_cast<ops::FrameworkType>( static_cast<FrameworkType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*(context->operator_def()), "framework_type", *(context->operator_def()), "framework_type",
FrameworkType::TENSORFLOW)); FrameworkType::TENSORFLOW));
......
...@@ -47,7 +47,8 @@ void RunTestSimple(const std::vector<index_t> &input_shape, ...@@ -47,7 +47,8 @@ void RunTestSimple(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data, true); net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data, true);
net.AddInputFromArray<D, float>("Bias", {out_channels}, bias_data, true); net.AddInputFromArray<D, float>("Bias", {out_channels}, bias_data, true);
// TODO(liutuo): remove the unused transform // TODO(liutuo): remove the unused transform
net.TransformFilterDataFormat<D, float>("Filter", HWOI, "FilterOIHW", OIHW); net.TransformFilterDataFormat<D, float>(
"Filter", DataFormat::HWOI, "FilterOIHW", DataFormat::OIHW);
if (D == DeviceType::GPU) { if (D == DeviceType::GPU) {
if (model_type == FrameworkType::CAFFE) { if (model_type == FrameworkType::CAFFE) {
OpDefBuilder("Deconv2D", "Deconv2dTest") OpDefBuilder("Deconv2D", "Deconv2dTest")
...@@ -77,8 +78,8 @@ void RunTestSimple(const std::vector<index_t> &input_shape, ...@@ -77,8 +78,8 @@ void RunTestSimple(const std::vector<index_t> &input_shape,
} }
net.RunOp(D); net.RunOp(D);
} else { } else {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
if (model_type == FrameworkType::CAFFE) { if (model_type == FrameworkType::CAFFE) {
OpDefBuilder("Deconv2D", "Deconv2dTest") OpDefBuilder("Deconv2D", "Deconv2dTest")
...@@ -109,8 +110,8 @@ void RunTestSimple(const std::vector<index_t> &input_shape, ...@@ -109,8 +110,8 @@ void RunTestSimple(const std::vector<index_t> &input_shape,
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
auto expected = net.CreateTensor<float>(expected_shape, expected_data); auto expected = net.CreateTensor<float>(expected_shape, expected_data);
...@@ -380,8 +381,8 @@ void TestComplexDeconvNxN(const int batch, ...@@ -380,8 +381,8 @@ void TestComplexDeconvNxN(const int batch,
"Filter", {output_channels, input_channels, kernel_h, kernel_w}, true, "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true,
false); false);
net.AddRandomInput<D, T>("Bias", {output_channels}, true, false); net.AddRandomInput<D, T>("Bias", {output_channels}, true, false);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
int out_h = 0; int out_h = 0;
int out_w = 0; int out_w = 0;
...@@ -440,8 +441,8 @@ void TestComplexDeconvNxN(const int batch, ...@@ -440,8 +441,8 @@ void TestComplexDeconvNxN(const int batch,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
......
...@@ -96,7 +96,7 @@ class DepthToSpaceOp<DeviceType::GPU, T> : public Operation { ...@@ -96,7 +96,7 @@ class DepthToSpaceOp<DeviceType::GPU, T> : public Operation {
explicit DepthToSpaceOp(OpConstructContext *context) explicit DepthToSpaceOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
int block_size = Operation::GetOptionalArg<int>("block_size", 1); int block_size = Operation::GetOptionalArg<int>("block_size", 1);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::DepthToSpaceKernel<T>>(block_size); kernel_ = make_unique<opencl::image::DepthToSpaceKernel<T>>(block_size);
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
......
...@@ -32,8 +32,8 @@ void RunDepthToSpace(const std::vector<index_t> &input_shape, ...@@ -32,8 +32,8 @@ void RunDepthToSpace(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>("Input", input_shape, input_data); net.AddInputFromArray<D, float>("Input", input_shape, input_data);
// Construct graph // Construct graph
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthToSpace", "DepthToSpaceTest") OpDefBuilder("DepthToSpace", "DepthToSpaceTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -41,8 +41,8 @@ void RunDepthToSpace(const std::vector<index_t> &input_shape, ...@@ -41,8 +41,8 @@ void RunDepthToSpace(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else { } else {
OpDefBuilder("DepthToSpace", "DepthToSpaceTest") OpDefBuilder("DepthToSpace", "DepthToSpaceTest")
...@@ -114,8 +114,8 @@ void RandomTest(const int block_size, ...@@ -114,8 +114,8 @@ void RandomTest(const int block_size,
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", shape); net.AddRandomInput<D, float>("Input", shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthToSpace", "DepthToSpaceTest") OpDefBuilder("DepthToSpace", "DepthToSpaceTest")
.Input("InputNCHW") .Input("InputNCHW")
.AddIntArg("block_size", block_size) .AddIntArg("block_size", block_size)
...@@ -125,8 +125,8 @@ void RandomTest(const int block_size, ...@@ -125,8 +125,8 @@ void RandomTest(const int block_size,
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("DepthToSpace", "DepthToSpaceTest") OpDefBuilder("DepthToSpace", "DepthToSpaceTest")
.Input("Input") .Input("Input")
......
...@@ -188,9 +188,9 @@ class DepthwiseConv2dOp<DeviceType::CPU, uint8_t> ...@@ -188,9 +188,9 @@ class DepthwiseConv2dOp<DeviceType::CPU, uint8_t>
filter->dim(2) * filter->dim(3), filter->dim(0), filter->dim(1), 1}; filter->dim(2) * filter->dim(3), filter->dim(0), filter->dim(1), 1};
if (paddings_.empty()) { if (paddings_.empty()) {
CalcPaddingAndOutputSize(input->shape().data(), CalcPaddingAndOutputSize(input->shape().data(),
NHWC, DataFormat::NHWC,
ohwi_shape.data(), ohwi_shape.data(),
OHWI, DataFormat::OHWI,
dilations_.data(), dilations_.data(),
strides_.data(), strides_.data(),
padding_type_, padding_type_,
...@@ -199,9 +199,9 @@ class DepthwiseConv2dOp<DeviceType::CPU, uint8_t> ...@@ -199,9 +199,9 @@ class DepthwiseConv2dOp<DeviceType::CPU, uint8_t>
} else { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), CalcOutputSize(input->shape().data(),
NHWC, DataFormat::NHWC,
ohwi_shape.data(), ohwi_shape.data(),
OHWI, DataFormat::OHWI,
paddings_.data(), paddings_.data(),
dilations_.data(), dilations_.data(),
strides_.data(), strides_.data(),
...@@ -375,7 +375,7 @@ class DepthwiseConv2dOp<DeviceType::GPU, T> : public DepthwiseConv2dOpBase { ...@@ -375,7 +375,7 @@ class DepthwiseConv2dOp<DeviceType::GPU, T> : public DepthwiseConv2dOpBase {
explicit DepthwiseConv2dOp(OpConstructContext *context) explicit DepthwiseConv2dOp(OpConstructContext *context)
: DepthwiseConv2dOpBase(context) { : DepthwiseConv2dOpBase(context) {
MemoryType mem_type; MemoryType mem_type;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::DepthwiseConv2dKernel<T>>(); kernel_ = make_unique<opencl::image::DepthwiseConv2dKernel<T>>();
} else { } else {
...@@ -459,6 +459,18 @@ void RegisterDepthwiseConv2d(OpRegistryBase *op_registry) { ...@@ -459,6 +459,18 @@ void RegisterDepthwiseConv2d(OpRegistryBase *op_registry) {
context->set_output_mem_type(mem_type); context->set_output_mem_type(mem_type);
})); }));
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("DepthwiseConv2d")
.SetInputsDataFormatSelector(
[](OpConditionContext *context) -> std::vector<DataFormat> {
DataFormat op_data_format =
static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*context->operator_def(), "data_format",
static_cast<int>(DataFormat::NONE)));
return {op_data_format, DataFormat::OIHW, DataFormat::NONE};
}));
} }
} // namespace ops } // namespace ops
......
...@@ -39,8 +39,8 @@ void SimpleValidTest() { ...@@ -39,8 +39,8 @@ void SimpleValidTest() {
true); true);
net.AddInputFromArray<D, float>("Bias", {2}, {.1f, .2f}, true); net.AddInputFromArray<D, float>("Bias", {2}, {.1f, .2f}, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -52,8 +52,8 @@ void SimpleValidTest() { ...@@ -52,8 +52,8 @@ void SimpleValidTest() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input") .Input("Input")
...@@ -127,8 +127,8 @@ void ComplexValidTest(index_t batch, ...@@ -127,8 +127,8 @@ void ComplexValidTest(index_t batch,
true); true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -141,8 +141,8 @@ void ComplexValidTest(index_t batch, ...@@ -141,8 +141,8 @@ void ComplexValidTest(index_t batch,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input") .Input("Input")
...@@ -249,8 +249,8 @@ void TestNxNS12(const index_t height, const index_t width) { ...@@ -249,8 +249,8 @@ void TestNxNS12(const index_t height, const index_t width) {
{multiplier * channel}, {multiplier * channel},
true, false); true, false);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -267,8 +267,8 @@ void TestNxNS12(const index_t height, const index_t width) { ...@@ -267,8 +267,8 @@ void TestNxNS12(const index_t height, const index_t width) {
// Run on cpu // Run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -389,9 +389,9 @@ void TestQuant(const index_t batch, ...@@ -389,9 +389,9 @@ void TestQuant(const index_t batch,
"Filter", {k_height, k_width, in_channels, multiplier}, true, false); "Filter", {k_height, k_width, in_channels, multiplier}, true, false);
net.AddRandomInput<CPU, float>("Bias", {out_channels}, true); net.AddRandomInput<CPU, float>("Bias", {out_channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>( net.TransformDataFormat<DeviceType::CPU, float>(
"Input", NHWC, "InputNCHW", NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.TransformFilterDataFormat<DeviceType::CPU, float>( net.TransformFilterDataFormat<DeviceType::CPU, float>(
"Filter", HWIO, "FilterOIHW", OIHW); "Filter", DataFormat::HWIO, "FilterOIHW", DataFormat::OIHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -405,7 +405,7 @@ void TestQuant(const index_t batch, ...@@ -405,7 +405,7 @@ void TestQuant(const index_t batch,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>( net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", NCHW, "Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Quantize", "QuantizeFilter") OpDefBuilder("Quantize", "QuantizeFilter")
.Input("Filter") .Input("Filter")
......
...@@ -190,7 +190,7 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase { ...@@ -190,7 +190,7 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase {
explicit DepthwiseDeconv2dOp(OpConstructContext *context) explicit DepthwiseDeconv2dOp(OpConstructContext *context)
: Deconv2dOpBase(context) { : Deconv2dOpBase(context) {
MemoryType mem_type = MemoryType::GPU_IMAGE; MemoryType mem_type = MemoryType::GPU_IMAGE;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::DepthwiseDeconv2dKernel<T>>(); kernel_ = make_unique<opencl::image::DepthwiseDeconv2dKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -230,7 +230,7 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase { ...@@ -230,7 +230,7 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase {
&out_paddings, &out_paddings,
nullptr, nullptr,
CAFFE, CAFFE,
NHWC); DataFormat::NHWC);
return kernel_->Compute(context, return kernel_->Compute(context,
input, input,
......
...@@ -39,7 +39,8 @@ void RunTestSimple(const int group, ...@@ -39,7 +39,8 @@ void RunTestSimple(const int group,
// Add input data // Add input data
net.AddInputFromArray<D, float>("Input", input_shape, input_data); net.AddInputFromArray<D, float>("Input", input_shape, input_data);
net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data, true); net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data, true);
net.TransformFilterDataFormat<D, float>("Filter", HWOI, "FilterOIHW", OIHW); net.TransformFilterDataFormat<D, float>(
"Filter", DataFormat::HWOI, "FilterOIHW", DataFormat::OIHW);
const index_t out_channels = expected_shape[3]; const index_t out_channels = expected_shape[3];
net.AddInputFromArray<D, float>("Bias", {out_channels}, bias_data, true); net.AddInputFromArray<D, float>("Bias", {out_channels}, bias_data, true);
...@@ -56,8 +57,8 @@ void RunTestSimple(const int group, ...@@ -56,8 +57,8 @@ void RunTestSimple(const int group,
net.RunOp(D); net.RunOp(D);
} else { } else {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, net.TransformDataFormat<DeviceType::CPU, float>(
"InputNCHW", NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest") OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("FilterOIHW") .Input("FilterOIHW")
...@@ -69,8 +70,8 @@ void RunTestSimple(const int group, ...@@ -69,8 +70,8 @@ void RunTestSimple(const int group,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
auto expected = net.CreateTensor<float>(expected_shape, expected_data); auto expected = net.CreateTensor<float>(expected_shape, expected_data);
...@@ -193,8 +194,8 @@ void RandomTest(index_t batch, ...@@ -193,8 +194,8 @@ void RandomTest(index_t batch,
{channel * multiplier}, {channel * multiplier},
bias_data, true, false); bias_data, true, false);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest") OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -210,8 +211,8 @@ void RandomTest(index_t batch, ...@@ -210,8 +211,8 @@ void RandomTest(index_t batch,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
......
...@@ -1145,7 +1145,7 @@ class EltwiseOp<DeviceType::GPU, T> : public Operation { ...@@ -1145,7 +1145,7 @@ class EltwiseOp<DeviceType::GPU, T> : public Operation {
int32_t scalar_input_index = Operation::GetOptionalArg<int32_t>( int32_t scalar_input_index = Operation::GetOptionalArg<int32_t>(
"scalar_input_index", 1); "scalar_input_index", 1);
MemoryType mem_type; MemoryType mem_type;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::EltwiseKernel<T>>( kernel_ = make_unique<opencl::image::EltwiseKernel<T>>(
type, coeff, scalar_input, scalar_input_index); type, coeff, scalar_input, scalar_input_index);
......
...@@ -69,7 +69,8 @@ void SimpleTensorScalar(const ops::EltwiseType type, ...@@ -69,7 +69,8 @@ void SimpleTensorScalar(const ops::EltwiseType type,
net.AddInputFromArray<D, T>("Input", shape, input); net.AddInputFromArray<D, T>("Input", shape, input);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, T>("Input", NHWC, "TInput", NCHW); net.TransformDataFormat<D, T>(
"Input", DataFormat::NHWC, "TInput", DataFormat::NCHW);
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("TInput") .Input("TInput")
.AddIntArg("T", DataTypeToEnum<T>::v()) .AddIntArg("T", DataTypeToEnum<T>::v())
...@@ -81,7 +82,8 @@ void SimpleTensorScalar(const ops::EltwiseType type, ...@@ -81,7 +82,8 @@ void SimpleTensorScalar(const ops::EltwiseType type,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, DstType>("TOutput", NCHW, "Output", NHWC); net.TransformDataFormat<D, DstType>(
"TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else { } else {
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("Input") .Input("Input")
...@@ -124,13 +126,15 @@ void SimpleTensorEltwise(const ops::EltwiseType type, ...@@ -124,13 +126,15 @@ void SimpleTensorEltwise(const ops::EltwiseType type,
.OutputType({ops::IsLogicalType(type) ? DT_INT32 : DT_FLOAT}) .OutputType({ops::IsLogicalType(type) ? DT_INT32 : DT_FLOAT})
.Output("TOutput"); .Output("TOutput");
if (shape0.size() > 1) { if (shape0.size() > 1) {
net.TransformDataFormat<D, T>("Input0", NHWC, "TInput0", NCHW); net.TransformDataFormat<D, T>(
"Input0", DataFormat::NHWC, "TInput0", DataFormat::NCHW);
op_builder.Input("TInput0"); op_builder.Input("TInput0");
} else { } else {
op_builder.Input("Input0"); op_builder.Input("Input0");
} }
if (shape1.size() > 1) { if (shape1.size() > 1) {
net.TransformDataFormat<D, T>("Input1", NHWC, "TInput1", NCHW); net.TransformDataFormat<D, T>(
"Input1", DataFormat::NHWC, "TInput1", DataFormat::NCHW);
op_builder.Input("TInput1"); op_builder.Input("TInput1");
} else { } else {
op_builder.Input("Input1"); op_builder.Input("Input1");
...@@ -139,7 +143,8 @@ void SimpleTensorEltwise(const ops::EltwiseType type, ...@@ -139,7 +143,8 @@ void SimpleTensorEltwise(const ops::EltwiseType type,
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, DstType>("TOutput", NCHW, "Output", NHWC); net.TransformDataFormat<D, DstType>(
"TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else { } else {
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("Input0") .Input("Input0")
...@@ -560,7 +565,8 @@ void GPUOverflowTest(const ops::EltwiseType type, ...@@ -560,7 +565,8 @@ void GPUOverflowTest(const ops::EltwiseType type,
net.AddInputFromArray<DeviceType::GPU, T>( net.AddInputFromArray<DeviceType::GPU, T>(
"Filter", "Filter",
{output_shape.back(), shape0.back(), 3, 3}, {output_shape.back(), shape0.back(), 3, 3},
std::vector<float>(output_shape.back() * shape0.back() * 9, 1)); std::vector<float>(output_shape.back() * shape0.back() * 9, 1),
true);
OpDefBuilder("Conv2D", "Conv2D") OpDefBuilder("Conv2D", "Conv2D")
.AddIntArg("T", DataTypeToEnum<T>::v()) .AddIntArg("T", DataTypeToEnum<T>::v())
.Input("EltOutput") .Input("EltOutput")
...@@ -636,8 +642,8 @@ void RandomTensorScalar(const ops::EltwiseType type, ...@@ -636,8 +642,8 @@ void RandomTensorScalar(const ops::EltwiseType type,
// Add input data // Add input data
net.AddRandomInput<DeviceType::GPU, float>("Input", shape, false, true, true); net.AddRandomInput<DeviceType::GPU, float>("Input", shape, false, true, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "TInput", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "TInput", DataFormat::NCHW);
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("TInput") .Input("TInput")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
...@@ -647,8 +653,8 @@ void RandomTensorScalar(const ops::EltwiseType type, ...@@ -647,8 +653,8 @@ void RandomTensorScalar(const ops::EltwiseType type,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("TOutput", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
...@@ -690,10 +696,10 @@ void RandomTensorEltwise(const ops::EltwiseType type, ...@@ -690,10 +696,10 @@ void RandomTensorEltwise(const ops::EltwiseType type,
true, true,
true); true);
net.TransformDataFormat<DeviceType::CPU, float>("Input0", NHWC, "TInput0", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input0", DataFormat::NHWC, "TInput0", DataFormat::NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Input1", NHWC, "TInput1", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input1", DataFormat::NHWC, "TInput1", DataFormat::NCHW);
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("TInput0") .Input("TInput0")
.Input("TInput1") .Input("TInput1")
...@@ -705,8 +711,8 @@ void RandomTensorEltwise(const ops::EltwiseType type, ...@@ -705,8 +711,8 @@ void RandomTensorEltwise(const ops::EltwiseType type,
// Run // Run
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("TOutput", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
...@@ -746,10 +752,10 @@ void Quantized(const std::vector<index_t> &shape, ...@@ -746,10 +752,10 @@ void Quantized(const std::vector<index_t> &shape,
true, true,
true); true);
net.TransformDataFormat<DeviceType::CPU, float>("Input0", NHWC, "TInput0", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input0", DataFormat::NHWC, "TInput0", DataFormat::NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Input1", NHWC, "TInput1", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input1", DataFormat::NHWC, "TInput1", DataFormat::NCHW);
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("TInput0") .Input("TInput0")
...@@ -761,8 +767,8 @@ void Quantized(const std::vector<index_t> &shape, ...@@ -761,8 +767,8 @@ void Quantized(const std::vector<index_t> &shape,
// Run // Run
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("TOutput", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Quantize", "QuantizeInput0") OpDefBuilder("Quantize", "QuantizeInput0")
.Input("Input0") .Input("Input0")
......
...@@ -49,7 +49,8 @@ void Simple() { ...@@ -49,7 +49,8 @@ void Simple() {
net.AddInputFromArray<D, float>("Offset", {1}, offset, true); net.AddInputFromArray<D, float>("Offset", {1}, offset, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<D, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "FoldedBatchNormTest") OpDefBuilder("BatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Scale") .Input("Scale")
...@@ -58,7 +59,8 @@ void Simple() { ...@@ -58,7 +59,8 @@ void Simple() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("BatchNorm", "FoldedBatchNormTest") OpDefBuilder("BatchNorm", "FoldedBatchNormTest")
.Input("Input") .Input("Input")
...@@ -100,8 +102,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) { ...@@ -100,8 +102,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "FoldedBatchNormTest") OpDefBuilder("BatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -113,8 +115,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) { ...@@ -113,8 +115,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -151,8 +153,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) { ...@@ -151,8 +153,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "FoldedBatchNormTest") OpDefBuilder("BatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -164,8 +166,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) { ...@@ -164,8 +166,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -205,8 +207,8 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) { ...@@ -205,8 +207,8 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "FoldedBatchNormTest") OpDefBuilder("BatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -218,8 +220,8 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) { ...@@ -218,8 +220,8 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -254,11 +256,11 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -254,11 +256,11 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) {
// Add input data // Add input data
net.AddRandomInput<DeviceType::GPU, float>("Input", net.AddRandomInput<DeviceType::GPU, float>("Input",
{batch, height, width, channels}); {batch, height, width, channels});
net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}); net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}); net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "FoldedBatchNormTest") OpDefBuilder("BatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -270,8 +272,8 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -270,8 +272,8 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
......
...@@ -190,7 +190,7 @@ class FullyConnectedOp<DeviceType::GPU, T> : public FullyConnectedOpBase { ...@@ -190,7 +190,7 @@ class FullyConnectedOp<DeviceType::GPU, T> : public FullyConnectedOpBase {
explicit FullyConnectedOp(OpConstructContext *context) explicit FullyConnectedOp(OpConstructContext *context)
: FullyConnectedOpBase(context) { : FullyConnectedOpBase(context) {
MemoryType mem_type = MemoryType::CPU_BUFFER; MemoryType mem_type = MemoryType::CPU_BUFFER;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::FullyConnectedKernel<T>>(); kernel_ = make_unique<opencl::image::FullyConnectedKernel<T>>();
} else { } else {
......
...@@ -48,7 +48,8 @@ void Simple(const std::vector<index_t> &input_shape, ...@@ -48,7 +48,8 @@ void Simple(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("FullyConnected", "FullyConnectedTest") OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("Input") .Input("Input")
...@@ -129,8 +130,8 @@ void Random(const index_t batch, ...@@ -129,8 +130,8 @@ void Random(const index_t batch,
net.AddRandomInput<DeviceType::GPU, float>("Bias", {out_channel}, true, net.AddRandomInput<DeviceType::GPU, float>("Bias", {out_channel}, true,
false); false);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("FullyConnected", "FullyConnectedTest") OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Weight") .Input("Weight")
...@@ -143,7 +144,8 @@ void Random(const index_t batch, ...@@ -143,7 +144,8 @@ void Random(const index_t batch,
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -215,8 +217,10 @@ void QuantRandom(const index_t batch, ...@@ -215,8 +217,10 @@ void QuantRandom(const index_t batch,
net.AddRandomInput<CPU, float>( net.AddRandomInput<CPU, float>(
"Weight", {out_channel, height, width, channels}, true); "Weight", {out_channel, height, width, channels}, true);
net.AddRandomInput<CPU, float>("Bias", {out_channel}, true); net.AddRandomInput<CPU, float>("Bias", {out_channel}, true);
net.TransformDataFormat<CPU, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<CPU, float>(
net.TransformFilterDataFormat<CPU, float>("Weight", OHWI, "WeightOIHW", OIHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.TransformFilterDataFormat<CPU, float>(
"Weight", DataFormat::OHWI, "WeightOIHW", DataFormat::OIHW);
OpDefBuilder("FullyConnected", "FullyConnectedTest") OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -226,7 +230,8 @@ void QuantRandom(const index_t batch, ...@@ -226,7 +230,8 @@ void QuantRandom(const index_t batch,
.AddIntArg("T", DT_FLOAT) .AddIntArg("T", DT_FLOAT)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(); net.RunOp();
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Quantize", "QuantizeWeight") OpDefBuilder("Quantize", "QuantizeWeight")
.Input("Weight") .Input("Weight")
......
...@@ -29,7 +29,8 @@ void Simple() { ...@@ -29,7 +29,8 @@ void Simple() {
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<D, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("LocalResponseNorm", "LocalResponseNormTest") OpDefBuilder("LocalResponseNorm", "LocalResponseNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -41,7 +42,8 @@ void Simple() { ...@@ -41,7 +42,8 @@ void Simple() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
// Check // Check
......
...@@ -36,7 +36,7 @@ class LSTMCellOp<DeviceType::GPU, T> : public Operation { ...@@ -36,7 +36,7 @@ class LSTMCellOp<DeviceType::GPU, T> : public Operation {
Operation::GetOptionalArg<float>("scalar_input", Operation::GetOptionalArg<float>("scalar_input",
0.0)); 0.0));
MemoryType mem_type = MemoryType::GPU_IMAGE; MemoryType mem_type = MemoryType::GPU_IMAGE;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::LSTMCellKernel<T>>(forget_bias); kernel_ = make_unique<opencl::image::LSTMCellKernel<T>>(forget_bias);
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
......
...@@ -47,7 +47,6 @@ class OpenCLBufferTransformer { ...@@ -47,7 +47,6 @@ class OpenCLBufferTransformer {
const OpenCLBufferType type, const OpenCLBufferType type,
const MemoryType out_mem_type, const MemoryType out_mem_type,
const int wino_blk_size, const int wino_blk_size,
DataFormat data_format,
Tensor *output) { Tensor *output) {
Workspace *ws = context->workspace(); Workspace *ws = context->workspace();
DataType dt = DataTypeToEnum<T>::value; DataType dt = DataTypeToEnum<T>::value;
...@@ -66,7 +65,6 @@ class OpenCLBufferTransformer { ...@@ -66,7 +65,6 @@ class OpenCLBufferTransformer {
VLOG(2) << "Transform CPU Buffer " << input->name() VLOG(2) << "Transform CPU Buffer " << input->name()
<< " to GPU Buffer " << internal_tensor->name() << " to GPU Buffer " << internal_tensor->name()
<< " with data type " << dt; << " with data type " << dt;
MACE_CHECK(data_format == DataFormat::NHWC);
internal_tensor->Resize(input->shape()); internal_tensor->Resize(input->shape());
const uint8_t *input_ptr = input->data<uint8_t>(); const uint8_t *input_ptr = input->data<uint8_t>();
Tensor::MappingGuard guard(internal_tensor); Tensor::MappingGuard guard(internal_tensor);
...@@ -88,7 +86,6 @@ class OpenCLBufferTransformer { ...@@ -88,7 +86,6 @@ class OpenCLBufferTransformer {
VLOG(2) << "Transform GPU Buffer " << internal_tensor.name() VLOG(2) << "Transform GPU Buffer " << internal_tensor.name()
<< " to CPU Buffer " << output->name() << " to CPU Buffer " << output->name()
<< " with data type " << dt; << " with data type " << dt;
MACE_CHECK(data_format == DataFormat::NHWC);
Tensor::MappingGuard guard(&internal_tensor); Tensor::MappingGuard guard(&internal_tensor);
const T *internal_ptr = internal_tensor.data<T>(); const T *internal_ptr = internal_tensor.data<T>();
output->Resize(internal_tensor.shape()); output->Resize(internal_tensor.shape());
...@@ -135,7 +132,7 @@ MaceStatus TransformFilter( ...@@ -135,7 +132,7 @@ MaceStatus TransformFilter(
input->MarkUnused(); input->MarkUnused();
return OpenCLBufferTransformer<T>(input->memory_type(), mem_type). return OpenCLBufferTransformer<T>(input->memory_type(), mem_type).
Transform(&op_context, input, buffer_type, mem_type, wino_blk_size, Transform(&op_context, input, buffer_type, mem_type, wino_blk_size,
DataFormat::DF_NONE, output); output);
} }
} // namespace ops } // namespace ops
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
#include "mace/core/memory_optimizer.h" #include "mace/core/memory_optimizer.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#include "mace/core/net_def_adapter.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -164,26 +165,27 @@ void OpTestContext::SetOCLImageAndBufferTestFlag() { ...@@ -164,26 +165,27 @@ void OpTestContext::SetOCLImageAndBufferTestFlag() {
bool OpsTestNet::Setup(mace::DeviceType device) { bool OpsTestNet::Setup(mace::DeviceType device) {
NetDef net_def; NetDef net_def;
for (auto &op_def : op_defs_) { for (auto &op_def : op_defs_) {
net_def.add_op()->CopyFrom(op_def); auto target_op = net_def.add_op();
target_op->CopyFrom(op_def);
auto has_data_format = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "has_data_format", 0);
auto is_quantized_op = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "T", static_cast<int>(DT_FLOAT))
== static_cast<int>(DT_UINT8);
for (auto input : op_def.input()) { for (auto input : op_def.input()) {
if (ws_.GetTensor(input) != nullptr && if (ws_.GetTensor(input) != nullptr &&
!ws_.GetTensor(input)->is_weight()) { !ws_.GetTensor(input)->is_weight()) {
auto input_info = net_def.add_input_info(); auto input_info = net_def.add_input_info();
input_info->set_name(input); input_info->set_name(input);
auto has_data_format = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "has_data_format", 1);
auto is_quantized_op = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "T", static_cast<int>(DT_FLOAT))
== static_cast<int>(DT_UINT8);
if (has_data_format) { if (has_data_format) {
if (is_quantized_op || device == DeviceType::GPU) { if (is_quantized_op || device == DeviceType::GPU) {
input_info->set_data_format(NHWC); input_info->set_data_format(static_cast<int>(DataFormat::NHWC));
} else { } else {
input_info->set_data_format(NCHW); input_info->set_data_format(static_cast<int>(DataFormat::NCHW));
} }
} else { } else {
input_info->set_data_format(DataFormat::DF_NONE); input_info->set_data_format(static_cast<int>(DataFormat::NONE));
} }
auto &shape = ws_.GetTensor(input)->shape(); auto &shape = ws_.GetTensor(input)->shape();
for (auto d : shape) { for (auto d : shape) {
...@@ -191,6 +193,10 @@ bool OpsTestNet::Setup(mace::DeviceType device) { ...@@ -191,6 +193,10 @@ bool OpsTestNet::Setup(mace::DeviceType device) {
} }
} }
} }
if (has_data_format) {
SetProtoArg<int>(target_op, "data_format",
static_cast<int>(DataFormat::AUTO));
}
} }
if (!op_defs_.empty()) { if (!op_defs_.empty()) {
auto op_def = op_defs_.back(); auto op_def = op_defs_.back();
...@@ -205,15 +211,21 @@ bool OpsTestNet::Setup(mace::DeviceType device) { ...@@ -205,15 +211,21 @@ bool OpsTestNet::Setup(mace::DeviceType device) {
} }
} }
} }
NetDef adapted_net_def;
NetDefAdapter net_def_adapter(op_registry_.get(), &ws_);
net_def_adapter.AdaptNetDef(&net_def,
OpTestContext::Get()->GetDevice(device),
&adapted_net_def);
MemoryOptimizer mem_optimizer; MemoryOptimizer mem_optimizer;
net_ = make_unique<SerialNet>( net_ = make_unique<SerialNet>(
op_registry_.get(), op_registry_.get(),
&net_def, &adapted_net_def,
&ws_, &ws_,
OpTestContext::Get()->GetDevice(device), OpTestContext::Get()->GetDevice(device),
&mem_optimizer); &mem_optimizer);
MaceStatus status = (ws_.PreallocateOutputTensor( MaceStatus status = (ws_.PreallocateOutputTensor(
net_def, adapted_net_def,
&mem_optimizer, &mem_optimizer,
OpTestContext::Get()->GetDevice(device))); OpTestContext::Get()->GetDevice(device)));
if (status != MaceStatus::MACE_SUCCESS) return false; if (status != MaceStatus::MACE_SUCCESS) return false;
...@@ -252,15 +264,20 @@ MaceStatus OpsTestNet::RunOp() { ...@@ -252,15 +264,20 @@ MaceStatus OpsTestNet::RunOp() {
MaceStatus OpsTestNet::RunNet(const mace::NetDef &net_def, MaceStatus OpsTestNet::RunNet(const mace::NetDef &net_def,
const mace::DeviceType device) { const mace::DeviceType device) {
device_type_ = device; device_type_ = device;
NetDef adapted_net_def;
NetDefAdapter net_def_adapter(op_registry_.get(), &ws_);
net_def_adapter.AdaptNetDef(&net_def,
OpTestContext::Get()->GetDevice(device),
&adapted_net_def);
MemoryOptimizer mem_optimizer; MemoryOptimizer mem_optimizer;
net_ = make_unique<SerialNet>( net_ = make_unique<SerialNet>(
op_registry_.get(), op_registry_.get(),
&net_def, &adapted_net_def,
&ws_, &ws_,
OpTestContext::Get()->GetDevice(device), OpTestContext::Get()->GetDevice(device),
&mem_optimizer); &mem_optimizer);
MACE_RETURN_IF_ERROR(ws_.PreallocateOutputTensor( MACE_RETURN_IF_ERROR(ws_.PreallocateOutputTensor(
net_def, adapted_net_def,
&mem_optimizer, &mem_optimizer,
OpTestContext::Get()->GetDevice(device))); OpTestContext::Get()->GetDevice(device)));
MACE_RETURN_IF_ERROR(net_->Init()); MACE_RETURN_IF_ERROR(net_->Init());
......
...@@ -216,7 +216,7 @@ class OpsTestNet { ...@@ -216,7 +216,7 @@ class OpsTestNet {
const std::vector<index_t> input_shape = input->shape(); const std::vector<index_t> input_shape = input->shape();
MACE_CHECK(input_shape.size() == 4, "input shape != 4"); MACE_CHECK(input_shape.size() == 4, "input shape != 4");
if (src_format == NHWC && dst_format == NCHW) { if (src_format == DataFormat::NHWC && dst_format == DataFormat::NCHW) {
index_t batch = input_shape[0]; index_t batch = input_shape[0];
index_t height = input_shape[1]; index_t height = input_shape[1];
index_t width = input_shape[2]; index_t width = input_shape[2];
...@@ -236,7 +236,8 @@ class OpsTestNet { ...@@ -236,7 +236,8 @@ class OpsTestNet {
} }
} }
} }
} else if (src_format == NCHW && dst_format == NHWC) { } else if (src_format == DataFormat::NCHW &&
dst_format == DataFormat::NHWC) {
index_t batch = input_shape[0]; index_t batch = input_shape[0];
index_t channels = input_shape[1]; index_t channels = input_shape[1];
index_t height = input_shape[2]; index_t height = input_shape[2];
...@@ -274,7 +275,7 @@ class OpsTestNet { ...@@ -274,7 +275,7 @@ class OpsTestNet {
input->is_weight()); input->is_weight());
const std::vector<index_t> input_shape = input->shape(); const std::vector<index_t> input_shape = input->shape();
MACE_CHECK(input_shape.size() == 4, "input shape != 4"); MACE_CHECK(input_shape.size() == 4, "input shape != 4");
if (src_format == HWOI && dst_format == OIHW) { if (src_format == DataFormat::HWOI && dst_format == DataFormat::OIHW) {
index_t height = input_shape[0]; index_t height = input_shape[0];
index_t width = input_shape[1]; index_t width = input_shape[1];
index_t out_channels = input_shape[2]; index_t out_channels = input_shape[2];
...@@ -292,7 +293,8 @@ class OpsTestNet { ...@@ -292,7 +293,8 @@ class OpsTestNet {
input_data[j * out_channels * in_channels + i]; input_data[j * out_channels * in_channels + i];
} }
} }
} else if (src_format == OIHW && dst_format == HWOI) { } else if (src_format == DataFormat::OIHW &&
dst_format == DataFormat::HWOI) {
index_t out_channels = input_shape[0]; index_t out_channels = input_shape[0];
index_t in_channels = input_shape[1]; index_t in_channels = input_shape[1];
index_t height = input_shape[2]; index_t height = input_shape[2];
...@@ -310,7 +312,8 @@ class OpsTestNet { ...@@ -310,7 +312,8 @@ class OpsTestNet {
input_data[j * height * width + i]; input_data[j * height * width + i];
} }
} }
} else if (src_format == HWIO && dst_format == OIHW) { } else if (src_format == DataFormat::HWIO &&
dst_format == DataFormat::OIHW) {
index_t height = input_shape[0]; index_t height = input_shape[0];
index_t width = input_shape[1]; index_t width = input_shape[1];
index_t in_channels = input_shape[2]; index_t in_channels = input_shape[2];
...@@ -330,7 +333,8 @@ class OpsTestNet { ...@@ -330,7 +333,8 @@ class OpsTestNet {
} }
} }
} }
} else if (src_format == OHWI && dst_format == OIHW) { } else if (src_format == DataFormat::OHWI &&
dst_format == DataFormat::OIHW) {
index_t out_channels = input_shape[0]; index_t out_channels = input_shape[0];
index_t height = input_shape[1]; index_t height = input_shape[1];
index_t width = input_shape[2]; index_t width = input_shape[2];
......
...@@ -179,7 +179,7 @@ class PadOp<DeviceType::GPU, T> : public Operation { ...@@ -179,7 +179,7 @@ class PadOp<DeviceType::GPU, T> : public Operation {
std::vector<int> paddings = Operation::GetRepeatedArgs<int>("paddings"); std::vector<int> paddings = Operation::GetRepeatedArgs<int>("paddings");
float constant_value = Operation::GetOptionalArg<float>( float constant_value = Operation::GetOptionalArg<float>(
"constant_value", 0.0); "constant_value", 0.0);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::PadKernel<T>>( kernel_ = make_unique<opencl::image::PadKernel<T>>(
type, paddings, constant_value); type, paddings, constant_value);
} else { } else {
......
...@@ -45,8 +45,8 @@ void SimpleConstant() { ...@@ -45,8 +45,8 @@ void SimpleConstant() {
// Run // Run
net.RunOp(D); net.RunOp(D);
} else { } else {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "TInput", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "TInput", DataFormat::NCHW);
OpDefBuilder("Pad", "PadTest") OpDefBuilder("Pad", "PadTest")
.Input("TInput") .Input("TInput")
.Output("TOutput") .Output("TOutput")
...@@ -58,8 +58,8 @@ void SimpleConstant() { ...@@ -58,8 +58,8 @@ void SimpleConstant() {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("TOutput", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
auto output = net.GetTensor("Output"); auto output = net.GetTensor("Output");
...@@ -93,7 +93,8 @@ void Result(const std::vector<index_t> &input_shape, ...@@ -93,7 +93,8 @@ void Result(const std::vector<index_t> &input_shape,
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
t_input = "TInput"; t_input = "TInput";
t_output = "TOutput"; t_output = "TOutput";
net.TransformDataFormat<DeviceType::CPU, T>(input, NHWC, t_input, NCHW); net.TransformDataFormat<DeviceType::CPU, T>(
input, DataFormat::NHWC, t_input, DataFormat::NCHW);
} }
OpDefBuilder("Pad", "PadTest") OpDefBuilder("Pad", "PadTest")
...@@ -108,7 +109,8 @@ void Result(const std::vector<index_t> &input_shape, ...@@ -108,7 +109,8 @@ void Result(const std::vector<index_t> &input_shape,
net.RunOp(D); net.RunOp(D);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, T>(t_output, NCHW, output, NHWC); net.TransformDataFormat<DeviceType::CPU, T>(
t_output, DataFormat::NCHW, output, DataFormat::NHWC);
} }
auto actual = net.GetTensor(output.c_str()); auto actual = net.GetTensor(output.c_str());
...@@ -172,8 +174,8 @@ TEST_F(PadTest, ComplexCPU) { ...@@ -172,8 +174,8 @@ TEST_F(PadTest, ComplexCPU) {
// Add input data // Add input data
net.AddRepeatedInput<DeviceType::CPU, float>("Input", {1, 1, 1, 2}, 2); net.AddRepeatedInput<DeviceType::CPU, float>("Input", {1, 1, 1, 2}, 2);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "TInput", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "TInput", DataFormat::NCHW);
OpDefBuilder("Pad", "PadTest") OpDefBuilder("Pad", "PadTest")
.Input("TInput") .Input("TInput")
.Output("TOutput") .Output("TOutput")
...@@ -184,8 +186,8 @@ TEST_F(PadTest, ComplexCPU) { ...@@ -184,8 +186,8 @@ TEST_F(PadTest, ComplexCPU) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("TOutput", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto output = net.GetTensor("Output"); auto output = net.GetTensor("Output");
...@@ -209,8 +211,8 @@ void Complex(const std::vector<index_t> &input_shape, ...@@ -209,8 +211,8 @@ void Complex(const std::vector<index_t> &input_shape,
// Add input data // Add input data
net.AddRandomInput<DeviceType::GPU, float>("Input", input_shape); net.AddRandomInput<DeviceType::GPU, float>("Input", input_shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "TInput", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "TInput", DataFormat::NCHW);
OpDefBuilder("Pad", "PadTest") OpDefBuilder("Pad", "PadTest")
.Input("TInput") .Input("TInput")
.Output("TOutput") .Output("TOutput")
...@@ -222,8 +224,8 @@ void Complex(const std::vector<index_t> &input_shape, ...@@ -222,8 +224,8 @@ void Complex(const std::vector<index_t> &input_shape,
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("TOutput", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
......
...@@ -270,9 +270,9 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase { ...@@ -270,9 +270,9 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase {
std::vector<int> paddings(2); std::vector<int> paddings(2);
if (paddings_.empty()) { if (paddings_.empty()) {
CalcPaddingAndOutputSize(input_tensor->shape().data(), CalcPaddingAndOutputSize(input_tensor->shape().data(),
NHWC, DataFormat::NHWC,
filter_shape.data(), filter_shape.data(),
OHWI, DataFormat::OHWI,
dilations_.data(), dilations_.data(),
strides_.data(), strides_.data(),
padding_type_, padding_type_,
...@@ -281,9 +281,9 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase { ...@@ -281,9 +281,9 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase {
} else { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input_tensor->shape().data(), CalcOutputSize(input_tensor->shape().data(),
NHWC, DataFormat::NHWC,
filter_shape.data(), filter_shape.data(),
OHWI, DataFormat::OHWI,
paddings_.data(), paddings_.data(),
dilations_.data(), dilations_.data(),
strides_.data(), strides_.data(),
...@@ -477,7 +477,7 @@ class PoolingOp<DeviceType::GPU, T> : public PoolingOpBase { ...@@ -477,7 +477,7 @@ class PoolingOp<DeviceType::GPU, T> : public PoolingOpBase {
public: public:
explicit PoolingOp(OpConstructContext *context) explicit PoolingOp(OpConstructContext *context)
: PoolingOpBase(context) { : PoolingOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::PoolingKernel<T>>(); kernel_ = make_unique<opencl::image::PoolingKernel<T>>();
} else { } else {
kernel_ = make_unique<opencl::buffer::PoolingKernel<T>>(); kernel_ = make_unique<opencl::buffer::PoolingKernel<T>>();
......
...@@ -34,8 +34,8 @@ TEST_F(PoolingOpTest, MAX_VALID) { ...@@ -34,8 +34,8 @@ TEST_F(PoolingOpTest, MAX_VALID) {
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}); 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -50,8 +50,8 @@ TEST_F(PoolingOpTest, MAX_VALID) { ...@@ -50,8 +50,8 @@ TEST_F(PoolingOpTest, MAX_VALID) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = auto expected =
...@@ -68,8 +68,8 @@ TEST_F(PoolingOpTest, MAX_SAME) { ...@@ -68,8 +68,8 @@ TEST_F(PoolingOpTest, MAX_SAME) {
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 3, 1}, net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 3, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8}); {0, 1, 2, 3, 4, 5, 6, 7, 8});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -84,8 +84,8 @@ TEST_F(PoolingOpTest, MAX_SAME) { ...@@ -84,8 +84,8 @@ TEST_F(PoolingOpTest, MAX_SAME) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 2, 2, 1}, {4, 5, 7, 8}); auto expected = net.CreateTensor<float>({1, 2, 2, 1}, {4, 5, 7, 8});
...@@ -102,8 +102,8 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) { ...@@ -102,8 +102,8 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
"Input", {1, 4, 4, 1}, "Input", {1, 4, 4, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -118,8 +118,8 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) { ...@@ -118,8 +118,8 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 2, 2, 1}, {10, 11, 14, 15}); auto expected = net.CreateTensor<float>({1, 2, 2, 1}, {10, 11, 14, 15});
...@@ -136,8 +136,8 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { ...@@ -136,8 +136,8 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
"Input", {1, 2, 9, 1}, "Input", {1, 2, 9, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}); {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -152,8 +152,8 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { ...@@ -152,8 +152,8 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 5, 1}, {10, 12, 14, 16, 17}); auto expected = net.CreateTensor<float>({1, 1, 5, 1}, {10, 12, 14, 16, 17});
...@@ -174,8 +174,8 @@ void SimpleMaxPooling3S2() { ...@@ -174,8 +174,8 @@ void SimpleMaxPooling3S2() {
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26}); 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26});
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Run // Run
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -187,8 +187,8 @@ void SimpleMaxPooling3S2() { ...@@ -187,8 +187,8 @@ void SimpleMaxPooling3S2() {
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
...@@ -224,8 +224,8 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape, ...@@ -224,8 +224,8 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape,
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", input_shape); net.AddRandomInput<D, float>("Input", input_shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -240,8 +240,8 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape, ...@@ -240,8 +240,8 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
...@@ -304,8 +304,8 @@ TEST_F(PoolingOpTest, AVG_VALID) { ...@@ -304,8 +304,8 @@ TEST_F(PoolingOpTest, AVG_VALID) {
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}); 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -320,8 +320,8 @@ TEST_F(PoolingOpTest, AVG_VALID) { ...@@ -320,8 +320,8 @@ TEST_F(PoolingOpTest, AVG_VALID) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>( auto expected = net.CreateTensor<float>(
...@@ -373,8 +373,8 @@ void AvgPoolingTest(const std::vector<index_t> &shape, ...@@ -373,8 +373,8 @@ void AvgPoolingTest(const std::vector<index_t> &shape,
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", shape); net.AddRandomInput<D, float>("Input", shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -389,8 +389,8 @@ void AvgPoolingTest(const std::vector<index_t> &shape, ...@@ -389,8 +389,8 @@ void AvgPoolingTest(const std::vector<index_t> &shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
...@@ -563,7 +563,7 @@ void TestQuant(const index_t batch, ...@@ -563,7 +563,7 @@ void TestQuant(const index_t batch,
net.AddRandomInput<CPU, float>( net.AddRandomInput<CPU, float>(
"Input", input_shape, false, false); "Input", input_shape, false, false);
net.TransformDataFormat<DeviceType::CPU, float>( net.TransformDataFormat<DeviceType::CPU, float>(
"Input", NHWC, "InputNCHW", NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.AddRandomInput<DeviceType::CPU, float>( net.AddRandomInput<DeviceType::CPU, float>(
"OutputNCHW", input_shape, false, true, true); "OutputNCHW", input_shape, false, true, true);
...@@ -580,7 +580,7 @@ void TestQuant(const index_t batch, ...@@ -580,7 +580,7 @@ void TestQuant(const index_t batch,
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>( net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", NCHW, "Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Quantize", "QuantizeInput") OpDefBuilder("Quantize", "QuantizeInput")
.Input("Input") .Input("Input")
......
...@@ -873,7 +873,7 @@ class ReduceOp<DeviceType::GPU, T> : public ReduceOpBase { ...@@ -873,7 +873,7 @@ class ReduceOp<DeviceType::GPU, T> : public ReduceOpBase {
public: public:
explicit ReduceOp(OpConstructContext *context) explicit ReduceOp(OpConstructContext *context)
: ReduceOpBase(context) { : ReduceOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ReduceKernel<T>>(reduce_type_, kernel_ = make_unique<opencl::image::ReduceKernel<T>>(reduce_type_,
axis_, axis_,
keep_dims_); keep_dims_);
...@@ -914,6 +914,9 @@ void RegisterReduce(OpRegistryBase *op_registry) { ...@@ -914,6 +914,9 @@ void RegisterReduce(OpRegistryBase *op_registry) {
.SetDevicePlacerFunc( .SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
bool keep_dims = bool keep_dims =
ProtoArgHelper::GetOptionalArg<OperatorDef, bool>( ProtoArgHelper::GetOptionalArg<OperatorDef, bool>(
*op, "keepdims", false); *op, "keepdims", false);
...@@ -923,7 +926,7 @@ void RegisterReduce(OpRegistryBase *op_registry) { ...@@ -923,7 +926,7 @@ void RegisterReduce(OpRegistryBase *op_registry) {
auto axis = auto axis =
ProtoArgHelper::GetRepeatedArgs<OperatorDef, int>( ProtoArgHelper::GetRepeatedArgs<OperatorDef, int>(
*op, "axis"); *op, "axis");
if (axis.size() != 2 || axis[0] != 1 || axis[1] == 2) { if (axis.size() != 2 || axis[0] != 1 || axis[1] != 2) {
return { DeviceType::CPU }; return { DeviceType::CPU };
} }
auto tensor_shape_info = context->tensor_shape_info(); auto tensor_shape_info = context->tensor_shape_info();
......
...@@ -38,7 +38,8 @@ void Simple(const std::vector<index_t> &input_shape, ...@@ -38,7 +38,8 @@ void Simple(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>("Input", input_shape, input); net.AddInputFromArray<D, float>("Input", input_shape, input);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<D, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Reduce", "ReduceTest") OpDefBuilder("Reduce", "ReduceTest")
.Input("InputNCHW") .Input("InputNCHW")
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
...@@ -49,7 +50,8 @@ void Simple(const std::vector<index_t> &input_shape, ...@@ -49,7 +50,8 @@ void Simple(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else { } else {
OpDefBuilder("Reduce", "ReduceTest") OpDefBuilder("Reduce", "ReduceTest")
.Input("Input") .Input("Input")
...@@ -289,8 +291,8 @@ void RandomTest(const std::vector<index_t> &input_shape, ...@@ -289,8 +291,8 @@ void RandomTest(const std::vector<index_t> &input_shape,
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", input_shape); net.AddRandomInput<D, float>("Input", input_shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Reduce", "ReduceTest") OpDefBuilder("Reduce", "ReduceTest")
.Input("InputNCHW") .Input("InputNCHW")
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
...@@ -301,8 +303,8 @@ void RandomTest(const std::vector<index_t> &input_shape, ...@@ -301,8 +303,8 @@ void RandomTest(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Reduce", "ReduceTest") OpDefBuilder("Reduce", "ReduceTest")
.Input("Input") .Input("Input")
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
...@@ -353,7 +355,7 @@ void TestQuant(const std::vector<index_t> &input_shape, ...@@ -353,7 +355,7 @@ void TestQuant(const std::vector<index_t> &input_shape,
net.AddRandomInput<CPU, float>( net.AddRandomInput<CPU, float>(
"Input", input_shape, false, false); "Input", input_shape, false, false);
net.TransformDataFormat<DeviceType::CPU, float>( net.TransformDataFormat<DeviceType::CPU, float>(
"Input", NHWC, "InputNCHW", NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.AddRandomInput<DeviceType::CPU, float>( net.AddRandomInput<DeviceType::CPU, float>(
"OutputNCHW", input_shape, false, true, true); "OutputNCHW", input_shape, false, true, true);
...@@ -368,7 +370,7 @@ void TestQuant(const std::vector<index_t> &input_shape, ...@@ -368,7 +370,7 @@ void TestQuant(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>( net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", NCHW, "Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Quantize", "QuantizeInput") OpDefBuilder("Quantize", "QuantizeInput")
.Input("Input") .Input("Input")
......
...@@ -51,7 +51,7 @@ MaceStatus Deconv2d<float>::Compute(const OpContext *context, ...@@ -51,7 +51,7 @@ MaceStatus Deconv2d<float>::Compute(const OpContext *context,
&out_pad_size, &out_pad_size,
&padded_out_shape, &padded_out_shape,
framework_type_, framework_type_,
NCHW); DataFormat::NCHW);
MACE_RETURN_IF_ERROR(output->Resize(out_shape)); MACE_RETURN_IF_ERROR(output->Resize(out_shape));
......
...@@ -50,7 +50,7 @@ MaceStatus DepthwiseDeconv2d<float>::Compute(const OpContext *context, ...@@ -50,7 +50,7 @@ MaceStatus DepthwiseDeconv2d<float>::Compute(const OpContext *context,
&out_pad_size, &out_pad_size,
&padded_out_shape, &padded_out_shape,
framework_type_, framework_type_,
NCHW); DataFormat::NCHW);
MACE_RETURN_IF_ERROR(output->Resize(out_shape)); MACE_RETURN_IF_ERROR(output->Resize(out_shape));
...@@ -185,7 +185,7 @@ MaceStatus GroupDeconv2d<float>::Compute(const OpContext *context, ...@@ -185,7 +185,7 @@ MaceStatus GroupDeconv2d<float>::Compute(const OpContext *context,
&out_pad_size, &out_pad_size,
&padded_out_shape, &padded_out_shape,
framework_type_, framework_type_,
NCHW); DataFormat::NCHW);
MACE_RETURN_IF_ERROR(output->Resize(out_shape)); MACE_RETURN_IF_ERROR(output->Resize(out_shape));
......
...@@ -212,7 +212,7 @@ class ResizeBicubicOp<DeviceType::GPU, T> : public Operation { ...@@ -212,7 +212,7 @@ class ResizeBicubicOp<DeviceType::GPU, T> : public Operation {
std::vector<index_t> size = Operation::GetRepeatedArgs<index_t>( std::vector<index_t> size = Operation::GetRepeatedArgs<index_t>(
"size", {-1, -1}); "size", {-1, -1});
MACE_CHECK(size.size() == 2); MACE_CHECK(size.size() == 2);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ResizeBicubicKernel<T>>( kernel_ = make_unique<opencl::image::ResizeBicubicKernel<T>>(
align_corners, size[0], size[1]); align_corners, size[0], size[1]);
} else { } else {
......
...@@ -31,8 +31,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) { ...@@ -31,8 +31,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) {
std::vector<float> input(24); std::vector<float> input(24);
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -42,8 +42,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) { ...@@ -42,8 +42,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8}); auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8});
...@@ -60,8 +60,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) { ...@@ -60,8 +60,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) {
std::vector<float> input(48); std::vector<float> input(48);
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 4, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 4, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -71,8 +71,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) { ...@@ -71,8 +71,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 2, 3, 3}, auto expected = net.CreateTensor<float>({1, 2, 3, 3},
...@@ -92,8 +92,8 @@ TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) { ...@@ -92,8 +92,8 @@ TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) {
std::vector<float> input(24); std::vector<float> input(24);
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -104,8 +104,8 @@ TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) { ...@@ -104,8 +104,8 @@ TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11}); auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11});
...@@ -133,8 +133,8 @@ void TestRandomResizeBicubic() { ...@@ -133,8 +133,8 @@ void TestRandomResizeBicubic() {
net.AddRandomInput<D, float>("Input", net.AddRandomInput<D, float>("Input",
{batch, in_height, in_width, channels}, {batch, in_height, in_width, channels},
false, true, true); false, true, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -144,8 +144,8 @@ void TestRandomResizeBicubic() { ...@@ -144,8 +144,8 @@ void TestRandomResizeBicubic() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on CPU // Run on CPU
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
......
...@@ -346,7 +346,7 @@ class ResizeBilinearOp<DeviceType::GPU, T> : public Operation { ...@@ -346,7 +346,7 @@ class ResizeBilinearOp<DeviceType::GPU, T> : public Operation {
std::vector<index_t> size = Operation::GetRepeatedArgs<index_t>( std::vector<index_t> size = Operation::GetRepeatedArgs<index_t>(
"size", {-1, -1}); "size", {-1, -1});
MACE_CHECK(size.size() == 2); MACE_CHECK(size.size() == 2);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ResizeBilinearKernel<T>>( kernel_ = make_unique<opencl::image::ResizeBilinearKernel<T>>(
align_corners, size[0], size[1]); align_corners, size[0], size[1]);
} else { } else {
......
...@@ -31,8 +31,8 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) { ...@@ -31,8 +31,8 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) {
std::vector<float> input(24); std::vector<float> input(24);
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -42,8 +42,8 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) { ...@@ -42,8 +42,8 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8}); auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8});
...@@ -60,8 +60,8 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) { ...@@ -60,8 +60,8 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) {
std::vector<float> input(24); std::vector<float> input(24);
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -72,8 +72,8 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) { ...@@ -72,8 +72,8 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11}); auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11});
...@@ -100,8 +100,8 @@ void TestRandomResizeBilinear() { ...@@ -100,8 +100,8 @@ void TestRandomResizeBilinear() {
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", net.AddRandomInput<D, float>("Input",
{batch, in_height, in_width, channels}); {batch, in_height, in_width, channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -111,8 +111,8 @@ void TestRandomResizeBilinear() { ...@@ -111,8 +111,8 @@ void TestRandomResizeBilinear() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on CPU // Run on CPU
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
...@@ -155,8 +155,8 @@ void TestQuantizedResizeBilinear() { ...@@ -155,8 +155,8 @@ void TestQuantizedResizeBilinear() {
true, true,
-1.f, -1.f,
1.f); 1.f);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -166,8 +166,8 @@ void TestQuantizedResizeBilinear() { ...@@ -166,8 +166,8 @@ void TestQuantizedResizeBilinear() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on CPU // Run on CPU
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// run quantize // run quantize
OpDefBuilder("Quantize", "QuantizeInput") OpDefBuilder("Quantize", "QuantizeInput")
......
...@@ -149,7 +149,7 @@ class ResizeNearestNeighborOp<DeviceType::GPU, T> : public Operation { ...@@ -149,7 +149,7 @@ class ResizeNearestNeighborOp<DeviceType::GPU, T> : public Operation {
: Operation(context) { : Operation(context) {
bool align_corners = Operation::GetOptionalArg<bool>( bool align_corners = Operation::GetOptionalArg<bool>(
"align_corners", false); "align_corners", false);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ResizeNearestNeighborKernel<T>>( kernel_ = make_unique<opencl::image::ResizeNearestNeighborKernel<T>>(
align_corners); align_corners);
} else { } else {
......
...@@ -32,8 +32,8 @@ TEST_F(ResizeNearestNeighborTest, CPUResizeNearestNeighborWOAlignCorners) { ...@@ -32,8 +32,8 @@ TEST_F(ResizeNearestNeighborTest, CPUResizeNearestNeighborWOAlignCorners) {
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
std::vector<int32_t> size = {1, 2}; std::vector<int32_t> size = {1, 2};
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.AddInputFromArray<DeviceType::CPU, int32_t>("Size", {2}, size); net.AddInputFromArray<DeviceType::CPU, int32_t>("Size", {2}, size);
OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest") OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest")
...@@ -45,8 +45,8 @@ TEST_F(ResizeNearestNeighborTest, CPUResizeNearestNeighborWOAlignCorners) { ...@@ -45,8 +45,8 @@ TEST_F(ResizeNearestNeighborTest, CPUResizeNearestNeighborWOAlignCorners) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8}); auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8});
...@@ -64,8 +64,8 @@ TEST_F(ResizeNearestNeighborTest, ResizeNearestNeighborWAlignCorners) { ...@@ -64,8 +64,8 @@ TEST_F(ResizeNearestNeighborTest, ResizeNearestNeighborWAlignCorners) {
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
std::vector<int32_t> size = {1, 2}; std::vector<int32_t> size = {1, 2};
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.AddInputFromArray<DeviceType::CPU, int32_t>("Size", {2}, size); net.AddInputFromArray<DeviceType::CPU, int32_t>("Size", {2}, size);
OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest") OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest")
...@@ -78,8 +78,8 @@ TEST_F(ResizeNearestNeighborTest, ResizeNearestNeighborWAlignCorners) { ...@@ -78,8 +78,8 @@ TEST_F(ResizeNearestNeighborTest, ResizeNearestNeighborWAlignCorners) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11}); auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11});
...@@ -105,8 +105,8 @@ void TestRandomResizeNearestNeighbor() { ...@@ -105,8 +105,8 @@ void TestRandomResizeNearestNeighbor() {
std::vector<int32_t> size = {20, 40}; std::vector<int32_t> size = {20, 40};
net.AddRandomInput<D, float>("Input", net.AddRandomInput<D, float>("Input",
{batch, in_height, in_width, channels}); {batch, in_height, in_width, channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.AddInputFromArray<D, int32_t>("Size", {2}, size); net.AddInputFromArray<D, int32_t>("Size", {2}, size);
OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest") OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -116,8 +116,8 @@ void TestRandomResizeNearestNeighbor() { ...@@ -116,8 +116,8 @@ void TestRandomResizeNearestNeighbor() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on CPU // Run on CPU
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
......
...@@ -414,7 +414,7 @@ class SoftmaxOp<DeviceType::GPU, T> : public Operation { ...@@ -414,7 +414,7 @@ class SoftmaxOp<DeviceType::GPU, T> : public Operation {
: Operation(context) { : Operation(context) {
bool use_log = ( bool use_log = (
Operation::GetOptionalArg<bool>("use_log", false)); Operation::GetOptionalArg<bool>("use_log", false));
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::SoftmaxKernel<T>>(use_log); kernel_ = make_unique<opencl::image::SoftmaxKernel<T>>(use_log);
} else { } else {
kernel_ = make_unique<opencl::buffer::SoftmaxKernel<T>>(use_log); kernel_ = make_unique<opencl::buffer::SoftmaxKernel<T>>(use_log);
......
...@@ -50,7 +50,8 @@ void Simple(bool use_log = false) { ...@@ -50,7 +50,8 @@ void Simple(bool use_log = false) {
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
// test 4d softmax // test 4d softmax
net.TransformDataFormat<CPU, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Softmax", "SoftmaxTest") OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -59,7 +60,8 @@ void Simple(bool use_log = false) { ...@@ -59,7 +60,8 @@ void Simple(bool use_log = false) {
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
...@@ -109,7 +111,8 @@ void Complex(const std::vector<index_t> &logits_shape, ...@@ -109,7 +111,8 @@ void Complex(const std::vector<index_t> &logits_shape,
net.AddRandomInput<D, float>("Input", logits_shape); net.AddRandomInput<D, float>("Input", logits_shape);
if (logits_shape.size() == 4) { if (logits_shape.size() == 4) {
net.TransformDataFormat<CPU, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Softmax", "SoftmaxTest") OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -127,7 +130,8 @@ void Complex(const std::vector<index_t> &logits_shape, ...@@ -127,7 +130,8 @@ void Complex(const std::vector<index_t> &logits_shape,
net.RunOp(); net.RunOp();
if (logits_shape.size() == 4) { if (logits_shape.size() == 4) {
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
......
...@@ -307,7 +307,7 @@ class SpaceToBatchNDOp<DeviceType::GPU, T> : public SpaceToBatchOpBase { ...@@ -307,7 +307,7 @@ class SpaceToBatchNDOp<DeviceType::GPU, T> : public SpaceToBatchOpBase {
public: public:
explicit SpaceToBatchNDOp(OpConstructContext *context) explicit SpaceToBatchNDOp(OpConstructContext *context)
: SpaceToBatchOpBase(context) { : SpaceToBatchOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::SpaceToBatchKernel<T>>(); kernel_ = make_unique<opencl::image::SpaceToBatchKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
......
...@@ -39,8 +39,8 @@ void RunSpaceToBatch(const std::vector<index_t> &input_shape, ...@@ -39,8 +39,8 @@ void RunSpaceToBatch(const std::vector<index_t> &input_shape,
.AddIntsArg("block_shape", block_shape_data) .AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else if (D == CPU) { } else if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest") OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -53,8 +53,8 @@ void RunSpaceToBatch(const std::vector<index_t> &input_shape, ...@@ -53,8 +53,8 @@ void RunSpaceToBatch(const std::vector<index_t> &input_shape,
net.RunOp(D); net.RunOp(D);
if (D == CPU) { if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
// Check // Check
ExpectTensorNear<float>(*expected, *net.GetOutput("Output")); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"));
...@@ -78,8 +78,8 @@ void RunBatchToSpace(const std::vector<index_t> &input_shape, ...@@ -78,8 +78,8 @@ void RunBatchToSpace(const std::vector<index_t> &input_shape,
.AddIntsArg("block_shape", block_shape_data) .AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else if (D == CPU) { } else if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest") OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -92,8 +92,8 @@ void RunBatchToSpace(const std::vector<index_t> &input_shape, ...@@ -92,8 +92,8 @@ void RunBatchToSpace(const std::vector<index_t> &input_shape,
net.RunOp(D); net.RunOp(D);
if (D == CPU) { if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
// Check // Check
ExpectTensorNear<float>(*expected, *net.GetOutput("Output")); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"));
...@@ -155,8 +155,8 @@ void TestSpaceToBatchLargeInput(const std::vector<index_t> &input_shape, ...@@ -155,8 +155,8 @@ void TestSpaceToBatchLargeInput(const std::vector<index_t> &input_shape,
net.RunOp(GPU); net.RunOp(GPU);
// run cpu // run cpu
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest") OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -164,8 +164,8 @@ void TestSpaceToBatchLargeInput(const std::vector<index_t> &input_shape, ...@@ -164,8 +164,8 @@ void TestSpaceToBatchLargeInput(const std::vector<index_t> &input_shape,
.AddIntsArg("block_shape", block_shape_data) .AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"OutputCPU", NHWC); "OutputNCHW", DataFormat::NCHW, "OutputCPU", DataFormat::NHWC);
// Check // Check
ExpectTensorNear<float>(*net.GetOutput("OutputCPU"), ExpectTensorNear<float>(*net.GetOutput("OutputCPU"),
...@@ -188,8 +188,8 @@ void TestoBatchToSpaceLargeInput(const std::vector<index_t> &input_shape, ...@@ -188,8 +188,8 @@ void TestoBatchToSpaceLargeInput(const std::vector<index_t> &input_shape,
net.RunOp(GPU); net.RunOp(GPU);
// run cpu // run cpu
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest") OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -197,8 +197,8 @@ void TestoBatchToSpaceLargeInput(const std::vector<index_t> &input_shape, ...@@ -197,8 +197,8 @@ void TestoBatchToSpaceLargeInput(const std::vector<index_t> &input_shape,
.AddIntsArg("block_shape", block_shape_data) .AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"OutputCPU", NHWC); "OutputNCHW", DataFormat::NCHW, "OutputCPU", DataFormat::NHWC);
// Check // Check
ExpectTensorNear<float>(*net.GetOutput("OutputCPU"), ExpectTensorNear<float>(*net.GetOutput("OutputCPU"),
...@@ -218,8 +218,8 @@ void TestSpaceToBatchQuantize(const std::vector<index_t> &input_shape, ...@@ -218,8 +218,8 @@ void TestSpaceToBatchQuantize(const std::vector<index_t> &input_shape,
1.f); 1.f);
// run cpu // run cpu
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest") OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -227,8 +227,8 @@ void TestSpaceToBatchQuantize(const std::vector<index_t> &input_shape, ...@@ -227,8 +227,8 @@ void TestSpaceToBatchQuantize(const std::vector<index_t> &input_shape,
.AddIntsArg("block_shape", block_shape_data) .AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"OutputCPU", NHWC); "OutputNCHW", DataFormat::NCHW, "OutputCPU", DataFormat::NHWC);
// run quantize // run quantize
OpDefBuilder("Quantize", "QuantizeInput") OpDefBuilder("Quantize", "QuantizeInput")
...@@ -279,8 +279,8 @@ void TestoBatchToSpaceQuantize(const std::vector<index_t> &input_shape, ...@@ -279,8 +279,8 @@ void TestoBatchToSpaceQuantize(const std::vector<index_t> &input_shape,
1.f); 1.f);
// run cpu // run cpu
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest") OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -288,8 +288,8 @@ void TestoBatchToSpaceQuantize(const std::vector<index_t> &input_shape, ...@@ -288,8 +288,8 @@ void TestoBatchToSpaceQuantize(const std::vector<index_t> &input_shape,
.AddIntsArg("block_shape", block_shape_data) .AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"OutputCPU", NHWC); "OutputNCHW", DataFormat::NCHW, "OutputCPU", DataFormat::NHWC);
// run quantize // run quantize
OpDefBuilder("Quantize", "QuantizeInput") OpDefBuilder("Quantize", "QuantizeInput")
......
...@@ -94,7 +94,7 @@ class SpaceToDepthOp<DeviceType::GPU, T> : public Operation { ...@@ -94,7 +94,7 @@ class SpaceToDepthOp<DeviceType::GPU, T> : public Operation {
explicit SpaceToDepthOp(OpConstructContext *context) explicit SpaceToDepthOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
int block_size = Operation::GetOptionalArg<int>("block_size", 1); int block_size = Operation::GetOptionalArg<int>("block_size", 1);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::SpaceToDepthKernel<T>>(block_size); kernel_ = make_unique<opencl::image::SpaceToDepthKernel<T>>(block_size);
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
......
...@@ -32,8 +32,8 @@ void RunSpaceToDepth(const std::vector<index_t> &input_shape, ...@@ -32,8 +32,8 @@ void RunSpaceToDepth(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>("Input", input_shape, input_data); net.AddInputFromArray<D, float>("Input", input_shape, input_data);
// Construct graph // Construct graph
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("SpaceToDepth", "SpaceToDepthTest") OpDefBuilder("SpaceToDepth", "SpaceToDepthTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -41,8 +41,8 @@ void RunSpaceToDepth(const std::vector<index_t> &input_shape, ...@@ -41,8 +41,8 @@ void RunSpaceToDepth(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else { } else {
OpDefBuilder("SpaceToDepth", "SpaceToDepthTest") OpDefBuilder("SpaceToDepth", "SpaceToDepthTest")
...@@ -107,8 +107,8 @@ void RandomTest(const int block_size, ...@@ -107,8 +107,8 @@ void RandomTest(const int block_size,
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", shape); net.AddRandomInput<D, float>("Input", shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("SpaceToDepth", "SpaceToDepthTest") OpDefBuilder("SpaceToDepth", "SpaceToDepthTest")
.Input("InputNCHW") .Input("InputNCHW")
.AddIntArg("block_size", block_size) .AddIntArg("block_size", block_size)
...@@ -118,8 +118,8 @@ void RandomTest(const int block_size, ...@@ -118,8 +118,8 @@ void RandomTest(const int block_size,
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("SpaceToDepth", "SpaceToDepthTest") OpDefBuilder("SpaceToDepth", "SpaceToDepthTest")
.Input("Input") .Input("Input")
......
...@@ -106,7 +106,7 @@ class SplitOp<DeviceType::GPU, T> : public Operation { ...@@ -106,7 +106,7 @@ class SplitOp<DeviceType::GPU, T> : public Operation {
explicit SplitOp(OpConstructContext *context) explicit SplitOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
int32_t axis = Operation::GetOptionalArg<int>("axis", 3); int32_t axis = Operation::GetOptionalArg<int>("axis", 3);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::SplitKernel<T>>(axis); kernel_ = make_unique<opencl::image::SplitKernel<T>>(axis);
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -147,7 +147,7 @@ void RegisterSplit(OpRegistryBase *op_registry) { ...@@ -147,7 +147,7 @@ void RegisterSplit(OpRegistryBase *op_registry) {
[](OpConditionContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) { if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU }; return {DeviceType::CPU, DeviceType::GPU};
} }
int axis = ProtoArgHelper::GetOptionalArg<OperatorDef, int>( int axis = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "axis", 3); *op, "axis", 3);
......
...@@ -83,7 +83,7 @@ class SqrDiffMeanOp<DeviceType::GPU, T> : public Operation { ...@@ -83,7 +83,7 @@ class SqrDiffMeanOp<DeviceType::GPU, T> : public Operation {
public: public:
explicit SqrDiffMeanOp(OpConstructContext *context) explicit SqrDiffMeanOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::SqrDiffMeanKernel<T>>(); kernel_ = make_unique<opencl::image::SqrDiffMeanKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
......
...@@ -36,13 +36,13 @@ void Simple(const std::vector<index_t> &input_shape0, ...@@ -36,13 +36,13 @@ void Simple(const std::vector<index_t> &input_shape0,
net.AddInputFromArray<D, float>("Input1", input_shape1, input1); net.AddInputFromArray<D, float>("Input1", input_shape1, input1);
net.TransformDataFormat<DeviceType::CPU, float>("Input0", net.TransformDataFormat<DeviceType::CPU, float>("Input0",
NHWC, DataFormat::NHWC,
"InputNCHW0", "InputNCHW0",
NCHW); DataFormat::NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Input1", net.TransformDataFormat<DeviceType::CPU, float>("Input1",
NHWC, DataFormat::NHWC,
"InputNCHW1", "InputNCHW1",
NCHW); DataFormat::NCHW);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest") OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest")
...@@ -54,9 +54,9 @@ void Simple(const std::vector<index_t> &input_shape0, ...@@ -54,9 +54,9 @@ void Simple(const std::vector<index_t> &input_shape0,
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW, DataFormat::NCHW,
"Output", "Output",
NHWC); DataFormat::NHWC);
} else { } else {
OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest") OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest")
.Input("Input0") .Input("Input0")
...@@ -107,10 +107,10 @@ void RandomTest(const std::vector<index_t> &input_shape0, ...@@ -107,10 +107,10 @@ void RandomTest(const std::vector<index_t> &input_shape0,
net.AddRandomInput<D, float>("Input0", input_shape0); net.AddRandomInput<D, float>("Input0", input_shape0);
net.AddRandomInput<D, float>("Input1", input_shape1); net.AddRandomInput<D, float>("Input1", input_shape1);
net.TransformDataFormat<DeviceType::CPU, float>("Input0", NHWC, "InputNCHW0", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input0", DataFormat::NHWC, "InputNCHW0", DataFormat::NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Input1", NHWC, "InputNCHW1", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input1", DataFormat::NHWC, "InputNCHW1", DataFormat::NCHW);
OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest") OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest")
.Input("InputNCHW0") .Input("InputNCHW0")
.Input("InputNCHW1") .Input("InputNCHW1")
...@@ -118,8 +118,8 @@ void RandomTest(const std::vector<index_t> &input_shape0, ...@@ -118,8 +118,8 @@ void RandomTest(const std::vector<index_t> &input_shape0,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest") OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest")
.Input("Input0") .Input("Input0")
.Input("Input1") .Input("Input1")
......
...@@ -86,8 +86,8 @@ void TestStridedSliceWithDataFormat(const std::vector<index_t> &input_shape, ...@@ -86,8 +86,8 @@ void TestStridedSliceWithDataFormat(const std::vector<index_t> &input_shape,
net.AddInputFromArray<CPU, int32_t>( net.AddInputFromArray<CPU, int32_t>(
"Strides", {static_cast<int32_t>(strides.size())}, strides); "Strides", {static_cast<int32_t>(strides.size())}, strides);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("StridedSlice", "StridedSliceOpTest") OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -105,8 +105,8 @@ void TestStridedSliceWithDataFormat(const std::vector<index_t> &input_shape, ...@@ -105,8 +105,8 @@ void TestStridedSliceWithDataFormat(const std::vector<index_t> &input_shape,
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output); net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"), ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output")); *net.GetOutput("Output"));
...@@ -154,8 +154,8 @@ void TestSliceWithDataFormat(const std::vector<index_t> &input_shape, ...@@ -154,8 +154,8 @@ void TestSliceWithDataFormat(const std::vector<index_t> &input_shape,
net.AddInputFromArray<CPU, int32_t>( net.AddInputFromArray<CPU, int32_t>(
"IndicesSize", {static_cast<int32_t>(indices_size.size())}, indices_size); "IndicesSize", {static_cast<int32_t>(indices_size.size())}, indices_size);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("StridedSlice", "StridedSliceOpTest") OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -168,8 +168,8 @@ void TestSliceWithDataFormat(const std::vector<index_t> &input_shape, ...@@ -168,8 +168,8 @@ void TestSliceWithDataFormat(const std::vector<index_t> &input_shape,
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output); net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"), ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output")); *net.GetOutput("Output"));
......
...@@ -34,10 +34,10 @@ class NetDef; ...@@ -34,10 +34,10 @@ class NetDef;
enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3, HTA = 4 }; enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3, HTA = 4 };
enum DataFormat { enum class DataFormat {
DF_NONE = 0, NHWC = 1, NCHW = 2, NONE = 0, NHWC = 1, NCHW = 2,
HWOI = 100, OIHW = 101, HWIO = 102, OHWI = 103, HWOI = 100, OIHW = 101, HWIO = 102, OHWI = 103,
DF_AUTO = 1000, AUTO = 1000,
}; };
enum GPUPerfHint { enum GPUPerfHint {
......
...@@ -41,7 +41,7 @@ device_type_map = {'cpu': cvt.DeviceType.CPU.value, ...@@ -41,7 +41,7 @@ device_type_map = {'cpu': cvt.DeviceType.CPU.value,
'cpu+gpu': cvt.DeviceType.CPU.value} 'cpu+gpu': cvt.DeviceType.CPU.value}
data_format_map = { data_format_map = {
'NONE': cvt.DataFormat.DF_NONE, 'NONE': cvt.DataFormat.NONE,
'NHWC': cvt.DataFormat.NHWC, 'NHWC': cvt.DataFormat.NHWC,
'NCHW': cvt.DataFormat.NCHW, 'NCHW': cvt.DataFormat.NCHW,
'OIHW': cvt.DataFormat.OIHW, 'OIHW': cvt.DataFormat.OIHW,
......
...@@ -26,14 +26,14 @@ class DeviceType(Enum): ...@@ -26,14 +26,14 @@ class DeviceType(Enum):
class DataFormat(Enum): class DataFormat(Enum):
DF_NONE = 0 NONE = 0
NHWC = 1 NHWC = 1
NCHW = 2 NCHW = 2
HWIO = 100 HWIO = 100
OIHW = 101 OIHW = 101
HWOI = 102 HWOI = 102
OHWI = 103 OHWI = 103
DF_AUTO = 1000 AUTO = 1000
# SAME_LOWER: if the amount of paddings to be added is odd, # SAME_LOWER: if the amount of paddings to be added is odd,
...@@ -598,8 +598,8 @@ class ConverterUtil(object): ...@@ -598,8 +598,8 @@ class ConverterUtil(object):
return DataFormat.NHWC return DataFormat.NHWC
elif arg.i == DataFormat.NCHW.value: elif arg.i == DataFormat.NCHW.value:
return DataFormat.NCHW return DataFormat.NCHW
elif arg.i == DataFormat.DF_AUTO.value: elif arg.i == DataFormat.AUTO.value:
return DataFormat.DF_AUTO return DataFormat.AUTO
else: else:
return None return None
......
...@@ -387,7 +387,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -387,7 +387,8 @@ class OnnxConverter(base_converter.ConverterInterface):
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
self._data_format = DataFormat.NCHW self._data_format = DataFormat.NCHW
ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
ConverterUtil.add_data_format_arg(self._mace_net_def, self._data_format) ConverterUtil.add_data_format_arg(self._mace_net_def,
self._data_format)
onnx_model = onnx.load(src_model_file) onnx_model = onnx.load(src_model_file)
ir_version = onnx_model.ir_version ir_version = onnx_model.ir_version
...@@ -403,7 +404,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -403,7 +404,7 @@ class OnnxConverter(base_converter.ConverterInterface):
print("constains ops domain: ", domain, "version:", version) print("constains ops domain: ", domain, "version:", version)
if 'kaldi2onnx' in domain: if 'kaldi2onnx' in domain:
polish_available = False polish_available = False
self._data_format = DataFormat.DF_NONE self._data_format = DataFormat.NONE
self._isKaldi = True self._isKaldi = True
if polish_available: if polish_available:
onnx_model = onnx.utils.polish_model(onnx_model) onnx_model = onnx.utils.polish_model(onnx_model)
......
...@@ -27,7 +27,7 @@ from mace.python.tools.converter_tool.base_converter import EltwiseType ...@@ -27,7 +27,7 @@ from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import FrameworkType from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.converter_tool.base_converter import MaceKeyword from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import MaceOp from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceHasDataFormatOps from mace.python.tools.converter_tool.base_converter import MaceHasDataFormatOps # noqa
from mace.python.tools.converter_tool.base_converter import MaceMayHasDataFormatOps # noqa from mace.python.tools.converter_tool.base_converter import MaceMayHasDataFormatOps # noqa
from mace.python.tools.converter_tool.base_converter import PaddingMode from mace.python.tools.converter_tool.base_converter import PaddingMode
from mace.python.tools.converter_tool.base_converter import ReduceType from mace.python.tools.converter_tool.base_converter import ReduceType
...@@ -200,15 +200,15 @@ class Transformer(base_converter.ConverterInterface): ...@@ -200,15 +200,15 @@ class Transformer(base_converter.ConverterInterface):
op.output.extend([input_node.name]) op.output.extend([input_node.name])
output_shape = op.output_shape.add() output_shape = op.output_shape.add()
output_shape.dims.extend(input_node.shape) output_shape.dims.extend(input_node.shape)
if input_node.data_format != DataFormat.DF_NONE: if input_node.data_format != DataFormat.NONE:
if input_node.data_format == DataFormat.NCHW: if input_node.data_format == DataFormat.NCHW:
self.transpose_shape(output_shape.dims, self.transpose_shape(output_shape.dims,
[0, 3, 1, 2]) [0, 3, 1, 2])
ConverterUtil.add_data_format_arg(op, ConverterUtil.add_data_format_arg(op,
DataFormat.DF_AUTO) DataFormat.AUTO)
else: else:
ConverterUtil.add_data_format_arg(op, ConverterUtil.add_data_format_arg(op,
DataFormat.DF_NONE) DataFormat.NONE)
self._producer[op.output[0]] = op self._producer[op.output[0]] = op
@staticmethod @staticmethod
...@@ -261,7 +261,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -261,7 +261,7 @@ class Transformer(base_converter.ConverterInterface):
producer = self._producer[tensor] producer = self._producer[tensor]
return ConverterUtil.data_format(producer) return ConverterUtil.data_format(producer)
else: else:
return DataFormat.DF_NONE return DataFormat.NONE
def consumer_count(self, tensor_name): def consumer_count(self, tensor_name):
return len(self._consumers.get(tensor_name, [])) return len(self._consumers.get(tensor_name, []))
...@@ -1021,7 +1021,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1021,7 +1021,6 @@ class Transformer(base_converter.ConverterInterface):
filter_format.name) filter_format.name)
return False return False
def add_winograd_arg(self): def add_winograd_arg(self):
if self._wino_arg == 0: if self._wino_arg == 0:
return False return False
...@@ -1350,20 +1349,21 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1350,20 +1349,21 @@ class Transformer(base_converter.ConverterInterface):
df_arg = op.arg.add() df_arg = op.arg.add()
df_arg.name = MaceKeyword.mace_data_format_str df_arg.name = MaceKeyword.mace_data_format_str
if op.type in MaceHasDataFormatOps: if op.type in MaceHasDataFormatOps:
df_arg.i = DataFormat.DF_AUTO.value df_arg.i = DataFormat.AUTO.value
elif op.type in MaceMayHasDataFormatOps: elif op.type in MaceMayHasDataFormatOps:
input_df = DataFormat.DF_AUTO.value input_df = DataFormat.AUTO.value
for input_tensor in op.input: for input_tensor in op.input:
if input_tensor in self._consts: if input_tensor in self._consts:
continue continue
mace_check(input_tensor in self._producer, mace_check(
"Input tensor %s not in producer" % input_tensor) input_tensor in self._producer,
"Input tensor %s not in producer" % input_tensor)
father_op = self._producer[input_tensor] father_op = self._producer[input_tensor]
temp_input_df = ConverterUtil.get_arg( temp_input_df = ConverterUtil.get_arg(
father_op, MaceKeyword.mace_data_format_str) father_op, MaceKeyword.mace_data_format_str)
if temp_input_df.i != DataFormat.DF_AUTO.value: if temp_input_df.i != DataFormat.AUTO.value:
input_df = temp_input_df.i input_df = temp_input_df.i
if input_df == DataFormat.DF_AUTO.value: if input_df == DataFormat.AUTO.value:
df_arg.i = input_df df_arg.i = input_df
# add flag to mark the ops may has data format # add flag to mark the ops may has data format
has_data_format_arg = op.arg.add() has_data_format_arg = op.arg.add()
...@@ -1379,7 +1379,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1379,7 +1379,7 @@ class Transformer(base_converter.ConverterInterface):
src_data_format = ConverterUtil.data_format(net) src_data_format = ConverterUtil.data_format(net)
for op in net.op: for op in net.op:
has_data_format = ConverterUtil.data_format(op) == \ has_data_format = ConverterUtil.data_format(op) == \
DataFormat.DF_AUTO DataFormat.AUTO
# transpose args # transpose args
if op.type == MaceOp.Pad.name: if op.type == MaceOp.Pad.name:
for arg in op.arg: for arg in op.arg:
......
...@@ -80,7 +80,7 @@ void CreateInputInfo(NetDef *net_def) { ...@@ -80,7 +80,7 @@ void CreateInputInfo(NetDef *net_def) {
input_info = net_def->add_input_info(); input_info = net_def->add_input_info();
input_info->set_name({{ net.input_info[idx].name|tojson }}); input_info->set_name({{ net.input_info[idx].name|tojson }});
input_info->set_data_type(static_cast<DataType>({{ net.input_info[idx].data_type }})); input_info->set_data_type(static_cast<DataType>({{ net.input_info[idx].data_type }}));
input_info->set_data_format(static_cast<DataFormat>({{ net.input_info[idx].data_format }})); input_info->set_data_format({{ net.input_info[idx].data_format }});
input_info->mutable_dims()->Reserve({{ net.input_info[idx].dims|length }}); input_info->mutable_dims()->Reserve({{ net.input_info[idx].dims|length }});
{% for dim in net.input_info[idx].dims %} {% for dim in net.input_info[idx].dims %}
input_info->add_dims({{ dim }}); input_info->add_dims({{ dim }});
...@@ -97,7 +97,7 @@ void CreateOutputInfo(NetDef *net_def) { ...@@ -97,7 +97,7 @@ void CreateOutputInfo(NetDef *net_def) {
output_info = net_def->add_output_info(); output_info = net_def->add_output_info();
output_info->set_name({{ net.output_info[idx].name|tojson }}); output_info->set_name({{ net.output_info[idx].name|tojson }});
output_info->set_data_type(static_cast<DataType>({{ net.output_info[idx].data_type }})); output_info->set_data_type(static_cast<DataType>({{ net.output_info[idx].data_type }}));
output_info->set_data_format(static_cast<DataFormat>({{ net.output_info[idx].data_format }})); output_info->set_data_format({{ net.output_info[idx].data_format }});
output_info->mutable_dims()->Reserve({{ net.output_info[idx].dims|length }}); output_info->mutable_dims()->Reserve({{ net.output_info[idx].dims|length }});
{% for dim in net.output_info[idx].dims %} {% for dim in net.output_info[idx].dims %}
output_info->add_dims({{dim}}); output_info->add_dims({{dim}});
......
...@@ -48,7 +48,7 @@ void MaceRunFunc(const int in_out_size) { ...@@ -48,7 +48,7 @@ void MaceRunFunc(const int in_out_size) {
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
InputOutputInfo *info = net_def->add_input_info(); InputOutputInfo *info = net_def->add_input_info();
info->set_data_format(DataFormat::NHWC); info->set_data_format(static_cast<int>(DataFormat::NHWC));
info->set_name(input_names[i]); info->set_name(input_names[i]);
for (auto d : input_shapes[0]) { for (auto d : input_shapes[0]) {
info->add_dims(static_cast<int>(d)); info->add_dims(static_cast<int>(d));
......
...@@ -45,7 +45,7 @@ void MaceRun(const int in_out_size, ...@@ -45,7 +45,7 @@ void MaceRun(const int in_out_size,
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
InputOutputInfo *info = net_def->add_input_info(); InputOutputInfo *info = net_def->add_input_info();
info->set_data_format(DataFormat::NHWC); info->set_data_format(static_cast<int>(DataFormat::NHWC));
info->set_name(input_names[i]); info->set_name(input_names[i]);
for (auto d : max_shape) { for (auto d : max_shape) {
info->add_dims(static_cast<int>(d)); info->add_dims(static_cast<int>(d));
......
...@@ -76,7 +76,7 @@ void Conv3x3(const std::string &input_name, ...@@ -76,7 +76,7 @@ void Conv3x3(const std::string &input_name,
.AddIntArg("padding", Padding::SAME) .AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("has_data_format", 1) .AddIntArg("data_format", static_cast<int>(DataFormat::AUTO))
.Finalize(&operator_def); .Finalize(&operator_def);
OutputShape *shape = operator_def.add_output_shape(); OutputShape *shape = operator_def.add_output_shape();
...@@ -99,7 +99,7 @@ void Relu(const std::string &input_name, ...@@ -99,7 +99,7 @@ void Relu(const std::string &input_name,
.AddStringArg("activation", "RELU") .AddStringArg("activation", "RELU")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type)) .AddIntArg("device", static_cast<int>(device_type))
.AddIntArg("has_data_format", 1) .AddIntArg("data_format", static_cast<int>(DataFormat::AUTO))
.Finalize(&operator_def); .Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def); net_def->add_op()->CopyFrom(operator_def);
...@@ -139,7 +139,8 @@ void CheckOutputs(const NetDef &net_def, ...@@ -139,7 +139,8 @@ void CheckOutputs(const NetDef &net_def,
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
std::string input_name = input.first + "NHWC"; std::string input_name = input.first + "NHWC";
net.AddInputFromArray<D, float>(input_name, input_shape, input_data); net.AddInputFromArray<D, float>(input_name, input_shape, input_data);
net.TransformDataFormat<D, float>(input_name, NHWC, input.first, NCHW); net.TransformDataFormat<D, float>(
input_name, DataFormat::NHWC, input.first, DataFormat::NCHW);
} else { } else {
net.AddInputFromArray<D, float>(input.first, input_shape, input_data); net.AddInputFromArray<D, float>(input.first, input_shape, input_data);
} }
...@@ -154,7 +155,7 @@ void CheckOutputs(const NetDef &net_def, ...@@ -154,7 +155,7 @@ void CheckOutputs(const NetDef &net_def,
memcpy(data.data(), memcpy(data.data(),
reinterpret_cast<const T *>(tensor_data.data()) + tensor.offset(), reinterpret_cast<const T *>(tensor_data.data()) + tensor.offset(),
tensor.data_size() * sizeof(T)); tensor.data_size() * sizeof(T));
net.AddInputFromArray<D, T>(tensor.name(), shape, data); net.AddInputFromArray<D, T>(tensor.name(), shape, data, true);
} }
net.RunNet(net_def, D); net.RunNet(net_def, D);
...@@ -175,9 +176,9 @@ void CheckOutputs(const NetDef &net_def, ...@@ -175,9 +176,9 @@ void CheckOutputs(const NetDef &net_def,
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
output_name = output.first + "NHWC"; output_name = output.first + "NHWC";
net.TransformDataFormat<CPU, float>(output.first, net.TransformDataFormat<CPU, float>(output.first,
NCHW, DataFormat::NCHW,
output_name, output_name,
NHWC); DataFormat::NHWC);
} }
ops::test::ExpectTensorNear<float>(*tmp_tensor, ops::test::ExpectTensorNear<float>(*tmp_tensor,
*net.GetOutput(output_name.data()), *net.GetOutput(output_name.data()),
......
...@@ -91,7 +91,7 @@ DataFormat ParseDataFormat(const std::string &data_format_str) { ...@@ -91,7 +91,7 @@ DataFormat ParseDataFormat(const std::string &data_format_str) {
} else if (data_format_str == "OIHW") { } else if (data_format_str == "OIHW") {
return DataFormat::OIHW; return DataFormat::OIHW;
} else { } else {
return DataFormat::DF_NONE; return DataFormat::NONE;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册