提交 50cf1737 编写于 作者: 李寅

Merge branch 'refactor-data-format' into 'master'

Refactor data format

See merge request !1069
...@@ -83,7 +83,7 @@ DataFormat ParseDataFormat(const std::string &data_format_str) { ...@@ -83,7 +83,7 @@ DataFormat ParseDataFormat(const std::string &data_format_str) {
} else if (data_format_str == "OIHW") { } else if (data_format_str == "OIHW") {
return DataFormat::OIHW; return DataFormat::OIHW;
} else { } else {
return DataFormat::DF_NONE; return DataFormat::NONE;
} }
} }
......
...@@ -96,6 +96,43 @@ MACE_GET_REPEATED_ARGUMENT_FUNC(int, ints, true) ...@@ -96,6 +96,43 @@ MACE_GET_REPEATED_ARGUMENT_FUNC(int, ints, true)
MACE_GET_REPEATED_ARGUMENT_FUNC(int64_t, ints, true) MACE_GET_REPEATED_ARGUMENT_FUNC(int64_t, ints, true)
#undef MACE_GET_REPEATED_ARGUMENT_FUNC #undef MACE_GET_REPEATED_ARGUMENT_FUNC
#define MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, T, fieldname) \
template<> \
void SetProtoArg<T>(Def *def, \
const std::string &arg_name, \
const T &value) { \
int size = def->arg_size(); \
for (int i = 0; i < size; ++i) { \
auto arg = def->mutable_arg(i); \
if (arg->name() == arg_name) { \
VLOG(3) << "Update old argument value from " \
<< arg->fieldname() << " to " \
<< value << " for " << arg_name; \
arg->set_##fieldname(value); \
return; \
} \
} \
VLOG(3) << "Add new argument " << arg_name << "(name: " \
<< arg_name << ", value: " << value << ")"; \
auto arg = def->add_arg(); \
arg->set_name(arg_name); \
arg->set_##fieldname(value); \
}
#define MACE_SET_OPTIONAL_ARGUMENT_FUNC_MACRO(Def) \
MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, float, f) \
MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, bool, i) \
MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, int, i) \
MACE_SET_OPTIONAL_ARGUMENT_FUNC(Def, int64_t, i)
MACE_SET_OPTIONAL_ARGUMENT_FUNC_MACRO(OperatorDef)
MACE_SET_OPTIONAL_ARGUMENT_FUNC_MACRO(NetDef)
#undef MACE_SET_OPTIONAL_ARGUMENT_FUNC
const std::string OutputMemoryTypeTagName() {
static const char *kOutputMemTypeArgName = "output_mem_type";
return kOutputMemTypeArgName;
}
bool IsQuantizedModel(const NetDef &net_def) { bool IsQuantizedModel(const NetDef &net_def) {
return return
......
...@@ -55,6 +55,18 @@ class ProtoArgHelper { ...@@ -55,6 +55,18 @@ class ProtoArgHelper {
std::map<std::string, Argument> arg_map_; std::map<std::string, Argument> arg_map_;
}; };
template <typename T>
void SetProtoArg(OperatorDef *op_def,
const std::string &arg_name,
const T&value);
template <typename T>
void SetProtoArg(NetDef *op_def,
const std::string &arg_name,
const T&value);
const std::string OutputMemoryTypeTagName();
bool IsQuantizedModel(const NetDef &def); bool IsQuantizedModel(const NetDef &def);
} // namespace mace } // namespace mace
......
...@@ -33,7 +33,7 @@ namespace mace { ...@@ -33,7 +33,7 @@ namespace mace {
bool MemoryOptimizer::IsMemoryReuseOp(const std::string &op_type) { bool MemoryOptimizer::IsMemoryReuseOp(const std::string &op_type) {
static const std::unordered_set<std::string> kReuseOp = { static const std::unordered_set<std::string> kReuseOp = {
"Reshape", "Identity", "Squeeze" "Reshape", "Identity", "Squeeze", "ExpandDims"
}; };
return kReuseOp.count(op_type) == 1; return kReuseOp.count(op_type) == 1;
} }
...@@ -124,8 +124,10 @@ void MemoryOptimizer::Optimize( ...@@ -124,8 +124,10 @@ void MemoryOptimizer::Optimize(
op_def->output_type_size()); op_def->output_type_size());
DataType dt; DataType dt;
bool has_data_format = ProtoArgHelper::GetOptionalArg<OperatorDef, int>( DataFormat data_format = static_cast<DataFormat>(
*op_def, "has_data_format", 0) != 0; ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op_def, "data_format",
static_cast<int>(DataFormat::NONE)));
int output_size = op_def->output_size(); int output_size = op_def->output_size();
for (int i = 0; i < output_size; ++i) { for (int i = 0; i < output_size; ++i) {
if (i < op_def->output_type_size()) { if (i < op_def->output_type_size()) {
...@@ -209,7 +211,7 @@ void MemoryOptimizer::Optimize( ...@@ -209,7 +211,7 @@ void MemoryOptimizer::Optimize(
mem_ref_count_[best_mem_id] = 1; mem_ref_count_[best_mem_id] = 1;
} }
tensor_mem_map_.emplace(op_def->output(i), TensorMemInfo(best_mem_id, tensor_mem_map_.emplace(op_def->output(i), TensorMemInfo(best_mem_id,
dt, has_data_format)); dt, data_format));
} }
} }
......
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include <vector> #include <vector>
#include "mace/proto/mace.pb.h" #include "mace/proto/mace.pb.h"
#include "mace/port/port.h"
#include "mace/core/types.h" #include "mace/core/types.h"
namespace mace { namespace mace {
...@@ -81,10 +82,10 @@ class MemoryOptimizer { ...@@ -81,10 +82,10 @@ class MemoryOptimizer {
struct TensorMemInfo { struct TensorMemInfo {
int mem_id; int mem_id;
DataType data_type; DataType data_type;
bool has_data_format; DataFormat data_format;
TensorMemInfo(int mem_id, DataType data_type, bool has_data_format) : TensorMemInfo(int mem_id, DataType data_type, DataFormat data_format) :
mem_id(mem_id), data_type(data_type), has_data_format(has_data_format) mem_id(mem_id), data_type(data_type), data_format(data_format)
{} {}
}; };
......
...@@ -31,99 +31,8 @@ ...@@ -31,99 +31,8 @@
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#include "mace/utils/timer.h" #include "mace/utils/timer.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/opencl_util.h"
#endif // MACE_ENABLE_OPENCL
namespace mace { namespace mace {
namespace {
struct InternalOutputInfo {
InternalOutputInfo(const MemoryType mem_type,
const DataType dtype,
const DataFormat data_format,
const std::vector<index_t> &shape,
int op_idx)
: mem_type(mem_type), dtype(dtype), data_format(data_format),
shape(shape), op_idx(op_idx) {}
MemoryType mem_type; // transformed memory type
DataType dtype;
DataFormat data_format;
std::vector<index_t> shape; // tensor shape
int op_idx; // operation which generate the tensor
};
#ifdef MACE_ENABLE_OPENCL
std::string TransformedName(const std::string &input_name,
const mace::MemoryType mem_type) {
std::stringstream ss;
ss << input_name << "_mem_type_" << mem_type;
return ss.str();
}
bool TransformRequiredOp(const std::string &op_type) {
static const std::unordered_set<std::string> kNoTransformOp = {
"Shape", "InferConv2dShape"
};
return kNoTransformOp.count(op_type) == 0;
}
#endif // MACE_ENABLE_OPENCL
} // namespace
std::unique_ptr<Operation> SerialNet::CreateOperation(
const OpRegistryBase *op_registry,
OpConstructContext *construct_context,
std::shared_ptr<OperatorDef> op_def,
bool has_data_format,
bool is_quantize_model) {
// Create the Operation
DeviceType target_device_type = target_device_->device_type();
DeviceType device_type = DeviceType::CPU;
construct_context->set_device(cpu_device_.get());
construct_context->set_operator_def(op_def);
construct_context->set_output_mem_type(MemoryType::CPU_BUFFER);
// Get available devices
auto available_devices =
op_registry->AvailableDevices(op_def->type(), construct_context);
// Find the device type to run the op.
// If the target_device_type in available devices, use target_device_type,
// otherwise, fallback to CPU device.
for (auto device : available_devices) {
if (device == target_device_type) {
device_type = target_device_type;
construct_context->set_device(target_device_);
if (target_device_->device_type() == DeviceType::GPU) {
construct_context->set_output_mem_type(MemoryType::GPU_IMAGE);
}
break;
}
}
op_def->set_device_type(device_type);
// transpose output shape if run on CPU (default format is NHWC)
if (!is_quantize_model && device_type == DeviceType::CPU &&
op_def->output_shape_size() == op_def->output_size()) {
for (int out_idx = 0; out_idx < op_def->output_size(); ++out_idx) {
if (has_data_format && op_def->output_shape(out_idx).dims_size() == 4) {
// NHWC -> NCHW
std::vector<index_t> output_shape =
TransposeShape<index_t, index_t>(
std::vector<index_t>(
op_def->output_shape(out_idx).dims().begin(),
op_def->output_shape(out_idx).dims().end()),
{0, 3, 1, 2});
for (int i = 0; i < 4; ++i) {
op_def->mutable_output_shape(out_idx)->set_dims(i, output_shape[i]);
}
}
}
}
return op_registry->CreateOperation(construct_context, device_type);
}
SerialNet::SerialNet(const OpRegistryBase *op_registry, SerialNet::SerialNet(const OpRegistryBase *op_registry,
const NetDef *net_def, const NetDef *net_def,
Workspace *ws, Workspace *ws,
...@@ -138,237 +47,47 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry, ...@@ -138,237 +47,47 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
target_device->cpu_runtime()->policy(), target_device->cpu_runtime()->policy(),
&target_device->cpu_runtime()->thread_pool())) { &target_device->cpu_runtime()->thread_pool())) {
MACE_LATENCY_LOGGER(1, "Constructing SerialNet"); MACE_LATENCY_LOGGER(1, "Constructing SerialNet");
// quantize model flag
bool is_quantize_model = IsQuantizedModel(*net_def);
// Tensor Shape map
std::unordered_map<std::string, std::vector<index_t>> tensor_shape_map;
for (auto &op : net_def->op()) {
if (op.output_size() != op.output_shape_size()) {
continue;
}
for (int i = 0; i < op.output_size(); ++i) {
tensor_shape_map[op.output(i)] = std::vector<index_t>(
op.output_shape(i).dims().begin(),
op.output_shape(i).dims().end());
}
}
for (auto &tensor : net_def->tensors()) {
tensor_shape_map[tensor.name()] =
std::vector<index_t>(tensor.dims().begin(), tensor.dims().end());
}
bool has_data_format = false;
if (target_device_->device_type() == DeviceType::CPU) {
for (auto &input_info : net_def->input_info()) {
std::vector<index_t> input_shape =
std::vector<index_t>(input_info.dims().begin(),
input_info.dims().end());
// update tensor shape map
tensor_shape_map[input_info.name()] = input_shape;
// Only could be NONE or NHWC
DataFormat input_data_format = static_cast<DataFormat>(
input_info.data_format());
has_data_format = has_data_format ||
(input_data_format != DataFormat::DF_NONE);
if (!is_quantize_model && input_data_format == DataFormat::NHWC &&
input_info.dims_size() == 4) {
// NHWC -> NCHW
input_shape =
TransposeShape<index_t, index_t>(input_shape, {0, 3, 1, 2});
}
}
}
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
// output tensor : related information
std::unordered_map<std::string, InternalOutputInfo> output_map;
// used for memory optimization // used for memory optimization
std::unordered_map<std::string, MemoryType> output_mem_map; std::unordered_map<std::string, MemoryType> output_mem_map;
std::unordered_set<std::string> transformed_set;
// add input information
MemoryType target_mem_type;
// default data format of output tensor
DataFormat default_output_df = DataFormat::DF_NONE;
if (target_device_->device_type() == DeviceType::GPU) {
target_mem_type = MemoryType::GPU_BUFFER;
for (auto &input_info : net_def->input_info()) {
DataFormat input_data_format = static_cast<DataFormat>(
input_info.data_format());
has_data_format = input_data_format != DataFormat::DF_NONE;
std::vector<index_t> input_shape =
std::vector<index_t>(input_info.dims().begin(),
input_info.dims().end());
// update tensor shape map
tensor_shape_map[input_info.name()] = input_shape;
output_map.emplace(input_info.name(), InternalOutputInfo(
target_mem_type, DataType::DT_FLOAT, input_data_format,
input_shape, -1));
}
default_output_df =
has_data_format ? DataFormat::NHWC : DataFormat::DF_NONE;
}
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
OpConstructContext construct_context(ws_, &tensor_shape_map); OpConstructContext construct_context(ws_);
for (int idx = 0; idx < net_def->op_size(); ++idx) { for (int idx = 0; idx < net_def->op_size(); ++idx) {
std::shared_ptr<OperatorDef> op_def(new OperatorDef(net_def->op(idx))); std::shared_ptr<OperatorDef> op_def(new OperatorDef(net_def->op(idx)));
// Create operation // Create operation
auto op = CreateOperation(op_registry, auto op_device_type = static_cast<DeviceType>(op_def->device_type());
&construct_context, if (op_device_type == target_device_->device_type()) {
op_def, construct_context.set_device(target_device_);
has_data_format, } else if (op_device_type == DeviceType::CPU) {
is_quantize_model); construct_context.set_device(cpu_device_.get());
#ifdef MACE_ENABLE_OPENCL } else {
// Add input transform operation if necessary LOG(FATAL) << "Encounter unexpected error: "
if (target_device_->device_type() == DeviceType::GPU) { << op_device_type << " vs " << target_device_->device_type();
// the outputs' memory type of the operation
MemoryType out_mem_type = construct_context.output_mem_type();
int input_size = op_def->input_size();
// if op is memory-unused op, no transformation
if (TransformRequiredOp(op_def->type())) {
for (int i = 0; i < input_size; ++i) {
if (output_map.count(op_def->input(i)) == 1) {
// if op is memory-reuse op, no transformation
if (MemoryOptimizer::IsMemoryReuseOp(op_def->type())) {
out_mem_type = output_map.at(op_def->input(i)).mem_type;
break;
}
// check whether to do transform
MemoryType wanted_in_mem_type =
construct_context.GetInputMemType(i);
DataType wanted_in_dt = construct_context.GetInputDataType(i);
if (output_map.at(op_def->input(i)).mem_type != wanted_in_mem_type
|| output_map.at(op_def->input(i)).dtype != wanted_in_dt) {
auto t_input_name = TransformedName(op_def->input(i),
wanted_in_mem_type);
auto &output_info = output_map.at(op_def->input(i));
// check whether the tensor has been transformed
if (transformed_set.count(t_input_name) == 0) {
VLOG(1) << "Add Transform operation " << op_def->name()
<< " to transform tensor "
<< op_def->input(i) << "', from memory type "
<< output_info.mem_type << " to "
<< wanted_in_mem_type
<< ", from Data Type " << output_info.dtype << " to "
<< wanted_in_dt << ". with data format "
<< output_info.data_format;
std::string input_name = op_def->input(i);
op_def->set_input(i, t_input_name);
auto input_shape = output_info.shape;
if (output_info.mem_type == MemoryType::CPU_BUFFER &&
output_info.data_format == DataFormat::NCHW &&
input_shape.size() == 4) {
// NCHW -> NHWC
input_shape =
TransposeShape<index_t, index_t>(input_shape,
{0, 2, 3, 1});
}
auto transform_op_def = OpenCLUtil::CreateTransformOpDef(
input_name, input_shape, t_input_name, wanted_in_dt,
construct_context.GetInputOpenCLBufferType(i),
wanted_in_mem_type, has_data_format);
OpConstructContext t_construct_context(ws_);
auto transform_op = CreateOperation(
op_registry,
&t_construct_context,
transform_op_def,
has_data_format);
operators_.emplace_back(std::move(transform_op));
transformed_set.insert(t_input_name);
output_mem_map[t_input_name] = wanted_in_mem_type;
// where to do graph reference count.
mem_optimizer->UpdateTensorRef(transform_op_def.get());
} else {
op_def->set_input(i, t_input_name);
}
}
} else {
MACE_CHECK(ws_->GetTensor(op_def->input(i)) != nullptr
&& ws_->GetTensor(op_def->input(i))->is_weight(),
"Tensor ", op_def->input(i), " of ",
op_def->name(), " not allocated");
}
}
}
// update the map : output_tensor -> Operation
for (int out_idx = 0; out_idx < op_def->output_size(); ++out_idx) {
DataType dt;
if (op_def->output_type_size() == op_def->output_size()) {
dt = op_def->output_type(out_idx);
} else {
dt = static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op_def, "T", static_cast<int>(DataType::DT_FLOAT)));
}
output_mem_map[op_def->output(out_idx)] = out_mem_type;
output_map.emplace(
op_def->output(out_idx),
InternalOutputInfo(
out_mem_type,
dt,
default_output_df,
op_def->output_shape().empty() ?
std::vector<index_t>() :
std::vector<index_t>(
op_def->output_shape(out_idx).dims().begin(),
op_def->output_shape(out_idx).dims().end()),
static_cast<int>(operators_.size())));
}
} }
#endif // MACE_ENABLE_OPENCL construct_context.set_operator_def(op_def);
auto op = op_registry->CreateOperation(&construct_context,
op_device_type);
operators_.emplace_back(std::move(op)); operators_.emplace_back(std::move(op));
// where to do graph reference count. // where to do graph reference count.
mem_optimizer->UpdateTensorRef(op_def.get()); mem_optimizer->UpdateTensorRef(op_def.get());
}
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
// Transform the output tensor if necessary if (target_device_->device_type() == DeviceType::GPU) {
if (target_device_->device_type() == DeviceType::GPU) { // update the map : output_tensor -> MemoryType
for (auto &output_info : net_def->output_info()) { MemoryType out_mem_type =
auto &internal_output_info = output_map.at(output_info.name()); static_cast<MemoryType>(
if ((internal_output_info.mem_type != target_mem_type && ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
internal_output_info.mem_type != MemoryType::CPU_BUFFER) || net_def->op(idx), OutputMemoryTypeTagName(),
internal_output_info.dtype != output_info.data_type()) { static_cast<int>(MemoryType::CPU_BUFFER)));
VLOG(1) << "Add Transform operation to transform output tensor '" for (int out_idx = 0; out_idx < op_def->output_size(); ++out_idx) {
<< output_info.name() << "', from memory type " output_mem_map[op_def->output(out_idx)] = out_mem_type;
<< internal_output_info.mem_type
<< " to " << target_mem_type
<< ", from Data Type " << internal_output_info.dtype
<< " to " << output_info.data_type();
std::string t_output_name = TransformedName(output_info.name(),
target_mem_type);
auto output_op_def =
operators_[internal_output_info.op_idx]->operator_def();
int output_size = output_op_def->output_size();
for (int i = 0; i < output_size; ++i) {
if (output_op_def->output(i) == output_info.name()) {
output_op_def->set_output(i, t_output_name);
// update the output : mem_type map
output_mem_map[t_output_name] = output_mem_map[output_info.name()];
output_mem_map[output_info.name()] = target_mem_type;
}
}
bool output_has_data_format =
static_cast<DataFormat>(output_info.data_format());
auto transform_op_def = OpenCLUtil::CreateTransformOpDef(
t_output_name,
internal_output_info.shape,
output_info.name(),
output_info.data_type(),
OpenCLBufferType::IN_OUT_CHANNEL,
target_mem_type,
output_has_data_format);
auto transform_op = CreateOperation(
op_registry,
&construct_context,
transform_op_def,
output_has_data_format);
operators_.emplace_back(std::move(transform_op));
// where to do graph reference count.
mem_optimizer->UpdateTensorRef(transform_op_def.get());
} }
} }
}
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
}
// Update output tensor reference // Update output tensor reference
for (auto &output_info : net_def->output_info()) { for (auto &output_info : net_def->output_info()) {
mem_optimizer->UpdateTensorRef(output_info.name()); mem_optimizer->UpdateTensorRef(output_info.name());
......
...@@ -54,14 +54,6 @@ class SerialNet : public NetBase { ...@@ -54,14 +54,6 @@ class SerialNet : public NetBase {
MaceStatus Run(RunMetadata *run_metadata = nullptr) override; MaceStatus Run(RunMetadata *run_metadata = nullptr) override;
private:
std::unique_ptr<Operation> CreateOperation(
const OpRegistryBase *op_registry,
OpConstructContext *construct_context,
std::shared_ptr<OperatorDef> op_def,
bool has_data_format,
bool is_quantize_model = false);
protected: protected:
Workspace *ws_; Workspace *ws_;
Device *target_device_; Device *target_device_;
......
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/net_def_adapter.h"
#include <string>
#include <vector>
#include "mace/core/operator.h"
#include "mace/utils/math.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/opencl_util.h"
#endif // MACE_ENABLE_OPENCL
namespace mace {
namespace {
DataFormat GetDefaultDataFormat(DeviceType device_type,
bool is_quantized_model) {
if (device_type == CPU) {
if (is_quantized_model) {
return DataFormat::NHWC;
} else {
return DataFormat::NCHW;
}
} else if (device_type == GPU) {
return DataFormat::NHWC;
} else {
LOG(FATAL) << "MACE do not support the device " << device_type;
return DataFormat::NONE;
}
}
template<typename T>
std::string TransformedName(const std::string &input_name,
const std::string &tag,
const T value) {
std::stringstream ss;
ss << input_name << "_" << tag << "_" << value;
return ss.str();
}
#ifdef MACE_ENABLE_OPENCL
bool TransformRequiredOp(const std::string &op_type) {
static const std::unordered_set<std::string> kNoTransformOp = {
"Shape", "InferConv2dShape"
};
return kNoTransformOp.count(op_type) == 0;
}
#endif // MACE_ENABLE_OPENCL
void BuildTransposeOpDef(
const std::string &input_name,
const std::string &output_name,
const std::vector<index_t> &output_shape,
const std::vector<int> dst_dims,
const DataType dt,
DeviceType device_type,
OperatorDef *op_def) {
std::string op_name = "mace_node_" + output_name;
op_def->set_name(op_name);
op_def->set_type("Transpose");
op_def->add_input(input_name);
op_def->add_output(output_name);
op_def->set_device_type(device_type);
Argument *arg = op_def->add_arg();
arg->set_name("dims");
for (auto dim : dst_dims) {
arg->add_ints(dim);
}
arg = op_def->add_arg();
arg->set_name("T");
arg->set_i(static_cast<int32_t>(dt));
if (!output_shape.empty()) {
OutputShape *shape = op_def->add_output_shape();
for (auto value : output_shape) {
shape->add_dims(value);
}
}
}
} // namespace
NetDefAdapter::NetDefAdapter(const OpRegistryBase *op_registry,
const Workspace *ws)
: op_registry_(op_registry), ws_(ws) {}
MaceStatus NetDefAdapter::AdaptNetDef(
const NetDef *net_def,
Device *target_device,
NetDef *target_net_def) {
MACE_LATENCY_LOGGER(1, "Adapting original NetDef");
// Copy from original op_def, leave ops alone.
target_net_def->mutable_arg()->CopyFrom(net_def->arg());
target_net_def->mutable_tensors()->CopyFrom(net_def->tensors());
target_net_def->mutable_input_info()->CopyFrom(net_def->input_info());
target_net_def->mutable_output_info()->CopyFrom(net_def->output_info());
std::unique_ptr<CPUDevice> cpu_device = make_unique<CPUDevice>(
target_device->cpu_runtime()->num_threads(),
target_device->cpu_runtime()->policy(),
&(target_device->cpu_runtime()->thread_pool()));
// quantize model flag
bool is_quantized_model = IsQuantizedModel(*net_def);
// Const tensors(filter) -> shape
std::unordered_map<std::string, std::vector<index_t>> tensor_shape_map;
// Output tensors -> information
TensorInfoMap output_map;
// output tensor : related information
std::unordered_set<std::string> transformed_set;
for (auto &tensor : net_def->tensors()) {
tensor_shape_map[tensor.name()] =
std::vector<index_t>(tensor.dims().begin(), tensor.dims().end());
}
MemoryType mem_type = MemoryType::CPU_BUFFER;
if (target_device->device_type() == DeviceType::CPU) {
mem_type = MemoryType::CPU_BUFFER;
} else if (target_device->device_type() == DeviceType::GPU) {
mem_type = MemoryType::GPU_BUFFER;
} else {
LOG(FATAL) << "MACE do not support the device type: "
<< target_device->device_type();
}
int input_size = target_net_def->input_info_size();
for (int i = 0; i < input_size; ++i) {
auto input_info = target_net_def->mutable_input_info(i);
auto input_data_format = static_cast<DataFormat>(
input_info->data_format());
DataFormat expected_data_format = GetDefaultDataFormat(
target_device->device_type(), is_quantized_model);
std::vector<index_t> input_shape(input_info->dims().begin(),
input_info->dims().end());
if (input_data_format != DataFormat::NONE
&& input_data_format != expected_data_format
&& input_shape.size() == 4) {
if (input_data_format == DataFormat::NHWC
&& expected_data_format == DataFormat::NCHW) {
std::vector<int> dst_dims{0, 3, 1, 2};
input_data_format = DataFormat::NCHW;
input_shape = TransposeShape<index_t, index_t>(input_shape, dst_dims);
} else if (input_data_format == DataFormat::NCHW
&& expected_data_format == DataFormat::NHWC) {
std::vector<int> dst_dims{0, 2, 3, 1};
input_data_format = DataFormat::NHWC;
input_shape = TransposeShape<index_t, index_t>(input_shape, dst_dims);
}
input_info->set_data_format(static_cast<int>(input_data_format));
int input_shape_size = input_shape.size();
for (int j = 0; j < input_shape_size; ++j) {
input_info->set_dims(j, input_shape[j]);
}
}
output_map.emplace(input_info->name(), InternalOutputInfo(
mem_type, input_info->data_type(),
input_data_format, input_shape, -1));
}
OpConditionContext context(ws_, &tensor_shape_map);
DataFormat op_output_data_format;
MemoryType op_output_mem_type;
for (int idx = 0; idx < net_def->op_size(); ++idx) {
OperatorDef op_def(net_def->op(idx));
context.set_operator_def(&op_def);
// Select device
MACE_RETURN_IF_ERROR(this->AdaptDevice(&context,
target_device,
cpu_device.get(),
output_map,
target_net_def,
&op_def));
// Adapt data type
MACE_RETURN_IF_ERROR(this->AdaptDataType(&context,
&op_def));
if (op_def.device_type() == DeviceType::GPU) {
MACE_RETURN_IF_ERROR(this->AdaptDataFormat(&context,
&op_def,
is_quantized_model,
&output_map,
&transformed_set,
&op_output_data_format,
target_net_def));
MACE_RETURN_IF_ERROR(this->AdaptMemoryType(&context,
&op_def,
&output_map,
&transformed_set,
&op_output_mem_type,
target_net_def));
} else {
MACE_RETURN_IF_ERROR(this->AdaptMemoryType(&context,
&op_def,
&output_map,
&transformed_set,
&op_output_mem_type,
target_net_def));
MACE_RETURN_IF_ERROR(this->AdaptDataFormat(&context,
&op_def,
is_quantized_model,
&output_map,
&transformed_set,
&op_output_data_format,
target_net_def));
}
int output_size = op_def.output_size();
for (int out_idx = 0; out_idx < output_size; ++out_idx) {
DataType dt;
if (op_def.output_type_size() == op_def.output_size()) {
dt = op_def.output_type(out_idx);
} else {
dt = static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "T", static_cast<int>(DataType::DT_FLOAT)));
}
output_map.emplace(
op_def.output(out_idx),
InternalOutputInfo(
op_output_mem_type,
dt,
op_output_data_format,
op_def.output_shape().empty() ?
std::vector<index_t>() :
std::vector<index_t>(
op_def.output_shape(out_idx).dims().begin(),
op_def.output_shape(out_idx).dims().end()),
target_net_def->op_size()));
}
// Add op to target net
target_net_def->add_op()->CopyFrom(op_def);
}
#ifdef MACE_ENABLE_OPENCL
if (target_device->device_type() == DeviceType::GPU) {
// Add buffer transform for GPU if necessary
MemoryType target_mem_type = MemoryType::GPU_BUFFER;
for (auto &output_info : net_def->output_info()) {
auto &internal_output_info = output_map.at(output_info.name());
if ((internal_output_info.mem_type != target_mem_type &&
internal_output_info.mem_type != MemoryType::CPU_BUFFER) ||
internal_output_info.dtype != output_info.data_type()) {
VLOG(1) << "Add Transform operation to transform output tensor '"
<< output_info.name() << "', from memory type "
<< internal_output_info.mem_type
<< " to " << target_mem_type
<< ", from Data Type " << internal_output_info.dtype
<< " to " << output_info.data_type();
std::string t_output_name = TransformedName(output_info.name(),
"mem_type",
target_mem_type);
auto output_op_def = target_net_def->mutable_op(
internal_output_info.op_idx);
int output_size = output_op_def->output_size();
for (int i = 0; i < output_size; ++i) {
if (output_op_def->output(i) == output_info.name()) {
output_op_def->set_output(i, t_output_name);
}
}
auto transformed_op_def = target_net_def->add_op();
OpenCLUtil::BuildTransformOpDef(
t_output_name,
internal_output_info.shape,
output_info.name(),
output_info.data_type(),
OpenCLBufferType::IN_OUT_CHANNEL,
target_mem_type,
internal_output_info.data_format,
transformed_op_def);
// set data format arg
SetProtoArg<int>(
transformed_op_def,
"data_format",
static_cast<int>(internal_output_info.data_format));
// set output memory type argument
SetProtoArg<int>(transformed_op_def,
OutputMemoryTypeTagName(),
target_mem_type);
}
}
}
#endif // MACE_ENABLE_OPENCL
VLOG(1) << DebugString(target_net_def);
return MaceStatus::MACE_SUCCESS;
}
MaceStatus NetDefAdapter::AdaptDevice(OpConditionContext *context,
Device *target_device,
Device *cpu_device,
const TensorInfoMap &output_map,
const NetDef *net_def,
OperatorDef *op_def) {
VLOG(3) << "Adapt device for op " << op_def->name();
DeviceType target_device_type = target_device->device_type();
DeviceType device_type = DeviceType::CPU;
context->set_device(cpu_device);
if (target_device_type != DeviceType::CPU) {
std::vector<DeviceType> producer_devices;
for (auto input : op_def->input()) {
if (output_map.count(input) == 1) {
if (output_map.at(input).op_idx < 0) {
producer_devices.push_back(target_device_type);
} else {
producer_devices.push_back(
static_cast<DeviceType>(
net_def->op(output_map.at(input).op_idx).device_type()));
}
}
}
// Get available devices
auto available_devices =
op_registry_->AvailableDevices(op_def->type(), context);
device_type = net_optimizer_.SelectBestDevice(op_def,
target_device_type,
available_devices,
producer_devices);
if (device_type == target_device_type) {
context->set_device(target_device);
} else {
LOG(INFO) << "Op " << op_def->name() << " fall back to CPU";
}
}
op_def->set_device_type(device_type);
return MaceStatus::MACE_SUCCESS;
}
MaceStatus NetDefAdapter::AdaptDataType(OpConditionContext *context,
OperatorDef *op_def) {
MACE_UNUSED(context);
// Where to add logic to support mixing precision
// Adjust data type of op ran on CPU
DataType dtype = static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op_def, "T", static_cast<int>(DT_FLOAT)));
if (op_def->device_type() == DeviceType::CPU && dtype == DT_HALF) {
SetProtoArg<int>(op_def, "T", static_cast<int>(DataType::DT_FLOAT));
}
return MaceStatus::MACE_SUCCESS;
}
MaceStatus NetDefAdapter::AdaptDataFormat(
OpConditionContext *context,
OperatorDef *op_def,
bool is_quantized_model,
TensorInfoMap *output_map,
std::unordered_set<std::string> *transformed_set,
DataFormat *op_output_df,
NetDef *target_net_def) {
VLOG(3) << "Adapt data format for op " << op_def->name();
DataFormat op_data_format =
static_cast<DataFormat>(ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op_def, "data_format",
static_cast<int>(DataFormat::NONE)));
// adjust the data format of operation
if (op_data_format == DataFormat::AUTO) {
op_data_format = GetDefaultDataFormat(
static_cast<DeviceType>(op_def->device_type()), is_quantized_model);
SetProtoArg<int>(op_def, "data_format", static_cast<int>(op_data_format));
if (op_data_format == DataFormat::NCHW) {
int output_shape_size = op_def->output_shape_size();
for (int i = 0; i < output_shape_size; ++i) {
auto output_shape = op_def->mutable_output_shape(i);
MACE_CHECK(output_shape->dims_size() == 4,
"Output shape should be 4D if the of has data format. ",
op_def->name());
// transpose output shape format from NHWC to NCHW
int64_t height = output_shape->dims(1);
int64_t width = output_shape->dims(2);
output_shape->set_dims(1, output_shape->dims(3));
output_shape->set_dims(2, height);
output_shape->set_dims(3, width);
}
}
}
*op_output_df = op_data_format;
// the output memory type of transpose op is based on the consumer op's device
MemoryType target_mem_type = MemoryType::CPU_BUFFER;
if (op_def->device_type() == DeviceType::GPU) {
target_mem_type = MemoryType::GPU_BUFFER;
}
auto inputs_data_format = op_registry_->InputsDataFormat(op_def->type(),
context);
DataFormat src_df, dst_df;
int input_size = op_def->input_size();
for (int i = 0; i < input_size; ++i) {
if (output_map->count(op_def->input(i)) == 0) {
// check this input is const tensor(filter)
MACE_CHECK(ws_->GetTensor(op_def->input(i)) != nullptr
&& ws_->GetTensor(op_def->input(i))->is_weight(),
"Tensor ", op_def->input(i), " of ",
op_def->name(), " is not allocated by Workspace ahead");
continue;
}
src_df = output_map->at(op_def->input(i)).data_format;
dst_df = inputs_data_format[i];
if (src_df == DataFormat::NONE
|| dst_df == DataFormat::NONE
|| output_map->at(op_def->input(i)).shape.size() != 4) {
continue;
}
if (src_df != dst_df) {
std::string transformed_name = TransformedName(op_def->input(i),
"data_format", static_cast<int>(dst_df));
if (transformed_set->count(transformed_name) == 0) {
VLOG(1) << "Add Transpose operation " << op_def->name()
<< " to transpose tensor "
<< op_def->input(i) << "', from data format "
<< static_cast<int>(src_df) << " to "
<< static_cast<int>(dst_df);
// Only support transpose between NHWC and NCHW for now.
std::vector<int> dst_dims;
if (src_df == DataFormat::NCHW && dst_df == DataFormat::NHWC) {
dst_dims = {0, 2, 3, 1};
} else if (src_df == DataFormat::NHWC && dst_df == DataFormat::NCHW) {
dst_dims = {0, 3, 1, 2};
} else {
LOG(FATAL) << "Encounter unsupported data format transpose from "
<< static_cast<int>(src_df) << " to "
<< static_cast<int>(dst_df);
}
auto &input_info = output_map->at(op_def->input(i));
auto output_shape = input_info.shape.empty() ?
std::vector<index_t>() :
TransposeShape<index_t, index_t>(input_info.shape,
dst_dims);
OperatorDef *transpose_op_def = target_net_def->add_op();
BuildTransposeOpDef(
op_def->input(i),
transformed_name,
output_shape,
dst_dims,
input_info.dtype,
DeviceType::CPU,
transpose_op_def);
// set data format arg
SetProtoArg<int>(transpose_op_def,
"data_format",
static_cast<int>(dst_df));
// set output memory type argument
SetProtoArg<int>(transpose_op_def,
OutputMemoryTypeTagName(),
target_mem_type);
// update output information map
output_map->emplace(
transformed_name,
InternalOutputInfo(
target_mem_type,
input_info.dtype,
dst_df,
output_shape,
target_net_def->op_size() - 1));
// record transformed tensors
transformed_set->insert(transformed_name);
}
// update original op_def's input
op_def->set_input(i, transformed_name);
}
}
return MaceStatus::MACE_SUCCESS;
}
MaceStatus NetDefAdapter::AdaptMemoryType(
OpConditionContext *context,
OperatorDef *op_def,
NetDefAdapter::TensorInfoMap *output_map,
std::unordered_set<std::string> *transformed_set,
MemoryType *op_output_mem_types,
NetDef *target_net_def) {
VLOG(3) << "Adapt memory type for op " << op_def->name();
// Get expected output memory type
// (only support one kind of memory type for multiple outputs)
op_registry_->GetInOutMemoryTypes(op_def->type(), context);
#ifdef MACE_ENABLE_OPENCL
// if op is memory-unused op, no transformation
if (TransformRequiredOp(op_def->type())) {
int input_size = op_def->input_size();
for (int i = 0; i < input_size; ++i) {
if (output_map->count(op_def->input(i)) == 0) {
MACE_CHECK(ws_->GetTensor(op_def->input(i)) != nullptr
&& ws_->GetTensor(op_def->input(i))->is_weight(),
"Tensor ", op_def->input(i), " of ",
op_def->name(), " not allocated");
continue;
}
auto &input_info = output_map->at(op_def->input(i));
// check whether to do transform
MemoryType src_mem_type = input_info.mem_type;
MemoryType dst_mem_type = context->GetInputMemType(i);
auto wanted_input_dtype = context->GetInputDataType(i);
if (src_mem_type != dst_mem_type ||
(input_info.dtype != wanted_input_dtype &&
(src_mem_type != MemoryType::CPU_BUFFER
|| dst_mem_type != MemoryType::CPU_BUFFER))) {
auto transformed_name = TransformedName(op_def->input(i),
"mem_type",
dst_mem_type);
// check whether the tensor has been transformed
if (transformed_set->count(transformed_name) == 0) {
VLOG(1) << "Add Transform operation " << op_def->name()
<< " to transform tensor "
<< op_def->input(i) << "', from memory type "
<< input_info.mem_type << " to "
<< dst_mem_type;
OperatorDef *transformed_op_def = target_net_def->add_op();
OpenCLUtil::BuildTransformOpDef(
op_def->input(i),
input_info.shape,
transformed_name,
wanted_input_dtype,
context->GetInputOpenCLBufferType(i),
dst_mem_type,
input_info.data_format,
transformed_op_def);
// set data format arg
SetProtoArg<int>(transformed_op_def,
"data_format",
static_cast<int>(input_info.data_format));
// set output memory type argument
SetProtoArg<int>(transformed_op_def,
OutputMemoryTypeTagName(),
dst_mem_type);
// update output information map
output_map->emplace(
transformed_name,
InternalOutputInfo(
dst_mem_type,
context->GetInputDataType(i),
input_info.data_format,
input_info.shape,
target_net_def->op_size() - 1));
// record transformed tensors
transformed_set->insert(transformed_name);
}
// update original op_def's input
op_def->set_input(i, transformed_name);
}
}
}
#else
MACE_UNUSED(output_map);
MACE_UNUSED(transformed_set);
MACE_UNUSED(target_net_def);
#endif // MACE_ENABLE_OPENCL
*op_output_mem_types = context->output_mem_type();
SetProtoArg<int>(op_def,
OutputMemoryTypeTagName(),
context->output_mem_type());
return MaceStatus::MACE_SUCCESS;
}
std::string NetDefAdapter::DebugString(const NetDef *net_def) {
std::stringstream sstream;
auto DeviceTypeToStrFunc = [](DeviceType device_type) -> std::string {
if (device_type == DeviceType::CPU) {
return "CPU";
} else if (device_type == DeviceType::GPU) {
return "GPU";
} else {
return "Unknown";
}
};
auto MemoryTypeToStrFunc = [](MemoryType type) -> std::string {
if (type == MemoryType::CPU_BUFFER) {
return "CPU_BUFFER";
} else if (type == MemoryType::GPU_BUFFER) {
return "GPU_BUFFER";
} else if (type == MemoryType::GPU_IMAGE) {
return "GPU_IMAGE";
} else {
return "Unknown";
}
};
auto DataFormatToStrFunc = [](DataFormat type) -> std::string {
if (type == DataFormat::NHWC) {
return "NHWC";
} else if (type == DataFormat::NCHW) {
return "NCHW";
} else if (type == DataFormat::NONE) {
return "NONE";
} else if (type == DataFormat::AUTO) {
return "AUTO";
} else if (type == DataFormat::OIHW) {
return "OIHW";
} else {
return "Unknown";
}
};
for (auto &op : net_def->op()) {
std::string device_type = DeviceTypeToStrFunc(
static_cast<DeviceType>(op.device_type()));
std::string data_type = DataTypeToString(static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "T", static_cast<int>(DT_FLOAT))));
std::string mem_type = MemoryTypeToStrFunc(
static_cast<MemoryType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, OutputMemoryTypeTagName(),
static_cast<int>(MemoryType::CPU_BUFFER))));
std::string data_format = DataFormatToStrFunc(
static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "data_format", static_cast<int>(DataFormat::NONE))));
sstream << std::endl;
sstream << "{" << std::endl;
sstream << " name: " << op.name() << std::endl;
sstream << " type: " << op.type() << std::endl;
sstream << " device: " << device_type << std::endl;
sstream << " data type: " << data_type << std::endl;
sstream << " data format: " << data_format << std::endl;
sstream << " memory type: " << mem_type << std::endl;
sstream << " inputs: [";
for (auto input : op.input()) {
sstream << input << ", ";
}
sstream << "]" << std::endl;
sstream << " outputs: [";
for (auto output : op.output()) {
sstream << output << ", ";
}
sstream << "]" << std::endl;
sstream << " output shapes: [";
for (auto output_shape : op.output_shape()) {
sstream << "(";
for (auto dim : output_shape.dims())
sstream << dim << ",";
sstream << ") ";
}
sstream << "]" << std::endl;
sstream << "}";
}
return sstream.str();
}
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_NET_DEF_ADAPTER_H_
#define MACE_CORE_NET_DEF_ADAPTER_H_
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "mace/core/types.h"
#include "mace/proto/mace.pb.h"
#include "mace/port/port.h"
#include "mace/core/operator.h"
#include "mace/core/net_optimizer.h"
namespace mace {
class OpRegistryBase;
class Workspace;
class Device;
/// Conventions:
/// 1. DataFormat::AUTO stands for formatted (NHWC or NCHW)
/// 2. if Op with DataFormat::AUTO, the arguments of this op
/// is formatted to NHWC
class NetDefAdapter {
public:
NetDefAdapter(const OpRegistryBase *op_registry,
const Workspace *ws);
// Adapt original net_def to a better net.
// 1. Adapt device: choose best device for every op in the net.
// 2. Adapt data type: Add data type related transform ops
// for mixing precision.
// 3. Adapt data format: confirm data format of every op
// and add transpose if necessary.
// 4. Adapt memory type: Add BufferTransform if necessary
// for transforming memory type between ops.
MaceStatus AdaptNetDef(
const NetDef *net_def,
Device *target_device,
NetDef *target_net_def);
public:
NetDefAdapter(const NetDefAdapter&) = delete;
NetDefAdapter(const NetDefAdapter&&) = delete;
NetDefAdapter &operator=(const NetDefAdapter &) = delete;
NetDefAdapter &operator=(const NetDefAdapter &&) = delete;
private:
struct InternalOutputInfo {
InternalOutputInfo(const MemoryType mem_type,
const DataType dtype,
const DataFormat data_format,
const std::vector<index_t> &shape,
int op_idx)
: mem_type(mem_type), dtype(dtype), data_format(data_format),
shape(shape), op_idx(op_idx) {}
MemoryType mem_type;
DataType dtype;
DataFormat data_format;
std::vector<index_t> shape; // tensor shape
int op_idx; // operation which generate the tensor
};
typedef std::unordered_map<std::string, InternalOutputInfo> TensorInfoMap;
private:
MaceStatus AdaptDevice(OpConditionContext *context,
Device *target_device,
Device *cpu_device,
const TensorInfoMap &output_map,
const NetDef *net_def,
OperatorDef *op);
MaceStatus AdaptDataType(OpConditionContext *context,
OperatorDef *op);
MaceStatus AdaptDataFormat(
OpConditionContext *context,
OperatorDef *op,
bool is_quantized_model,
TensorInfoMap *output_map,
std::unordered_set<std::string> *transformed_set,
DataFormat *op_output_df,
NetDef *target_net_def);
MaceStatus AdaptMemoryType(
OpConditionContext *context,
OperatorDef *op_def,
TensorInfoMap *output_map,
std::unordered_set<std::string> *transformed_set,
MemoryType *op_output_mem_types,
NetDef *target_net_def);
std::string DebugString(const NetDef *net_def);
private:
const OpRegistryBase *op_registry_;
const Workspace *ws_;
NetOptimizer net_optimizer_;
};
} // namespace mace
#endif // MACE_CORE_NET_DEF_ADAPTER_H_
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/net_optimizer.h"
#include <string>
namespace mace {
DeviceType NetOptimizer::SelectBestDevice(
const OperatorDef *op_def,
DeviceType target_device_type,
const std::set<DeviceType> &available_devices,
const std::vector<DeviceType> &inputs_op_devices) {
static const std::set<std::string> kComputeIntensiveOps = {
"Conv2D", "DepthwiseConv2d", "Deconv2D", "DepthwiseDeconv2d",
"FullyConnected"
};
// CPU is the device to fall back
DeviceType best_device = DeviceType::CPU;
if (available_devices.count(target_device_type) == 1) {
best_device = target_device_type;
}
if (best_device == DeviceType::CPU) {
return best_device;
}
// Put compute-intensive ops in target device
if (kComputeIntensiveOps.count(op_def->type()) == 1) {
return best_device;
}
// Greedy strategy: Use input op's device type as current op's device
for (auto device_type : inputs_op_devices) {
if (device_type != best_device) {
best_device = device_type;
}
}
return best_device;
}
} // namespace mace
// Copyright 2019 The MACE Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_CORE_NET_OPTIMIZER_H_
#define MACE_CORE_NET_OPTIMIZER_H_
#include <set>
#include <vector>
#include "mace/port/port.h"
#include "mace/proto/mace.pb.h"
namespace mace {
/// Any optimization for Net could be put in here in the future.
class NetOptimizer {
public:
/// Select best device for the op to support mixing usage of CPU and GPU.
/// Greedy strategy: one way to the end. If the op fallback to CPU, then
/// the follow-up ops will run on CPU too util meet
/// some compute-intensive ops(Convolution) to
/// reduce the memory copy between CPU and GPU.
/// Simple but effective.
///
/// \param op_def the op
/// \param target_device target device to run on
/// \param available_devices available devices of the op
/// \param inputs_op_devices devices of father ops run on
/// \return Best device for the op_def
DeviceType SelectBestDevice(const OperatorDef *op_def,
DeviceType target_device,
const std::set<DeviceType> &available_devices,
const std::vector<DeviceType> &inputs_op_devices);
};
} // namespace mace
#endif // MACE_CORE_NET_OPTIMIZER_H_
...@@ -20,36 +20,23 @@ ...@@ -20,36 +20,23 @@
#include "mace/core/operator.h" #include "mace/core/operator.h"
namespace mace { namespace mace {
OpConditionContext::OpConditionContext(
OpConstructContext::OpConstructContext(Workspace *ws) const Workspace *ws,
: operator_def_(nullptr), OpConditionContext::TensorShapeMap *info)
ws_(ws),
device_(nullptr),
tensor_shape_info_(nullptr) {}
OpConstructContext::OpConstructContext(
mace::Workspace *ws,
mace::OpConstructContext::TensorShapeMap *info)
: operator_def_(nullptr), : operator_def_(nullptr),
ws_(ws), ws_(ws),
device_(nullptr), device_(nullptr),
tensor_shape_info_(info) {} tensor_shape_info_(info) {}
void OpConstructContext::set_operator_def( void OpConditionContext::set_operator_def(
std::shared_ptr<mace::OperatorDef> operator_def) { const OperatorDef *operator_def) {
operator_def_ = operator_def; operator_def_ = operator_def;
input_data_types_.clear(); input_data_types_.clear();
} }
void OpConstructContext::set_output_mem_type(mace::MemoryType type) { void OpConditionContext::SetInputInfo(size_t idx,
MACE_CHECK(operator_def_ != nullptr); MemoryType mem_type,
output_mem_type_ = type; DataType dt) {
input_mem_types_.clear();
}
void OpConstructContext::SetInputInfo(size_t idx,
mace::MemoryType mem_type,
mace::DataType dt) {
if (input_mem_types_.empty()) { if (input_mem_types_.empty()) {
// the default inputs' memory types are same as output memory type. // the default inputs' memory types are same as output memory type.
input_mem_types_.resize(operator_def_->input_size(), output_mem_type_); input_mem_types_.resize(operator_def_->input_size(), output_mem_type_);
...@@ -66,7 +53,13 @@ void OpConstructContext::SetInputInfo(size_t idx, ...@@ -66,7 +53,13 @@ void OpConstructContext::SetInputInfo(size_t idx,
input_data_types_[idx] = dt; input_data_types_[idx] = dt;
} }
MemoryType OpConstructContext::GetInputMemType(size_t idx) const { void OpConditionContext::set_output_mem_type(MemoryType type) {
MACE_CHECK(operator_def_ != nullptr);
output_mem_type_ = type;
input_mem_types_.clear();
}
MemoryType OpConditionContext::GetInputMemType(size_t idx) const {
if (input_mem_types_.empty()) { if (input_mem_types_.empty()) {
return output_mem_type_; return output_mem_type_;
} }
...@@ -75,7 +68,7 @@ MemoryType OpConstructContext::GetInputMemType(size_t idx) const { ...@@ -75,7 +68,7 @@ MemoryType OpConstructContext::GetInputMemType(size_t idx) const {
return input_mem_types_[idx]; return input_mem_types_[idx];
} }
DataType OpConstructContext::GetInputDataType(size_t idx) const { DataType OpConditionContext::GetInputDataType(size_t idx) const {
if (input_data_types_.empty()) { if (input_data_types_.empty()) {
// the default inputs' data types are same as operation's data type. // the default inputs' data types are same as operation's data type.
return static_cast<DataType>( return static_cast<DataType>(
...@@ -87,17 +80,17 @@ DataType OpConstructContext::GetInputDataType(size_t idx) const { ...@@ -87,17 +80,17 @@ DataType OpConstructContext::GetInputDataType(size_t idx) const {
} }
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
void OpConstructContext::SetInputOpenCLBufferType( void OpConditionContext::SetInputOpenCLBufferType(
size_t idx, OpenCLBufferType buffer_type) { size_t idx, OpenCLBufferType buffer_type) {
if (input_opencl_buffer_types_.empty()) { if (input_opencl_buffer_types_.empty()) {
// the default inputs' memory types are same as output memory type. // the default inputs' memory types are same as output memory type.
input_opencl_buffer_types_.resize(operator_def_->input_size(), input_opencl_buffer_types_.resize(operator_def_->input_size(),
OpenCLBufferType::IN_OUT_CHANNEL); OpenCLBufferType::IN_OUT_CHANNEL);
} }
MACE_CHECK(idx < input_opencl_buffer_types_.size()); MACE_CHECK(idx < input_opencl_buffer_types_.size());
input_opencl_buffer_types_[idx] = buffer_type; input_opencl_buffer_types_[idx] = buffer_type;
} }
OpenCLBufferType OpConstructContext::GetInputOpenCLBufferType( OpenCLBufferType OpConditionContext::GetInputOpenCLBufferType(
size_t idx) const { size_t idx) const {
if (input_opencl_buffer_types_.empty()) { if (input_opencl_buffer_types_.empty()) {
return OpenCLBufferType::IN_OUT_CHANNEL; return OpenCLBufferType::IN_OUT_CHANNEL;
...@@ -107,6 +100,16 @@ OpenCLBufferType OpConstructContext::GetInputOpenCLBufferType( ...@@ -107,6 +100,16 @@ OpenCLBufferType OpConstructContext::GetInputOpenCLBufferType(
} }
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
OpConstructContext::OpConstructContext(Workspace *ws)
: operator_def_(nullptr),
ws_(ws),
device_(nullptr) {}
void OpConstructContext::set_operator_def(
std::shared_ptr<OperatorDef> operator_def) {
operator_def_ = operator_def;
}
OpInitContext::OpInitContext(Workspace *ws, Device *device) OpInitContext::OpInitContext(Workspace *ws, Device *device)
: ws_(ws), device_(device) {} : ws_(ws), device_(device) {}
...@@ -202,19 +205,40 @@ const std::string OpKeyBuilder::Build() { ...@@ -202,19 +205,40 @@ const std::string OpKeyBuilder::Build() {
} // namespace } // namespace
OpRegistrationInfo::OpRegistrationInfo() { OpRegistrationInfo::OpRegistrationInfo() {
device_placer = [this](OpConstructContext *context) -> std::set<DeviceType> { // default device type placer
auto op = context->operator_def(); device_placer = [this](OpConditionContext *context) -> std::set<DeviceType> {
// The GPU ops only support 4D In/Out tensor by default MACE_UNUSED(context);
if (this->devices.count(DeviceType::CPU) == 1 &&
op->output_shape_size() == op->output_size() &&
op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU };
}
return this->devices; return this->devices;
}; };
// default input and output memory type setter
memory_type_setter = [](OpConditionContext *context) -> void {
if (context->device()->device_type() == DeviceType::GPU) {
#ifdef MACE_ENABLE_OPENCL
if (context->device()->gpu_runtime()->UseImageMemory()) {
context->set_output_mem_type(MemoryType::GPU_IMAGE);
} else {
context->set_output_mem_type(MemoryType::GPU_BUFFER);
}
#endif // MACE_ENABLE_OPENCL
} else {
context->set_output_mem_type(MemoryType::CPU_BUFFER);
}
};
data_format_selector = [](OpConditionContext *context)
-> std::vector<DataFormat> {
DataFormat op_data_format =
static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*context->operator_def(), "data_format",
static_cast<int>(DataFormat::NONE)));
return std::vector<DataFormat>(context->operator_def()->input_size(),
op_data_format);
};
} }
void OpRegistrationInfo::AddDevice(mace::DeviceType device) { void OpRegistrationInfo::AddDevice(DeviceType device) {
devices.insert(device); devices.insert(device);
} }
...@@ -226,9 +250,9 @@ void OpRegistrationInfo::Register(const std::string &key, OpCreator creator) { ...@@ -226,9 +250,9 @@ void OpRegistrationInfo::Register(const std::string &key, OpCreator creator) {
MaceStatus OpRegistryBase::Register( MaceStatus OpRegistryBase::Register(
const std::string &op_type, const std::string &op_type,
const mace::DeviceType device_type, const DeviceType device_type,
const mace::DataType dt, const DataType dt,
mace::OpRegistrationInfo::OpCreator creator) { OpRegistrationInfo::OpCreator creator) {
if (registry_.count(op_type) == 0) { if (registry_.count(op_type) == 0) {
registry_[op_type] = std::unique_ptr<OpRegistrationInfo>( registry_[op_type] = std::unique_ptr<OpRegistrationInfo>(
new OpRegistrationInfo); new OpRegistrationInfo);
...@@ -255,13 +279,29 @@ MaceStatus OpRegistryBase::Register( ...@@ -255,13 +279,29 @@ MaceStatus OpRegistryBase::Register(
} }
const std::set<DeviceType> OpRegistryBase::AvailableDevices( const std::set<DeviceType> OpRegistryBase::AvailableDevices(
const std::string &op_type, OpConstructContext *context) const { const std::string &op_type, OpConditionContext *context) const {
MACE_CHECK(registry_.count(op_type) != 0, MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered."); op_type, " operation is not registered.");
return registry_.at(op_type)->device_placer(context); return registry_.at(op_type)->device_placer(context);
} }
void OpRegistryBase::GetInOutMemoryTypes(
const std::string &op_type,
OpConditionContext *context) const {
MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered.");
return registry_.at(op_type)->memory_type_setter(context);
}
const std::vector<DataFormat> OpRegistryBase::InputsDataFormat(
const std::string &op_type,
OpConditionContext *context) const {
MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered.");
return registry_.at(op_type)->data_format_selector(context);
}
std::unique_ptr<Operation> OpRegistryBase::CreateOperation( std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
OpConstructContext *context, OpConstructContext *context,
DeviceType device_type) const { DeviceType device_type) const {
...@@ -269,15 +309,6 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation( ...@@ -269,15 +309,6 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
DataType dtype = static_cast<DataType>( DataType dtype = static_cast<DataType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>( ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def, "T", static_cast<int>(DT_FLOAT))); *operator_def, "T", static_cast<int>(DT_FLOAT)));
if (device_type == DeviceType::CPU && dtype == DT_HALF) {
int arg_size = operator_def->arg_size();
for (int i = 0; i < arg_size; ++i) {
if (operator_def->arg(i).name() == "T") {
operator_def->mutable_arg(i)->set_i(DT_FLOAT);
}
}
dtype = DT_FLOAT;
}
VLOG(1) << "Creating operator " << operator_def->name() << "(" VLOG(1) << "Creating operator " << operator_def->name() << "("
<< operator_def->type() << "<" << dtype << ">" << ") on " << operator_def->type() << "<" << dtype << ">" << ") on "
<< device_type; << device_type;
...@@ -308,9 +339,30 @@ OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc( ...@@ -308,9 +339,30 @@ OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc(
return *this; return *this;
} }
OpConditionBuilder& OpConditionBuilder::SetInputMemoryTypeSetter(
OpRegistrationInfo::MemoryTypeSetter setter) {
memory_type_setter_ = setter;
return *this;
}
OpConditionBuilder& OpConditionBuilder::SetInputsDataFormatSelector(
OpRegistrationInfo::DataFormatSelector selector) {
data_format_selector_ = selector;
return *this;
}
void OpConditionBuilder::Finalize(OpRegistrationInfo *info) const { void OpConditionBuilder::Finalize(OpRegistrationInfo *info) const {
if (info != nullptr && placer_) { if (info != nullptr) {
info->device_placer = placer_; if (placer_) {
info->device_placer = placer_;
}
if (memory_type_setter_) {
info->memory_type_setter = memory_type_setter_;
}
if (data_format_selector_) {
info->data_format_selector = data_format_selector_;
}
} }
} }
......
...@@ -32,22 +32,20 @@ ...@@ -32,22 +32,20 @@
namespace mace { namespace mace {
// memory_optimizer, device // OpConditionContext has all information used for choosing proper Op
class OpConstructContext { class OpConditionContext {
public:
typedef std::unordered_map<std::string, std::vector<index_t>> TensorShapeMap; typedef std::unordered_map<std::string, std::vector<index_t>> TensorShapeMap;
OpConditionContext(const Workspace *ws, TensorShapeMap *info);
~OpConditionContext() = default;
public: void set_operator_def(const OperatorDef* operator_def);
explicit OpConstructContext(Workspace *ws);
OpConstructContext(Workspace *ws, TensorShapeMap *info);
~OpConstructContext() = default;
void set_operator_def(std::shared_ptr<OperatorDef> operator_def); inline const OperatorDef *operator_def() const {
inline std::shared_ptr<OperatorDef> operator_def() const {
return operator_def_; return operator_def_;
} }
inline Workspace *workspace() const { inline const Workspace *workspace() const {
return ws_; return ws_;
} }
...@@ -81,8 +79,8 @@ class OpConstructContext { ...@@ -81,8 +79,8 @@ class OpConstructContext {
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
private: private:
std::shared_ptr<OperatorDef> operator_def_; const OperatorDef *operator_def_;
Workspace *ws_; const Workspace *ws_;
Device *device_; Device *device_;
TensorShapeMap *tensor_shape_info_; TensorShapeMap *tensor_shape_info_;
// used for memory transform // used for memory transform
...@@ -94,6 +92,46 @@ class OpConstructContext { ...@@ -94,6 +92,46 @@ class OpConstructContext {
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
}; };
// memory_optimizer, device
class OpConstructContext {
typedef std::unordered_map<std::string, std::vector<index_t>> TensorShapeMap;
public:
explicit OpConstructContext(Workspace *ws);
~OpConstructContext() = default;
void set_operator_def(std::shared_ptr<OperatorDef> operator_def);
inline std::shared_ptr<OperatorDef> operator_def() const {
return operator_def_;
}
inline Workspace *workspace() const {
return ws_;
}
inline void set_device(Device* device) {
device_ = device;
}
inline Device *device() const {
return device_;
}
#ifdef MACE_ENABLE_OPENCL
inline MemoryType GetOpMemoryType() const {
return static_cast<MemoryType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*operator_def_, OutputMemoryTypeTagName(),
static_cast<int>(MemoryType::CPU_BUFFER)));
}
#endif // MACE_ENABLE_OPENCL
private:
std::shared_ptr<OperatorDef> operator_def_;
Workspace *ws_;
Device *device_;
};
// memory_optimizer, device // memory_optimizer, device
class OpInitContext { class OpInitContext {
public: public:
...@@ -207,8 +245,11 @@ struct OpRegistrationInfo { ...@@ -207,8 +245,11 @@ struct OpRegistrationInfo {
public: public:
typedef std::function<std::unique_ptr<Operation>(OpConstructContext *)> typedef std::function<std::unique_ptr<Operation>(OpConstructContext *)>
OpCreator; OpCreator;
typedef std::function<std::set<DeviceType>(OpConstructContext *)> typedef std::function<std::set<DeviceType>(OpConditionContext *)>
DevicePlacer; DevicePlacer;
typedef std::function<void(OpConditionContext *)> MemoryTypeSetter;
typedef std::function<std::vector<DataFormat>(OpConditionContext *)>
DataFormatSelector;
OpRegistrationInfo(); OpRegistrationInfo();
...@@ -219,6 +260,8 @@ struct OpRegistrationInfo { ...@@ -219,6 +260,8 @@ struct OpRegistrationInfo {
std::set<DeviceType> devices; std::set<DeviceType> devices;
std::unordered_map<std::string, OpCreator> creators; std::unordered_map<std::string, OpCreator> creators;
DevicePlacer device_placer; DevicePlacer device_placer;
MemoryTypeSetter memory_type_setter;
DataFormatSelector data_format_selector;
}; };
class OpConditionBuilder { class OpConditionBuilder {
...@@ -230,11 +273,21 @@ class OpConditionBuilder { ...@@ -230,11 +273,21 @@ class OpConditionBuilder {
OpConditionBuilder &SetDevicePlacerFunc( OpConditionBuilder &SetDevicePlacerFunc(
OpRegistrationInfo::DevicePlacer placer); OpRegistrationInfo::DevicePlacer placer);
// If you set input memory type for specified Op,
// you must call OpConditionContext::set_output_mem_type
OpConditionBuilder &SetInputMemoryTypeSetter(
OpRegistrationInfo::MemoryTypeSetter setter);
OpConditionBuilder &SetInputsDataFormatSelector(
OpRegistrationInfo::DataFormatSelector selector);
void Finalize(OpRegistrationInfo *info) const; void Finalize(OpRegistrationInfo *info) const;
private: private:
std::string type_; std::string type_;
OpRegistrationInfo::DevicePlacer placer_; OpRegistrationInfo::DevicePlacer placer_;
OpRegistrationInfo::MemoryTypeSetter memory_type_setter_;
OpRegistrationInfo::DataFormatSelector data_format_selector_;
}; };
...@@ -250,7 +303,13 @@ class OpRegistryBase { ...@@ -250,7 +303,13 @@ class OpRegistryBase {
MaceStatus Register(const OpConditionBuilder &builder); MaceStatus Register(const OpConditionBuilder &builder);
const std::set<DeviceType> AvailableDevices( const std::set<DeviceType> AvailableDevices(
const std::string &op_type, OpConstructContext *context) const; const std::string &op_type, OpConditionContext *context) const;
void GetInOutMemoryTypes(
const std::string &op_type, OpConditionContext *context) const;
const std::vector<DataFormat> InputsDataFormat(
const std::string &op_type, OpConditionContext *context) const;
std::unique_ptr<Operation> CreateOperation( std::unique_ptr<Operation> CreateOperation(
OpConstructContext *context, OpConstructContext *context,
......
...@@ -147,38 +147,38 @@ void OpenCLUtil::CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */ ...@@ -147,38 +147,38 @@ void OpenCLUtil::CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
} }
} }
std::shared_ptr<OperatorDef> OpenCLUtil::CreateTransformOpDef( void OpenCLUtil::BuildTransformOpDef(
const std::string &input_name, const std::string &input_name,
const std::vector<mace::index_t> &input_shape, const std::vector<mace::index_t> &input_shape,
const std::string &output_name, const std::string &output_name,
const mace::DataType dt, const mace::DataType dt,
const OpenCLBufferType buffer_type, const OpenCLBufferType buffer_type,
const mace::MemoryType mem_type, const mace::MemoryType mem_type,
bool has_data_format) { DataFormat data_format,
std::unique_ptr<OperatorDef> op(new OperatorDef); OperatorDef *op_def) {
std::string op_name = "mace_node_" + output_name; std::string op_name = "mace_node_" + output_name;
op->set_name(op_name); op_def->set_name(op_name);
op->set_type("BufferTransform"); op_def->set_type("BufferTransform");
op->add_input(input_name); op_def->add_input(input_name);
op->add_output(output_name); op_def->add_output(output_name);
Argument *arg = op->add_arg(); op_def->set_device_type(DeviceType::GPU);
Argument *arg = op_def->add_arg();
arg->set_name("buffer_type"); arg->set_name("buffer_type");
arg->set_i(static_cast<int32_t>(buffer_type)); arg->set_i(static_cast<int32_t>(buffer_type));
arg = op->add_arg(); arg = op_def->add_arg();
arg->set_name("mem_type"); arg->set_name("mem_type");
arg->set_i(static_cast<int32_t>(mem_type)); arg->set_i(static_cast<int32_t>(mem_type));
arg = op->add_arg(); arg = op_def->add_arg();
arg->set_name("T"); arg->set_name("T");
arg->set_i(static_cast<int32_t>(dt)); arg->set_i(static_cast<int32_t>(dt));
arg = op->add_arg(); arg = op_def->add_arg();
arg->set_name("has_data_format"); arg->set_name("data_format");
arg->set_i(has_data_format); arg->set_i(static_cast<int>(data_format));
if (!input_shape.empty()) { if (!input_shape.empty()) {
OutputShape *shape = op->add_output_shape(); OutputShape *shape = op_def->add_output_shape();
for (auto value : input_shape) { for (auto value : input_shape) {
shape->add_dims(value); shape->add_dims(value);
} }
} }
return std::move(op);
} }
} // namespace mace } // namespace mace
...@@ -43,14 +43,15 @@ class OpenCLUtil { ...@@ -43,14 +43,15 @@ class OpenCLUtil {
std::vector<size_t> *image_shape, std::vector<size_t> *image_shape,
const int wino_blk_size = 2); const int wino_blk_size = 2);
static std::shared_ptr<OperatorDef> CreateTransformOpDef( static void BuildTransformOpDef(
const std::string &input_name, const std::string &input_name,
const std::vector<mace::index_t> &input_shape, const std::vector<mace::index_t> &input_shape,
const std::string &output_name, const std::string &output_name,
const mace::DataType dt, const mace::DataType dt,
const OpenCLBufferType buffer_type, const OpenCLBufferType buffer_type,
const MemoryType mem_type, const MemoryType mem_type,
bool has_data_format); DataFormat data_format,
OperatorDef *op_def);
}; };
} // namespace mace } // namespace mace
......
...@@ -263,13 +263,13 @@ MaceStatus Workspace::PreallocateOutputTensor( ...@@ -263,13 +263,13 @@ MaceStatus Workspace::PreallocateOutputTensor(
} }
} }
VLOG(1) << "Preallocate buffer to tensors"; VLOG(1) << "Preallocate buffer to tensors";
bool is_quantize_model = IsQuantizedModel(net_def);
for (auto &tensor_mem : mem_optimizer->tensor_mem_map()) { for (auto &tensor_mem : mem_optimizer->tensor_mem_map()) {
std::unique_ptr<Tensor> tensor std::unique_ptr<Tensor> tensor
(new Tensor(preallocated_allocator_.GetBuffer(tensor_mem.second.mem_id), (new Tensor(preallocated_allocator_.GetBuffer(tensor_mem.second.mem_id),
tensor_mem.second.data_type, tensor_mem.second.data_type,
false, tensor_mem.first)); false, tensor_mem.first));
if (tensor_mem.second.has_data_format) { tensor->set_data_format(tensor_mem.second.data_format);
if (tensor_mem.second.data_format != DataFormat::NONE) {
if (mem_blocks[tensor_mem.second.mem_id].mem_type() if (mem_blocks[tensor_mem.second.mem_id].mem_type()
== MemoryType::GPU_IMAGE) { == MemoryType::GPU_IMAGE) {
VLOG(1) << "Tensor: " << tensor_mem.first VLOG(1) << "Tensor: " << tensor_mem.first
...@@ -279,22 +279,12 @@ MaceStatus Workspace::PreallocateOutputTensor( ...@@ -279,22 +279,12 @@ MaceStatus Workspace::PreallocateOutputTensor(
<< tensor->UnderlyingBuffer()->shape()[0] << tensor->UnderlyingBuffer()->shape()[0]
<< ", " << ", "
<< tensor->UnderlyingBuffer()->shape()[1]; << tensor->UnderlyingBuffer()->shape()[1];
tensor->set_data_format(DataFormat::NHWC);
} else { } else {
VLOG(1) << "Tensor: " << tensor_mem.first VLOG(1) << "Tensor: " << tensor_mem.first
<< " Mem: " << tensor_mem.second.mem_id << " Mem: " << tensor_mem.second.mem_id
<< " Data type: " << tensor->dtype() << " Data type: " << tensor->dtype()
<< ", Buffer size: " << tensor->UnderlyingBuffer()->size(); << ", Buffer size: " << tensor->UnderlyingBuffer()->size();
if (mem_blocks[tensor_mem.second.mem_id].mem_type()
== MemoryType::GPU_BUFFER ||
is_quantize_model) {
tensor->set_data_format(DataFormat::NHWC);
} else {
tensor->set_data_format(DataFormat::NCHW);
}
} }
} else {
tensor->set_data_format(DataFormat::DF_NONE);
} }
tensor_map_[tensor_mem.first] = std::move(tensor); tensor_map_[tensor_mem.first] = std::move(tensor);
} }
......
...@@ -94,7 +94,7 @@ DataFormat ParseDataFormat(const std::string &data_format_str) { ...@@ -94,7 +94,7 @@ DataFormat ParseDataFormat(const std::string &data_format_str) {
} else if (data_format_str == "OIHW") { } else if (data_format_str == "OIHW") {
return DataFormat::OIHW; return DataFormat::OIHW;
} else { } else {
return DataFormat::DF_NONE; return DataFormat::NONE;
} }
} }
......
...@@ -143,7 +143,7 @@ void BMNet::SetUp() { ...@@ -143,7 +143,7 @@ void BMNet::SetUp() {
// Add input and output information // Add input and output information
for (size_t i = 0; i < input_names_.size(); ++i) { for (size_t i = 0; i < input_names_.size(); ++i) {
InputOutputInfo *info = net_.add_input_info(); InputOutputInfo *info = net_.add_input_info();
info->set_data_format(DataFormat::NHWC); info->set_data_format(static_cast<int>(DataFormat::NHWC));
info->set_name(input_names_[i]); info->set_name(input_names_[i]);
for (auto d : input_shapes_[i]) { for (auto d : input_shapes_[i]) {
info->add_dims(static_cast<int>(d)); info->add_dims(static_cast<int>(d));
...@@ -244,7 +244,7 @@ void BMNet::AddConv(const std::string &conv_type, ...@@ -244,7 +244,7 @@ void BMNet::AddConv(const std::string &conv_type,
op_def->add_output(output_name); op_def->add_output(output_name);
AddIntsArg(op_def, "strides", strides); AddIntsArg(op_def, "strides", strides);
AddIntArg(op_def, "padding", padding_type); AddIntArg(op_def, "padding", padding_type);
AddIntArg(op_def, "has_data_format", 1); AddIntArg(op_def, "data_format", static_cast<int>(DataFormat::AUTO));
AddIntArg(op_def, "T", DT_HALF); AddIntArg(op_def, "T", DT_HALF);
if (has_relu6) { if (has_relu6) {
AddStringArg(op_def, "activation", "RELUX"); AddStringArg(op_def, "activation", "RELUX");
...@@ -271,7 +271,7 @@ void BMNet::AddEltwise(const std::string &op_name, ...@@ -271,7 +271,7 @@ void BMNet::AddEltwise(const std::string &op_name,
op_def->add_output(output); op_def->add_output(output);
AddIntArg(op_def, "type", type); AddIntArg(op_def, "type", type);
AddIntArg(op_def, "T", DT_HALF); AddIntArg(op_def, "T", DT_HALF);
AddIntArg(op_def, "has_data_format", 1); AddIntArg(op_def, "data_format", static_cast<int>(DataFormat::AUTO));
OutputShape *shape = op_def->add_output_shape(); OutputShape *shape = op_def->add_output_shape();
for (auto dim : output_shape) { for (auto dim : output_shape) {
shape->add_dims(dim); shape->add_dims(dim);
......
...@@ -27,6 +27,7 @@ ...@@ -27,6 +27,7 @@
#include "mace/public/mace.h" #include "mace/public/mace.h"
#include "mace/port/env.h" #include "mace/port/env.h"
#include "mace/port/file_system.h" #include "mace/port/file_system.h"
#include "mace/core/net_def_adapter.h"
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/gpu_device.h" #include "mace/core/runtime/opencl/gpu_device.h"
...@@ -282,9 +283,9 @@ MaceTensor::MaceTensor(const std::vector<int64_t> &shape, ...@@ -282,9 +283,9 @@ MaceTensor::MaceTensor(const std::vector<int64_t> &shape,
std::shared_ptr<void> data, std::shared_ptr<void> data,
const DataFormat format) { const DataFormat format) {
MACE_CHECK_NOTNULL(data.get()); MACE_CHECK_NOTNULL(data.get());
MACE_CHECK(format == DataFormat::DF_NONE || format == DataFormat::NHWC MACE_CHECK(format == DataFormat::NONE || format == DataFormat::NHWC
|| format == DataFormat::NCHW || format == OIHW, || format == DataFormat::NCHW || format == DataFormat::OIHW,
"MACE only support DF_NONE, NHWC, NCHW and OIHW " "MACE only support NONE, NHWC, NCHW and OIHW "
"formats of input now."); "formats of input now.");
impl_ = make_unique<MaceTensor::Impl>(); impl_ = make_unique<MaceTensor::Impl>();
impl_->shape = shape; impl_->shape = shape;
...@@ -495,7 +496,7 @@ MaceStatus MaceEngine::Impl::Init( ...@@ -495,7 +496,7 @@ MaceStatus MaceEngine::Impl::Init(
DataType output_dt = output_info_map_[output_name].data_type(); DataType output_dt = output_info_map_[output_name].data_type();
Tensor *output_tensor = Tensor *output_tensor =
ws_->CreateTensor(output_name, device_->allocator(), output_dt); ws_->CreateTensor(output_name, device_->allocator(), output_dt);
output_tensor->set_data_format(NHWC); output_tensor->set_data_format(DataFormat::NHWC);
#endif #endif
} }
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA) #if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
...@@ -512,26 +513,32 @@ MaceStatus MaceEngine::Impl::Init( ...@@ -512,26 +513,32 @@ MaceStatus MaceEngine::Impl::Init(
} }
} else { } else {
#endif #endif
MACE_RETURN_IF_ERROR(ws_->LoadModelTensor(*net_def, MACE_RETURN_IF_ERROR(ws_->LoadModelTensor(*net_def,
device_.get(), device_.get(),
model_data)); model_data));
MemoryOptimizer mem_optimizer; NetDef adapted_net_def;
// Init model NetDefAdapter net_def_adapter(op_registry_.get(), ws_.get());
net_ = std::unique_ptr<NetBase>(new SerialNet(op_registry_.get(), net_def_adapter.AdaptNetDef(net_def, device_.get(), &adapted_net_def);
net_def,
ws_.get(), MemoryOptimizer mem_optimizer;
device_.get(), // Init model
&mem_optimizer)); net_ = std::unique_ptr<NetBase>(new SerialNet(op_registry_.get(),
&adapted_net_def,
// Preallocate all output tensors of ops ws_.get(),
MACE_RETURN_IF_ERROR(ws_->PreallocateOutputTensor(*net_def, device_.get(),
&mem_optimizer, &mem_optimizer));
device_.get()));
if (device_type_ == DeviceType::GPU) { // Preallocate all output tensors of ops
ws_->RemoveAndReloadBuffer(*net_def, model_data, device_->allocator()); MACE_RETURN_IF_ERROR(ws_->PreallocateOutputTensor(adapted_net_def,
} &mem_optimizer,
MACE_RETURN_IF_ERROR(net_->Init()); device_.get()));
if (device_type_ == DeviceType::GPU) {
ws_->RemoveAndReloadBuffer(adapted_net_def,
model_data,
device_->allocator());
}
MACE_RETURN_IF_ERROR(net_->Init());
#if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA) #if defined(MACE_ENABLE_HEXAGON) || defined(MACE_ENABLE_HTA)
} }
#endif #endif
...@@ -578,14 +585,14 @@ MaceEngine::Impl::~Impl() { ...@@ -578,14 +585,14 @@ MaceEngine::Impl::~Impl() {
MaceStatus MaceEngine::Impl::TransposeInput( MaceStatus MaceEngine::Impl::TransposeInput(
const std::pair<const std::string, MaceTensor> &input, const std::pair<const std::string, MaceTensor> &input,
Tensor *input_tensor) { Tensor *input_tensor) {
bool has_data_format = input_tensor->data_format() != DataFormat::DF_NONE; bool has_data_format = input_tensor->data_format() != DataFormat::NONE;
DataFormat data_format = DataFormat::DF_NONE; DataFormat data_format = DataFormat::NONE;
DataType input_dt = input_tensor->dtype(); DataType input_dt = input_tensor->dtype();
if (has_data_format) { if (has_data_format) {
std::vector<int> dst_dims; std::vector<int> dst_dims;
if (device_->device_type() == DeviceType::CPU && if (device_->device_type() == DeviceType::CPU &&
input.second.shape().size() == 4 && input.second.shape().size() == 4 &&
input.second.data_format() == NHWC && input.second.data_format() == DataFormat::NHWC &&
!is_quantized_model_) { !is_quantized_model_) {
VLOG(1) << "Transform input " << input.first << " from NHWC to NCHW"; VLOG(1) << "Transform input " << input.first << " from NHWC to NCHW";
input_tensor->set_data_format(DataFormat::NCHW); input_tensor->set_data_format(DataFormat::NCHW);
...@@ -647,28 +654,28 @@ MaceStatus MaceEngine::Impl::TransposeOutput( ...@@ -647,28 +654,28 @@ MaceStatus MaceEngine::Impl::TransposeOutput(
DataType output_dt = output_tensor->dtype(); DataType output_dt = output_tensor->dtype();
// save output // save output
if (output_tensor != nullptr && output->second.data() != nullptr) { if (output_tensor != nullptr && output->second.data() != nullptr) {
if (output_tensor->data_format() != DataFormat::DF_NONE && if (output_tensor->data_format() != DataFormat::NONE &&
output->second.data_format() != DataFormat::DF_NONE && output->second.data_format() != DataFormat::NONE &&
output->second.shape().size() == 4 && output->second.shape().size() == 4 &&
output->second.data_format() != output_tensor->data_format()) { output->second.data_format() != output_tensor->data_format()) {
VLOG(1) << "Transform output " << output->first << " from " VLOG(1) << "Transform output " << output->first << " from "
<< output_tensor->data_format() << " to " << static_cast<int>(output_tensor->data_format()) << " to "
<< output->second.data_format(); << static_cast<int>(output->second.data_format());
std::vector<int> dst_dims; std::vector<int> dst_dims;
if (output_tensor->data_format() == NCHW && if (output_tensor->data_format() == DataFormat::NCHW &&
output->second.data_format() == NHWC) { output->second.data_format() == DataFormat::NHWC) {
dst_dims = {0, 2, 3, 1}; dst_dims = {0, 2, 3, 1};
} else if (output_tensor->data_format() == NHWC && } else if (output_tensor->data_format() == DataFormat::NHWC &&
output->second.data_format() == NCHW) { output->second.data_format() == DataFormat::NCHW) {
dst_dims = {0, 3, 1, 2}; dst_dims = {0, 3, 1, 2};
} else { } else {
LOG(FATAL) << "Not supported output data format: " LOG(FATAL) << "Not supported output data format: "
<< output->second.data_format() << " vs " << static_cast<int>(output->second.data_format()) << " vs "
<< output_tensor->data_format(); << static_cast<int>(output_tensor->data_format());
} }
VLOG(1) << "Transform output " << output->first << " from " VLOG(1) << "Transform output " << output->first << " from "
<< output_tensor->data_format() << " to " << static_cast<int>(output_tensor->data_format()) << " to "
<< output->second.data_format(); << static_cast<int>(output->second.data_format());
std::vector<index_t> shape = std::vector<index_t> shape =
TransposeShape<index_t, index_t>(output_tensor->shape(), TransposeShape<index_t, index_t>(output_tensor->shape(),
dst_dims); dst_dims);
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "mace/ops/activation.h" #include "mace/ops/activation.h"
#include <memory> #include <memory>
#include <set>
#include "mace/core/operator.h" #include "mace/core/operator.h"
#if defined(MACE_ENABLE_NEON) #if defined(MACE_ENABLE_NEON)
...@@ -94,7 +96,7 @@ class ActivationOp<DeviceType::GPU, T> : public Operation { ...@@ -94,7 +96,7 @@ class ActivationOp<DeviceType::GPU, T> : public Operation {
auto leakyrelu_coefficient = static_cast<T>( auto leakyrelu_coefficient = static_cast<T>(
Operation::GetOptionalArg<float>("leakyrelu_coefficient", 0.0f)); Operation::GetOptionalArg<float>("leakyrelu_coefficient", 0.0f));
MemoryType mem_type; MemoryType mem_type;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::ActivationKernel<T>>( kernel_ = make_unique<opencl::image::ActivationKernel<T>>(
type, relux_max_limit, leakyrelu_coefficient); type, relux_max_limit, leakyrelu_coefficient);
...@@ -132,6 +134,24 @@ void RegisterActivation(OpRegistryBase *op_registry) { ...@@ -132,6 +134,24 @@ void RegisterActivation(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Activation", ActivationOp, MACE_REGISTER_OP(op_registry, "Activation", ActivationOp,
DeviceType::GPU, half); DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("Activation")
.SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
int has_data_format =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "has_data_format", 0);
if (!has_data_format ||
op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU };
}
return { DeviceType::CPU, DeviceType::GPU };
}));
} }
} // namespace ops } // namespace ops
......
...@@ -207,7 +207,8 @@ void TestSimplePrelu() { ...@@ -207,7 +207,8 @@ void TestSimplePrelu() {
// Run // Run
net.RunOp(D); net.RunOp(D);
} else { } else {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<D, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Activation", "PreluTest") OpDefBuilder("Activation", "PreluTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Alpha") .Input("Alpha")
...@@ -217,7 +218,8 @@ void TestSimplePrelu() { ...@@ -217,7 +218,8 @@ void TestSimplePrelu() {
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
auto expected = net.CreateTensor<float>( auto expected = net.CreateTensor<float>(
......
...@@ -67,7 +67,7 @@ class AddNOp<DeviceType::GPU, T> : public Operation { ...@@ -67,7 +67,7 @@ class AddNOp<DeviceType::GPU, T> : public Operation {
public: public:
explicit AddNOp(OpConstructContext *context) explicit AddNOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::AddNKernel<T>>(); kernel_ = make_unique<opencl::image::AddNKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -101,6 +101,24 @@ void RegisterAddN(OpRegistryBase *op_registry) { ...@@ -101,6 +101,24 @@ void RegisterAddN(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "AddN", AddNOp, DeviceType::GPU, half); MACE_REGISTER_OP(op_registry, "AddN", AddNOp, DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("AddN")
.SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
int has_data_format =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "has_data_format", 0);
if (!has_data_format ||
op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU };
}
return { DeviceType::CPU, DeviceType::GPU };
}));
} }
} // namespace ops } // namespace ops
......
...@@ -54,7 +54,7 @@ MaceStatus Deconv2dBase::ResizeOutAndPadOut( ...@@ -54,7 +54,7 @@ MaceStatus Deconv2dBase::ResizeOutAndPadOut(
out_pad_size, out_pad_size,
&padded_out_shape, &padded_out_shape,
framework_type_, framework_type_,
NCHW); DataFormat::NCHW);
MACE_RETURN_IF_ERROR(output->Resize(out_shape)); MACE_RETURN_IF_ERROR(output->Resize(out_shape));
......
...@@ -174,7 +174,7 @@ class BatchNormOp<DeviceType::GPU, T> : public Operation { ...@@ -174,7 +174,7 @@ class BatchNormOp<DeviceType::GPU, T> : public Operation {
float leakyrelu_coefficient = Operation::GetOptionalArg<float>( float leakyrelu_coefficient = Operation::GetOptionalArg<float>(
"leakyrelu_coefficient", 0.0f); "leakyrelu_coefficient", 0.0f);
MemoryType mem_type; MemoryType mem_type;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::BatchNormKernel<T>>( kernel_ = make_unique<opencl::image::BatchNormKernel<T>>(
epsilon, activation, relux_max_limit, leakyrelu_coefficient); epsilon, activation, relux_max_limit, leakyrelu_coefficient);
......
...@@ -34,7 +34,8 @@ void Simple() { ...@@ -34,7 +34,8 @@ void Simple() {
net.AddInputFromArray<D, float>("Var", {1}, {11.67f}, true); net.AddInputFromArray<D, float>("Var", {1}, {11.67f}, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<D, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Scale") .Input("Scale")
...@@ -47,7 +48,8 @@ void Simple() { ...@@ -47,7 +48,8 @@ void Simple() {
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input") .Input("Input")
...@@ -93,8 +95,8 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { ...@@ -93,8 +95,8 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
...@@ -112,8 +114,8 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) { ...@@ -112,8 +114,8 @@ TEST_F(BatchNormOpTest, SimpleRandomOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -163,8 +165,8 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) { ...@@ -163,8 +165,8 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -179,8 +181,8 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) { ...@@ -179,8 +181,8 @@ TEST_F(BatchNormOpTest, SimpleRandomHalfOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -230,8 +232,8 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) { ...@@ -230,8 +232,8 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -246,8 +248,8 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) { ...@@ -246,8 +248,8 @@ TEST_F(BatchNormOpTest, ComplexRandomOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -296,8 +298,8 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -296,8 +298,8 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Mean", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Var", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "BatchNormTest") OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -312,8 +314,8 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -312,8 +314,8 @@ TEST_F(BatchNormOpTest, ComplexRandomHalfOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
......
...@@ -264,7 +264,7 @@ class BatchToSpaceNDOp<DeviceType::GPU, T> : public BatchToSpaceOpBase { ...@@ -264,7 +264,7 @@ class BatchToSpaceNDOp<DeviceType::GPU, T> : public BatchToSpaceOpBase {
public: public:
explicit BatchToSpaceNDOp(OpConstructContext *context) explicit BatchToSpaceNDOp(OpConstructContext *context)
: BatchToSpaceOpBase(context) { : BatchToSpaceOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::BatchToSpaceKernel<T>>(); kernel_ = make_unique<opencl::image::BatchToSpaceKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
......
...@@ -103,7 +103,7 @@ class BiasAddOp<DeviceType::GPU, T> : public Operation { ...@@ -103,7 +103,7 @@ class BiasAddOp<DeviceType::GPU, T> : public Operation {
: Operation(context), : Operation(context),
has_data_format_(Operation::GetOptionalArg<int>("has_data_format", 1)) { has_data_format_(Operation::GetOptionalArg<int>("has_data_format", 1)) {
MemoryType mem_type = MemoryType::CPU_BUFFER; MemoryType mem_type = MemoryType::CPU_BUFFER;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::BiasAddKernel<T>>(); kernel_ = make_unique<opencl::image::BiasAddKernel<T>>();
} else { } else {
...@@ -145,6 +145,24 @@ void RegisterBiasAdd(OpRegistryBase *op_registry) { ...@@ -145,6 +145,24 @@ void RegisterBiasAdd(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "BiasAdd", BiasAddOp, MACE_REGISTER_OP(op_registry, "BiasAdd", BiasAddOp,
DeviceType::GPU, half); DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("BiasAdd")
.SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
int has_data_format =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "has_data_format", 0);
if (!has_data_format ||
op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU };
}
return { DeviceType::CPU, DeviceType::GPU };
}));
} }
} // namespace ops } // namespace ops
......
...@@ -27,9 +27,7 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) { ...@@ -27,9 +27,7 @@ void BiasAdd(int iters, int batch, int channels, int height, int width) {
OpsTestNet net; OpsTestNet net;
// Add input data // Add input data
DataFormat data_format = NHWC;
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
data_format = NCHW;
net.AddRandomInput<D, T>("Input", {batch, channels, height, width}); net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
net.AddRandomInput<D, T>("Input", {batch, height, width, channels}); net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
......
...@@ -31,8 +31,8 @@ void BiasAddSimple() { ...@@ -31,8 +31,8 @@ void BiasAddSimple() {
net.AddInputFromArray<D, float>("Bias", {1}, {0.5f}, true); net.AddInputFromArray<D, float>("Bias", {1}, {0.5f}, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Bias") .Input("Bias")
...@@ -41,8 +41,8 @@ void BiasAddSimple() { ...@@ -41,8 +41,8 @@ void BiasAddSimple() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
.Input("Input") .Input("Input")
...@@ -83,8 +83,8 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) { ...@@ -83,8 +83,8 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) {
{batch, height, width, channels}); {batch, height, width, channels});
net.AddRandomInput<DeviceType::GPU, float>("Bias", {channels}, true, true); net.AddRandomInput<DeviceType::GPU, float>("Bias", {channels}, true, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
...@@ -97,8 +97,8 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) { ...@@ -97,8 +97,8 @@ TEST_F(BiasAddOpTest, SimpleRandomOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -132,8 +132,8 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) { ...@@ -132,8 +132,8 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) {
{batch, height, width, channels}); {batch, height, width, channels});
net.AddRandomInput<DeviceType::GPU, float>("Bias", {channels}, true, true); net.AddRandomInput<DeviceType::GPU, float>("Bias", {channels}, true, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("BiasAdd", "BiasAddTest") OpDefBuilder("BiasAdd", "BiasAddTest")
...@@ -146,8 +146,8 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) { ...@@ -146,8 +146,8 @@ TEST_F(BiasAddOpTest, ComplexRandomOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
......
...@@ -48,7 +48,6 @@ void FilterBufferToImage(int iters, ...@@ -48,7 +48,6 @@ void FilterBufferToImage(int iters,
OpenCLBufferType::IN_OUT_CHANNEL, OpenCLBufferType::IN_OUT_CHANNEL,
MemoryType::GPU_IMAGE, MemoryType::GPU_IMAGE,
0, 0,
DataFormat::NHWC,
b2i_output); b2i_output);
}; };
......
...@@ -37,14 +37,14 @@ void TestBidirectionTransform(const OpenCLBufferType type, ...@@ -37,14 +37,14 @@ void TestBidirectionTransform(const OpenCLBufferType type,
OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE) OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE)
.Transform(&context, net.ws()->GetTensor("Input"), .Transform(&context, net.ws()->GetTensor("Input"),
type, MemoryType::GPU_IMAGE, 0, DataFormat::NHWC, b2i_output); type, MemoryType::GPU_IMAGE, 0, b2i_output);
// Inverse Transform // Inverse Transform
Tensor *i2b_output = net.ws()->CreateTensor( Tensor *i2b_output = net.ws()->CreateTensor(
"I2BOutput", context.device()->allocator(), DataTypeToEnum<T>::value); "I2BOutput", context.device()->allocator(), DataTypeToEnum<T>::value);
OpenCLBufferTransformer<T>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER) OpenCLBufferTransformer<T>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER)
.Transform(&context, b2i_output, .Transform(&context, b2i_output,
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, i2b_output); type, MemoryType::GPU_BUFFER, 0, i2b_output);
// Check // Check
ExpectTensorNear<T>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), ExpectTensorNear<T>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"),
...@@ -178,14 +178,14 @@ void TestDiffTypeBidirectionTransform(const OpenCLBufferType type, ...@@ -178,14 +178,14 @@ void TestDiffTypeBidirectionTransform(const OpenCLBufferType type,
OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE) OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE)
.Transform(&context, net.ws()->GetTensor("Input"), .Transform(&context, net.ws()->GetTensor("Input"),
type, MemoryType::GPU_IMAGE, 0, DataFormat::NHWC, b2i_output); type, MemoryType::GPU_IMAGE, 0, b2i_output);
// Inverse Transform // Inverse Transform
Tensor *i2b_output = net.ws()->CreateTensor( Tensor *i2b_output = net.ws()->CreateTensor(
"I2BOutput", context.device()->allocator(), DT_FLOAT); "I2BOutput", context.device()->allocator(), DT_FLOAT);
OpenCLBufferTransformer<float>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER) OpenCLBufferTransformer<float>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER)
.Transform(&context, b2i_output, .Transform(&context, b2i_output,
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, i2b_output); type, MemoryType::GPU_BUFFER, 0, i2b_output);
// Check // Check
ExpectTensorNear<float>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), ExpectTensorNear<float>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"),
...@@ -218,14 +218,14 @@ void TestStringHalfBidirectionTransform(const OpenCLBufferType type, ...@@ -218,14 +218,14 @@ void TestStringHalfBidirectionTransform(const OpenCLBufferType type,
// Transform // Transform
OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE) OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE)
.Transform(&context, net.ws()->GetTensor("Input"), .Transform(&context, net.ws()->GetTensor("Input"),
type, MemoryType::GPU_IMAGE, 0, DataFormat::NHWC, b2i_output); type, MemoryType::GPU_IMAGE, 0, b2i_output);
// Inverse Transform // Inverse Transform
Tensor *i2b_output = net.ws()->CreateTensor( Tensor *i2b_output = net.ws()->CreateTensor(
"I2BOutput", context.device()->allocator(), DataTypeToEnum<T>::value); "I2BOutput", context.device()->allocator(), DataTypeToEnum<T>::value);
OpenCLBufferTransformer<T>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER) OpenCLBufferTransformer<T>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER)
.Transform(&context, b2i_output, .Transform(&context, b2i_output,
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, i2b_output); type, MemoryType::GPU_BUFFER, 0, i2b_output);
// Check // Check
ExpectTensorNear<half>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"), ExpectTensorNear<half>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"),
......
...@@ -39,14 +39,11 @@ class BufferTransformOp<DeviceType::GPU, T> : public Operation { ...@@ -39,14 +39,11 @@ class BufferTransformOp<DeviceType::GPU, T> : public Operation {
auto type = auto type =
static_cast<OpenCLBufferType>(Operation::GetOptionalArg<int>( static_cast<OpenCLBufferType>(Operation::GetOptionalArg<int>(
"buffer_type", static_cast<int>(CONV2D_FILTER))); "buffer_type", static_cast<int>(CONV2D_FILTER)));
bool has_data_format = Operation::GetOptionalArg<int>("has_data_format", 0)
!= 0;
MemoryType in_mem_type = context->workspace()->GetTensor( MemoryType in_mem_type = context->workspace()->GetTensor(
operator_def_->input(0))->memory_type(); operator_def_->input(0))->memory_type();
return OpenCLBufferTransformer<T>(in_mem_type, out_mem_type_).Transform( return OpenCLBufferTransformer<T>(in_mem_type, out_mem_type_).Transform(
context, input, type, out_mem_type_, wino_blk_size_, context, input, type, out_mem_type_, wino_blk_size_, output);
has_data_format, output);
} }
private: private:
......
...@@ -48,7 +48,7 @@ void TestBidirectionTransform(const OpenCLBufferType type, ...@@ -48,7 +48,7 @@ void TestBidirectionTransform(const OpenCLBufferType type,
OpenCLBufferTransformer<DstType>(MemoryType::GPU_BUFFER, OpenCLBufferTransformer<DstType>(MemoryType::GPU_BUFFER,
MemoryType::GPU_BUFFER) MemoryType::GPU_BUFFER)
.Transform(&context, net.ws()->GetTensor("Input"), .Transform(&context, net.ws()->GetTensor("Input"),
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, bt_output); type, MemoryType::GPU_BUFFER, 0, bt_output);
// Inverse Transform // Inverse Transform
Tensor *output = net.ws()->CreateTensor( Tensor *output = net.ws()->CreateTensor(
...@@ -57,7 +57,7 @@ void TestBidirectionTransform(const OpenCLBufferType type, ...@@ -57,7 +57,7 @@ void TestBidirectionTransform(const OpenCLBufferType type,
OpenCLBufferTransformer<OrgType>(MemoryType::GPU_BUFFER, OpenCLBufferTransformer<OrgType>(MemoryType::GPU_BUFFER,
MemoryType::GPU_BUFFER) MemoryType::GPU_BUFFER)
.Transform(&context, bt_output, .Transform(&context, bt_output,
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, output); type, MemoryType::GPU_BUFFER, 0, output);
if (DataTypeToEnum<OrgType>::value == DataTypeToEnum<DstType>::value) { if (DataTypeToEnum<OrgType>::value == DataTypeToEnum<DstType>::value) {
EXPECT_EQ(net.GetOutput("Input")->UnderlyingBuffer(), EXPECT_EQ(net.GetOutput("Input")->UnderlyingBuffer(),
...@@ -94,7 +94,7 @@ void TestArgumentTransform(const index_t input_size) { ...@@ -94,7 +94,7 @@ void TestArgumentTransform(const index_t input_size) {
MemoryType::GPU_BUFFER) MemoryType::GPU_BUFFER)
.Transform(&context, net.ws()->GetTensor("Input"), .Transform(&context, net.ws()->GetTensor("Input"),
OpenCLBufferType::ARGUMENT, MemoryType::GPU_BUFFER, OpenCLBufferType::ARGUMENT, MemoryType::GPU_BUFFER,
0, DataFormat::NHWC, output); 0, output);
index_t expected_size = RoundUp<index_t>(input_size, 4); index_t expected_size = RoundUp<index_t>(input_size, 4);
EXPECT_EQ(expected_size, output->buffer_shape()[0]); EXPECT_EQ(expected_size, output->buffer_shape()[0]);
......
...@@ -82,7 +82,7 @@ class ChannelShuffleOp<DeviceType::GPU, T> : public Operation { ...@@ -82,7 +82,7 @@ class ChannelShuffleOp<DeviceType::GPU, T> : public Operation {
explicit ChannelShuffleOp(OpConstructContext *context) explicit ChannelShuffleOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
const int groups = Operation::GetOptionalArg<int>("group", 1); const int groups = Operation::GetOptionalArg<int>("group", 1);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ChannelShuffleKernel<T>>(groups); kernel_ = make_unique<opencl::image::ChannelShuffleKernel<T>>(groups);
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -116,7 +116,7 @@ void RegisterChannelShuffle(OpRegistryBase *op_registry) { ...@@ -116,7 +116,7 @@ void RegisterChannelShuffle(OpRegistryBase *op_registry) {
op_registry, op_registry,
OpConditionBuilder("ChannelShuffle") OpConditionBuilder("ChannelShuffle")
.SetDevicePlacerFunc( .SetDevicePlacerFunc(
[](OpConstructContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) { if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU }; return { DeviceType::CPU, DeviceType::GPU };
......
...@@ -28,8 +28,8 @@ TEST_F(ChannelShuffleOpTest, C8G4_CPU) { ...@@ -28,8 +28,8 @@ TEST_F(ChannelShuffleOpTest, C8G4_CPU) {
"Input", {1, 1, 2, 8}, "Input", {1, 1, 2, 8},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("ChannelShuffle", "ChannelShuffleTest") OpDefBuilder("ChannelShuffle", "ChannelShuffleTest")
...@@ -40,8 +40,8 @@ TEST_F(ChannelShuffleOpTest, C8G4_CPU) { ...@@ -40,8 +40,8 @@ TEST_F(ChannelShuffleOpTest, C8G4_CPU) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>( auto expected = net.CreateTensor<float>(
......
...@@ -40,19 +40,19 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, ...@@ -40,19 +40,19 @@ void CalcPaddingAndOutputSize(const index_t *input_shape,
index_t input_height = 0, input_width = 0; index_t input_height = 0, input_width = 0;
index_t kernel_height = 0, kernel_width = 0; index_t kernel_height = 0, kernel_width = 0;
if (input_format == NCHW) { if (input_format == DataFormat::NCHW) {
input_height = input_shape[2]; input_height = input_shape[2];
input_width = input_shape[3]; input_width = input_shape[3];
} else if (input_format == NHWC) { } else if (input_format == DataFormat::NHWC) {
input_height = input_shape[1]; input_height = input_shape[1];
input_width = input_shape[2]; input_width = input_shape[2];
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
if (filter_format == OIHW) { if (filter_format == DataFormat::OIHW) {
kernel_height = filter_shape[2]; kernel_height = filter_shape[2];
kernel_width = filter_shape[3]; kernel_width = filter_shape[3];
} else if (filter_format == OHWI) { } else if (filter_format == DataFormat::OHWI) {
kernel_height = filter_shape[1]; kernel_height = filter_shape[1];
kernel_width = filter_shape[2]; kernel_width = filter_shape[2];
} else { } else {
...@@ -97,11 +97,11 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, ...@@ -97,11 +97,11 @@ void CalcPaddingAndOutputSize(const index_t *input_shape,
0, (output_width - 1) * strides[1] + k_extent_width - input_width); 0, (output_width - 1) * strides[1] + k_extent_width - input_width);
output_shape[0] = input_shape[0]; output_shape[0] = input_shape[0];
if (input_format == NCHW) { if (input_format == DataFormat::NCHW) {
output_shape[1] = output_channels; output_shape[1] = output_channels;
output_shape[2] = output_height; output_shape[2] = output_height;
output_shape[3] = output_width; output_shape[3] = output_width;
} else if (input_format == NHWC) { } else if (input_format == DataFormat::NHWC) {
output_shape[1] = output_height; output_shape[1] = output_height;
output_shape[2] = output_width; output_shape[2] = output_width;
output_shape[3] = output_channels; output_shape[3] = output_channels;
...@@ -117,7 +117,8 @@ void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW ...@@ -117,7 +117,8 @@ void CalcNCHWPaddingAndOutputSize(const index_t *input_shape, // NCHW
Padding padding, Padding padding,
index_t *output_shape, index_t *output_shape,
int *padding_size) { int *padding_size) {
CalcPaddingAndOutputSize(input_shape, NCHW, filter_shape, OIHW, dilations, CalcPaddingAndOutputSize(input_shape, DataFormat::NCHW, filter_shape,
DataFormat::OIHW, dilations,
strides, padding, output_shape, padding_size); strides, padding, output_shape, padding_size);
} }
...@@ -128,7 +129,8 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC ...@@ -128,7 +129,8 @@ void CalcNHWCPaddingAndOutputSize(const index_t *input_shape, // NHWC
Padding padding, Padding padding,
index_t *output_shape, index_t *output_shape,
int *padding_size) { int *padding_size) {
CalcPaddingAndOutputSize(input_shape, NHWC, filter_shape, OIHW, dilations, CalcPaddingAndOutputSize(input_shape, DataFormat::NHWC, filter_shape,
DataFormat::OIHW, dilations,
strides, padding, output_shape, padding_size); strides, padding, output_shape, padding_size);
} }
...@@ -151,19 +153,19 @@ void CalcOutputSize(const index_t *input_shape, ...@@ -151,19 +153,19 @@ void CalcOutputSize(const index_t *input_shape,
index_t input_height = 0, input_width = 0; index_t input_height = 0, input_width = 0;
index_t kernel_height = 0, kernel_width = 0; index_t kernel_height = 0, kernel_width = 0;
if (input_format == NCHW) { if (input_format == DataFormat::NCHW) {
input_height = input_shape[2]; input_height = input_shape[2];
input_width = input_shape[3]; input_width = input_shape[3];
} else if (input_format == NHWC) { } else if (input_format == DataFormat::NHWC) {
input_height = input_shape[1]; input_height = input_shape[1];
input_width = input_shape[2]; input_width = input_shape[2];
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
} }
if (filter_format == OIHW) { if (filter_format == DataFormat::OIHW) {
kernel_height = filter_shape[2]; kernel_height = filter_shape[2];
kernel_width = filter_shape[3]; kernel_width = filter_shape[3];
} else if (filter_format == OHWI) { } else if (filter_format == DataFormat::OHWI) {
kernel_height = filter_shape[1]; kernel_height = filter_shape[1];
kernel_width = filter_shape[2]; kernel_width = filter_shape[2];
} else { } else {
...@@ -195,11 +197,11 @@ void CalcOutputSize(const index_t *input_shape, ...@@ -195,11 +197,11 @@ void CalcOutputSize(const index_t *input_shape,
} }
output_shape[0] = input_shape[0]; output_shape[0] = input_shape[0];
if (input_format == NCHW) { if (input_format == DataFormat::NCHW) {
output_shape[1] = output_channels; output_shape[1] = output_channels;
output_shape[2] = output_height; output_shape[2] = output_height;
output_shape[3] = output_width; output_shape[3] = output_width;
} else if (input_format == NHWC) { } else if (input_format == DataFormat::NHWC) {
output_shape[1] = output_height; output_shape[1] = output_height;
output_shape[2] = output_width; output_shape[2] = output_width;
output_shape[3] = output_channels; output_shape[3] = output_channels;
...@@ -215,7 +217,8 @@ void CalcOutputSize(const index_t *input_shape, // NHWC ...@@ -215,7 +217,8 @@ void CalcOutputSize(const index_t *input_shape, // NHWC
const int *strides, const int *strides,
const RoundType round_type, const RoundType round_type,
index_t *output_shape) { index_t *output_shape) {
CalcOutputSize(input_shape, NHWC, filter_shape, OIHW, padding_size, dilations, CalcOutputSize(input_shape, DataFormat::NHWC, filter_shape,
DataFormat::OIHW, padding_size, dilations,
strides, round_type, output_shape); strides, round_type, output_shape);
} }
...@@ -226,7 +229,8 @@ void CalcNCHWOutputSize(const index_t *input_shape, // NCHW ...@@ -226,7 +229,8 @@ void CalcNCHWOutputSize(const index_t *input_shape, // NCHW
const int *strides, const int *strides,
const RoundType round_type, const RoundType round_type,
index_t *output_shape) { index_t *output_shape) {
CalcOutputSize(input_shape, NCHW, filter_shape, OIHW, padding_size, dilations, CalcOutputSize(input_shape, DataFormat::NCHW, filter_shape,
DataFormat::OIHW, padding_size, dilations,
strides, round_type, output_shape); strides, round_type, output_shape);
} }
...@@ -241,14 +245,18 @@ void CalcDeconvShape_TF(const std::vector<index_t> &input_shape, ...@@ -241,14 +245,18 @@ void CalcDeconvShape_TF(const std::vector<index_t> &input_shape,
std::vector<index_t> *padded_out_shape, std::vector<index_t> *padded_out_shape,
DataFormat data_format) { DataFormat data_format) {
const index_t const index_t
in_height = data_format == NCHW ? input_shape[2] : input_shape[1]; in_height =
data_format == DataFormat::NCHW ? input_shape[2] : input_shape[1];
const index_t const index_t
in_width = data_format == NCHW ? input_shape[3] : input_shape[2]; in_width =
data_format == DataFormat::NCHW ? input_shape[3] : input_shape[2];
const index_t const index_t
out_height = data_format == NCHW ? output_shape[2] : output_shape[1]; out_height =
data_format == DataFormat::NCHW ? output_shape[2] : output_shape[1];
const index_t const index_t
out_width = data_format == NCHW ? output_shape[3] : output_shape[2]; out_width =
data_format == DataFormat::NCHW ? output_shape[3] : output_shape[2];
const index_t extended_in_height = (in_height - 1) * strides[0] + 1; const index_t extended_in_height = (in_height - 1) * strides[0] + 1;
const index_t extended_in_width = (in_width - 1) * strides[1] + 1; const index_t extended_in_width = (in_width - 1) * strides[1] + 1;
...@@ -307,11 +315,11 @@ void CalcDeconvShape_TF(const std::vector<index_t> &input_shape, ...@@ -307,11 +315,11 @@ void CalcDeconvShape_TF(const std::vector<index_t> &input_shape,
padded_out_shape->resize(4); padded_out_shape->resize(4);
(*padded_out_shape)[0] = output_shape[0]; (*padded_out_shape)[0] = output_shape[0];
(*padded_out_shape)[1] = (*padded_out_shape)[1] =
data_format == NCHW ? output_channel : padded_out_height; data_format == DataFormat::NCHW ? output_channel : padded_out_height;
(*padded_out_shape)[2] = (*padded_out_shape)[2] =
data_format == NCHW ? padded_out_height : padded_out_width; data_format == DataFormat::NCHW ? padded_out_height : padded_out_width;
(*padded_out_shape)[3] = (*padded_out_shape)[3] =
data_format == NCHW ? padded_out_width : output_channel; data_format == DataFormat::NCHW ? padded_out_width : output_channel;
} }
} }
...@@ -325,9 +333,11 @@ void CalcDeconvShape_Caffe(const std::vector<index_t> &input_shape, ...@@ -325,9 +333,11 @@ void CalcDeconvShape_Caffe(const std::vector<index_t> &input_shape,
std::vector<index_t> *padded_out_shape, std::vector<index_t> *padded_out_shape,
DataFormat data_format) { DataFormat data_format) {
const index_t const index_t
in_height = data_format == NCHW ? input_shape[2] : input_shape[1]; in_height =
data_format == DataFormat::NCHW ? input_shape[2] : input_shape[1];
const index_t const index_t
in_width = data_format == NCHW ? input_shape[3] : input_shape[2]; in_width =
data_format == DataFormat::NCHW ? input_shape[3] : input_shape[2];
const index_t output_channel = filter_shape[0] * group; const index_t output_channel = filter_shape[0] * group;
...@@ -351,11 +361,11 @@ void CalcDeconvShape_Caffe(const std::vector<index_t> &input_shape, ...@@ -351,11 +361,11 @@ void CalcDeconvShape_Caffe(const std::vector<index_t> &input_shape,
padded_out_shape->resize(4); padded_out_shape->resize(4);
(*padded_out_shape)[0] = input_shape[0]; (*padded_out_shape)[0] = input_shape[0];
(*padded_out_shape)[1] = (*padded_out_shape)[1] =
data_format == NCHW ? output_channel : padded_out_height; data_format == DataFormat::NCHW ? output_channel : padded_out_height;
(*padded_out_shape)[2] = (*padded_out_shape)[2] =
data_format == NCHW ? padded_out_height : padded_out_width; data_format == DataFormat::NCHW ? padded_out_height : padded_out_width;
(*padded_out_shape)[3] = (*padded_out_shape)[3] =
data_format == NCHW ? padded_out_width : output_channel; data_format == DataFormat::NCHW ? padded_out_width : output_channel;
} }
if (out_shape != nullptr) { if (out_shape != nullptr) {
...@@ -363,9 +373,11 @@ void CalcDeconvShape_Caffe(const std::vector<index_t> &input_shape, ...@@ -363,9 +373,11 @@ void CalcDeconvShape_Caffe(const std::vector<index_t> &input_shape,
index_t out_width = padded_out_width - out_pad_size[1]; index_t out_width = padded_out_width - out_pad_size[1];
out_shape->resize(4); out_shape->resize(4);
(*out_shape)[0] = input_shape[0]; (*out_shape)[0] = input_shape[0];
(*out_shape)[1] = data_format == NCHW ? output_channel : out_height; (*out_shape)[1] =
(*out_shape)[2] = data_format == NCHW ? out_height : out_width; data_format == DataFormat::NCHW ? output_channel : out_height;
(*out_shape)[3] = data_format == NCHW ? out_width : output_channel; (*out_shape)[2] = data_format == DataFormat::NCHW ? out_height : out_width;
(*out_shape)[3] =
data_format == DataFormat::NCHW ? out_width : output_channel;
} }
} }
...@@ -385,7 +397,7 @@ void CalDeconvOutputShapeAndPadSize(const std::vector<index_t> &input_shape, ...@@ -385,7 +397,7 @@ void CalDeconvOutputShapeAndPadSize(const std::vector<index_t> &input_shape,
MACE_CHECK(output_shape->size() == 4, MACE_CHECK(output_shape->size() == 4,
"deconv output shape shoud be 4-dims"); "deconv output shape shoud be 4-dims");
std::vector<index_t> &out_shape = *output_shape; std::vector<index_t> &out_shape = *output_shape;
if (data_format == NCHW) { if (data_format == DataFormat::NCHW) {
const index_t t = out_shape[1]; const index_t t = out_shape[1];
out_shape[1] = out_shape[3]; out_shape[1] = out_shape[3];
out_shape[3] = out_shape[2]; out_shape[3] = out_shape[2];
......
...@@ -199,7 +199,7 @@ class ConcatOp<DeviceType::GPU, T> : public ConcatOpBase { ...@@ -199,7 +199,7 @@ class ConcatOp<DeviceType::GPU, T> : public ConcatOpBase {
public: public:
explicit ConcatOp(OpConstructContext *context) explicit ConcatOp(OpConstructContext *context)
: ConcatOpBase(context) { : ConcatOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ConcatKernel<T>>(); kernel_ = make_unique<opencl::image::ConcatKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -241,12 +241,12 @@ void RegisterConcat(OpRegistryBase *op_registry) { ...@@ -241,12 +241,12 @@ void RegisterConcat(OpRegistryBase *op_registry) {
op_registry, op_registry,
OpConditionBuilder("Concat") OpConditionBuilder("Concat")
.SetDevicePlacerFunc( .SetDevicePlacerFunc(
[](OpConstructContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
auto tensor_shape_info = context->tensor_shape_info();
if (op->output_shape_size() != op->output_size()) { if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU }; return { DeviceType::CPU, DeviceType::GPU };
} }
auto tensor_shape_info = context->tensor_shape_info();
if (op->output_shape(0).dims_size() != 4) { if (op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU }; return { DeviceType::CPU };
} else { } else {
......
...@@ -231,9 +231,9 @@ class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase { ...@@ -231,9 +231,9 @@ class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase {
std::vector<int> paddings(2); std::vector<int> paddings(2);
if (paddings_.empty()) { if (paddings_.empty()) {
CalcPaddingAndOutputSize(input->shape().data(), CalcPaddingAndOutputSize(input->shape().data(),
NHWC, DataFormat::NHWC,
filter->shape().data(), filter->shape().data(),
OHWI, DataFormat::OHWI,
dilations_.data(), dilations_.data(),
strides_.data(), strides_.data(),
padding_type_, padding_type_,
...@@ -242,9 +242,9 @@ class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase { ...@@ -242,9 +242,9 @@ class Conv2dOp<DeviceType::CPU, uint8_t> : public ConvPool2dOpBase {
} else { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), CalcOutputSize(input->shape().data(),
NHWC, DataFormat::NHWC,
filter->shape().data(), filter->shape().data(),
OHWI, DataFormat::OHWI,
paddings_.data(), paddings_.data(),
dilations_.data(), dilations_.data(),
strides_.data(), strides_.data(),
...@@ -459,14 +459,13 @@ class Conv2dOp<DeviceType::GPU, T> : public ConvPool2dOpBase { ...@@ -459,14 +459,13 @@ class Conv2dOp<DeviceType::GPU, T> : public ConvPool2dOpBase {
"leakyrelu_coefficient", 0.0f)), "leakyrelu_coefficient", 0.0f)),
wino_block_size_(Operation::GetOptionalArg<int>("wino_block_size", 0)) { wino_block_size_(Operation::GetOptionalArg<int>("wino_block_size", 0)) {
MemoryType mem_type; MemoryType mem_type;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::Conv2dKernel<T>>(); kernel_ = make_unique<opencl::image::Conv2dKernel<T>>();
} else { } else {
mem_type = MemoryType::GPU_BUFFER; mem_type = MemoryType::GPU_BUFFER;
kernel_ = make_unique<opencl::buffer::Conv2dKernel<T>>(); kernel_ = make_unique<opencl::buffer::Conv2dKernel<T>>();
} }
context->set_output_mem_type(mem_type);
// Transform filter tensor to target format // Transform filter tensor to target format
if ((wino_block_size_ == 2 || wino_block_size_ == 4) && if ((wino_block_size_ == 2 || wino_block_size_ == 4) &&
(kernel_->CheckUseWinograd( (kernel_->CheckUseWinograd(
......
...@@ -47,8 +47,8 @@ void TestNHWCSimple3x3VALID(int wino_blk_size = 0) { ...@@ -47,8 +47,8 @@ void TestNHWCSimple3x3VALID(int wino_blk_size = 0) {
const std::vector<index_t> output_shape = {1, 1, 1, 1}; const std::vector<index_t> output_shape = {1, 1, 1, 1};
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -60,8 +60,8 @@ void TestNHWCSimple3x3VALID(int wino_blk_size = 0) { ...@@ -60,8 +60,8 @@ void TestNHWCSimple3x3VALID(int wino_blk_size = 0) {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
...@@ -105,8 +105,8 @@ void TestNHWCSimple3x3SAME(int wino_blk_size = 0) { ...@@ -105,8 +105,8 @@ void TestNHWCSimple3x3SAME(int wino_blk_size = 0) {
const std::vector<index_t> output_shape = {1, 3, 3, 1}; const std::vector<index_t> output_shape = {1, 3, 3, 1};
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -118,8 +118,8 @@ void TestNHWCSimple3x3SAME(int wino_blk_size = 0) { ...@@ -118,8 +118,8 @@ void TestNHWCSimple3x3SAME(int wino_blk_size = 0) {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
...@@ -189,8 +189,8 @@ void TestNHWCSimple3x3WithoutBias() { ...@@ -189,8 +189,8 @@ void TestNHWCSimple3x3WithoutBias() {
1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, true); 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -203,8 +203,8 @@ void TestNHWCSimple3x3WithoutBias() { ...@@ -203,8 +203,8 @@ void TestNHWCSimple3x3WithoutBias() {
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("Input") .Input("Input")
...@@ -256,8 +256,8 @@ void TestNHWCCombined3x3() { ...@@ -256,8 +256,8 @@ void TestNHWCCombined3x3() {
net.AddInputFromArray<D, T>("Bias", {2}, {0.1f, 0.2f}, true); net.AddInputFromArray<D, T>("Bias", {2}, {0.1f, 0.2f}, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -270,8 +270,8 @@ void TestNHWCCombined3x3() { ...@@ -270,8 +270,8 @@ void TestNHWCCombined3x3() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
...@@ -321,8 +321,8 @@ void TestFusedNHWCSimple3x3VALID(int wino_blk_size = 0) { ...@@ -321,8 +321,8 @@ void TestFusedNHWCSimple3x3VALID(int wino_blk_size = 0) {
const std::vector<index_t> output_shape = {1, 1, 1, 1}; const std::vector<index_t> output_shape = {1, 1, 1, 1};
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -336,8 +336,8 @@ void TestFusedNHWCSimple3x3VALID(int wino_blk_size = 0) { ...@@ -336,8 +336,8 @@ void TestFusedNHWCSimple3x3VALID(int wino_blk_size = 0) {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
...@@ -376,8 +376,8 @@ void TestFusedNHWCSimple3x3WithoutBias(int wino_blk_size = 0) { ...@@ -376,8 +376,8 @@ void TestFusedNHWCSimple3x3WithoutBias(int wino_blk_size = 0) {
const std::vector<index_t> output_shape = {1, 1, 1, 1}; const std::vector<index_t> output_shape = {1, 1, 1, 1};
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -391,8 +391,8 @@ void TestFusedNHWCSimple3x3WithoutBias(int wino_blk_size = 0) { ...@@ -391,8 +391,8 @@ void TestFusedNHWCSimple3x3WithoutBias(int wino_blk_size = 0) {
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
...@@ -459,8 +459,8 @@ void TestConv1x1() { ...@@ -459,8 +459,8 @@ void TestConv1x1() {
net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f}, true); net.AddInputFromArray<D, float>("Bias", {2}, {0.1f, 0.2f}, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -472,8 +472,8 @@ void TestConv1x1() { ...@@ -472,8 +472,8 @@ void TestConv1x1() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Conv2D", "Conv2DTest") OpDefBuilder("Conv2D", "Conv2DTest")
.Input("Input") .Input("Input")
...@@ -532,8 +532,8 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape, ...@@ -532,8 +532,8 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape,
"Filter", {output_channels, input_channels, kernel_h, kernel_w}, true, "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true,
false); false);
net.AddRandomInput<D, T>("Bias", {output_channels}, true, false); net.AddRandomInput<D, T>("Bias", {output_channels}, true, false);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
...@@ -552,8 +552,8 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape, ...@@ -552,8 +552,8 @@ void TestComplexConvNxNS12(const std::vector<index_t> &shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -651,8 +651,8 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape, ...@@ -651,8 +651,8 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape,
float_bias_data, float_bias_data,
true); true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -667,8 +667,8 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape, ...@@ -667,8 +667,8 @@ void TestHalfComplexConvNxNS12(const std::vector<index_t> &input_shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -811,8 +811,8 @@ void TestDilationConvNxN(const std::vector<index_t> &shape, ...@@ -811,8 +811,8 @@ void TestDilationConvNxN(const std::vector<index_t> &shape,
"Filter", {output_channels, input_channels, kernel_h, kernel_w}, true); "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true);
net.AddRandomInput<D, T>("Bias", {output_channels}, true); net.AddRandomInput<D, T>("Bias", {output_channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
...@@ -828,8 +828,8 @@ void TestDilationConvNxN(const std::vector<index_t> &shape, ...@@ -828,8 +828,8 @@ void TestDilationConvNxN(const std::vector<index_t> &shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -900,8 +900,8 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape, ...@@ -900,8 +900,8 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape,
"Filter", {output_channels, input_channels, kernel_h, kernel_w}, true); "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true);
net.AddRandomInput<D, float>("Bias", {output_channels}, true); net.AddRandomInput<D, float>("Bias", {output_channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -916,8 +916,8 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape, ...@@ -916,8 +916,8 @@ void TestGeneralHalfAtrousConv(const std::vector<index_t> &image_shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
...@@ -979,8 +979,8 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape, ...@@ -979,8 +979,8 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape,
"Filter", {output_channels, input_channels, kernel_h, kernel_w}, true); "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true);
net.AddRandomInput<D, float>("Bias", {output_channels}, true); net.AddRandomInput<D, float>("Bias", {output_channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Construct graph // Construct graph
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -994,8 +994,8 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape, ...@@ -994,8 +994,8 @@ void TestArbitraryPadConvNxN(const std::vector<index_t> &shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -1118,12 +1118,12 @@ void TestQuant(const index_t batch, ...@@ -1118,12 +1118,12 @@ void TestQuant(const index_t batch,
net.AddRandomInput<CPU, float>("Filter", {out_channels, k_height, k_width, net.AddRandomInput<CPU, float>("Filter", {out_channels, k_height, k_width,
in_channels}, true); in_channels}, true);
net.AddRandomInput<CPU, float>("Bias", {out_channels}, true); net.AddRandomInput<CPU, float>("Bias", {out_channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.TransformFilterDataFormat<DeviceType::CPU, float>("Filter", net.TransformFilterDataFormat<DeviceType::CPU, float>("Filter",
OHWI, DataFormat::OHWI,
"FilterOIHW", "FilterOIHW",
OIHW); DataFormat::OIHW);
OpDefBuilder("Conv2D", "Conv2dTest") OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -1136,8 +1136,8 @@ void TestQuant(const index_t batch, ...@@ -1136,8 +1136,8 @@ void TestQuant(const index_t batch,
.AddIntArg("T", static_cast<int>(DT_FLOAT)) .AddIntArg("T", static_cast<int>(DT_FLOAT))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Quantize", "QuantizeFilter") OpDefBuilder("Quantize", "QuantizeFilter")
.Input("Filter") .Input("Filter")
......
...@@ -117,7 +117,7 @@ class CropOp<DeviceType::GPU, T> : public Operation { ...@@ -117,7 +117,7 @@ class CropOp<DeviceType::GPU, T> : public Operation {
public: public:
explicit CropOp(OpConstructContext *context) explicit CropOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::CropKernel<T>>( kernel_ = make_unique<opencl::image::CropKernel<T>>(
Operation::GetRepeatedArgs<int>("offset")); Operation::GetRepeatedArgs<int>("offset"));
} else { } else {
...@@ -145,6 +145,24 @@ void RegisterCrop(OpRegistryBase *op_registry) { ...@@ -145,6 +145,24 @@ void RegisterCrop(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Crop", CropOp, MACE_REGISTER_OP(op_registry, "Crop", CropOp,
DeviceType::GPU, half); DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("Crop")
.SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
int has_data_format =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "has_data_format", 0);
if (!has_data_format ||
op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU };
}
return { DeviceType::CPU, DeviceType::GPU };
}));
} }
} // namespace ops } // namespace ops
......
...@@ -42,13 +42,13 @@ void RunCrop(const std::vector<index_t> &input_shape, ...@@ -42,13 +42,13 @@ void RunCrop(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else if (D == CPU) { } else if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input0", net.TransformDataFormat<DeviceType::CPU, float>("Input0",
NHWC, DataFormat::NHWC,
"InputNCHW0", "InputNCHW0",
NCHW); DataFormat::NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Input1", net.TransformDataFormat<DeviceType::CPU, float>("Input1",
NHWC, DataFormat::NHWC,
"InputNCHW1", "InputNCHW1",
NCHW); DataFormat::NCHW);
OpDefBuilder("Crop", "CropTest") OpDefBuilder("Crop", "CropTest")
.Input("InputNCHW0") .Input("InputNCHW0")
.Input("InputNCHW1") .Input("InputNCHW1")
...@@ -62,8 +62,8 @@ void RunCrop(const std::vector<index_t> &input_shape, ...@@ -62,8 +62,8 @@ void RunCrop(const std::vector<index_t> &input_shape,
net.RunOp(D); net.RunOp(D);
if (D == CPU) { if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
// Check // Check
auto expected = net.CreateTensor<float>(expected_shape, expected_data); auto expected = net.CreateTensor<float>(expected_shape, expected_data);
......
...@@ -32,8 +32,8 @@ void SimpleTestWithDataFormat(const std::vector<index_t> &shape, ...@@ -32,8 +32,8 @@ void SimpleTestWithDataFormat(const std::vector<index_t> &shape,
OpsTestNet net; OpsTestNet net;
net.AddInputFromArray<CPU, T>("Input", shape, input); net.AddInputFromArray<CPU, T>("Input", shape, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Cumsum", "CumsumTest") OpDefBuilder("Cumsum", "CumsumTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -48,8 +48,8 @@ void SimpleTestWithDataFormat(const std::vector<index_t> &shape, ...@@ -48,8 +48,8 @@ void SimpleTestWithDataFormat(const std::vector<index_t> &shape,
// Run // Run
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
net.AddInputFromArray<CPU, T>("ExpectedOutput", shape, output); net.AddInputFromArray<CPU, T>("ExpectedOutput", shape, output);
ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"), ExpectTensorNear<T>(*net.GetOutput("ExpectedOutput"),
......
...@@ -173,7 +173,7 @@ class Deconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase { ...@@ -173,7 +173,7 @@ class Deconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase {
explicit Deconv2dOp(OpConstructContext *context) explicit Deconv2dOp(OpConstructContext *context)
: Deconv2dOpBase(context) { : Deconv2dOpBase(context) {
MemoryType mem_type = MemoryType::GPU_IMAGE; MemoryType mem_type = MemoryType::GPU_IMAGE;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::Deconv2dKernel<T>>(); kernel_ = make_unique<opencl::image::Deconv2dKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -197,7 +197,6 @@ class Deconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase { ...@@ -197,7 +197,6 @@ class Deconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase {
OpenCLBufferType::ARGUMENT, OpenCLBufferType::ARGUMENT,
mem_type) == MaceStatus::MACE_SUCCESS); mem_type) == MaceStatus::MACE_SUCCESS);
} }
context->SetInputInfo(2, MemoryType::CPU_BUFFER, DataType::DT_INT32);
} }
} }
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
...@@ -241,7 +240,7 @@ class Deconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase { ...@@ -241,7 +240,7 @@ class Deconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase {
&out_paddings, &out_paddings,
nullptr, nullptr,
model_type_, model_type_,
NHWC); DataFormat::NHWC);
return kernel_->Compute(context, input, filter, bias, return kernel_->Compute(context, input, filter, bias,
strides_.data(), in_paddings.data(), activation_, strides_.data(), in_paddings.data(), activation_,
...@@ -264,6 +263,30 @@ void RegisterDeconv2D(OpRegistryBase *op_registry) { ...@@ -264,6 +263,30 @@ void RegisterDeconv2D(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Deconv2D", Deconv2dOp, MACE_REGISTER_OP(op_registry, "Deconv2D", Deconv2dOp,
DeviceType::GPU, half); DeviceType::GPU, half);
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("Deconv2D")
.SetInputMemoryTypeSetter(
[](OpConditionContext *context) -> void {
MemoryType mem_type = MemoryType::CPU_BUFFER;
if (context->device()->device_type() == DeviceType::GPU) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE;
} else {
MACE_NOT_IMPLEMENTED;
}
FrameworkType framework_type =
static_cast<FrameworkType>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*(context->operator_def()), "framework_type",
FrameworkType::TENSORFLOW));
if (framework_type == FrameworkType::TENSORFLOW) {
context->SetInputInfo(2, MemoryType::CPU_BUFFER,
DataType::DT_INT32);
}
}
context->set_output_mem_type(mem_type);
}));
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
} }
......
...@@ -47,7 +47,8 @@ void RunTestSimple(const std::vector<index_t> &input_shape, ...@@ -47,7 +47,8 @@ void RunTestSimple(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data, true); net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data, true);
net.AddInputFromArray<D, float>("Bias", {out_channels}, bias_data, true); net.AddInputFromArray<D, float>("Bias", {out_channels}, bias_data, true);
// TODO(liutuo): remove the unused transform // TODO(liutuo): remove the unused transform
net.TransformFilterDataFormat<D, float>("Filter", HWOI, "FilterOIHW", OIHW); net.TransformFilterDataFormat<D, float>(
"Filter", DataFormat::HWOI, "FilterOIHW", DataFormat::OIHW);
if (D == DeviceType::GPU) { if (D == DeviceType::GPU) {
if (model_type == FrameworkType::CAFFE) { if (model_type == FrameworkType::CAFFE) {
OpDefBuilder("Deconv2D", "Deconv2dTest") OpDefBuilder("Deconv2D", "Deconv2dTest")
...@@ -77,8 +78,8 @@ void RunTestSimple(const std::vector<index_t> &input_shape, ...@@ -77,8 +78,8 @@ void RunTestSimple(const std::vector<index_t> &input_shape,
} }
net.RunOp(D); net.RunOp(D);
} else { } else {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
if (model_type == FrameworkType::CAFFE) { if (model_type == FrameworkType::CAFFE) {
OpDefBuilder("Deconv2D", "Deconv2dTest") OpDefBuilder("Deconv2D", "Deconv2dTest")
...@@ -109,8 +110,8 @@ void RunTestSimple(const std::vector<index_t> &input_shape, ...@@ -109,8 +110,8 @@ void RunTestSimple(const std::vector<index_t> &input_shape,
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
auto expected = net.CreateTensor<float>(expected_shape, expected_data); auto expected = net.CreateTensor<float>(expected_shape, expected_data);
...@@ -380,8 +381,8 @@ void TestComplexDeconvNxN(const int batch, ...@@ -380,8 +381,8 @@ void TestComplexDeconvNxN(const int batch,
"Filter", {output_channels, input_channels, kernel_h, kernel_w}, true, "Filter", {output_channels, input_channels, kernel_h, kernel_w}, true,
false); false);
net.AddRandomInput<D, T>("Bias", {output_channels}, true, false); net.AddRandomInput<D, T>("Bias", {output_channels}, true, false);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
int out_h = 0; int out_h = 0;
int out_w = 0; int out_w = 0;
...@@ -440,8 +441,8 @@ void TestComplexDeconvNxN(const int batch, ...@@ -440,8 +441,8 @@ void TestComplexDeconvNxN(const int batch,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
......
...@@ -96,7 +96,7 @@ class DepthToSpaceOp<DeviceType::GPU, T> : public Operation { ...@@ -96,7 +96,7 @@ class DepthToSpaceOp<DeviceType::GPU, T> : public Operation {
explicit DepthToSpaceOp(OpConstructContext *context) explicit DepthToSpaceOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
int block_size = Operation::GetOptionalArg<int>("block_size", 1); int block_size = Operation::GetOptionalArg<int>("block_size", 1);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::DepthToSpaceKernel<T>>(block_size); kernel_ = make_unique<opencl::image::DepthToSpaceKernel<T>>(block_size);
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
......
...@@ -32,8 +32,8 @@ void RunDepthToSpace(const std::vector<index_t> &input_shape, ...@@ -32,8 +32,8 @@ void RunDepthToSpace(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>("Input", input_shape, input_data); net.AddInputFromArray<D, float>("Input", input_shape, input_data);
// Construct graph // Construct graph
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthToSpace", "DepthToSpaceTest") OpDefBuilder("DepthToSpace", "DepthToSpaceTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -41,8 +41,8 @@ void RunDepthToSpace(const std::vector<index_t> &input_shape, ...@@ -41,8 +41,8 @@ void RunDepthToSpace(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else { } else {
OpDefBuilder("DepthToSpace", "DepthToSpaceTest") OpDefBuilder("DepthToSpace", "DepthToSpaceTest")
...@@ -114,8 +114,8 @@ void RandomTest(const int block_size, ...@@ -114,8 +114,8 @@ void RandomTest(const int block_size,
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", shape); net.AddRandomInput<D, float>("Input", shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthToSpace", "DepthToSpaceTest") OpDefBuilder("DepthToSpace", "DepthToSpaceTest")
.Input("InputNCHW") .Input("InputNCHW")
.AddIntArg("block_size", block_size) .AddIntArg("block_size", block_size)
...@@ -125,8 +125,8 @@ void RandomTest(const int block_size, ...@@ -125,8 +125,8 @@ void RandomTest(const int block_size,
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("DepthToSpace", "DepthToSpaceTest") OpDefBuilder("DepthToSpace", "DepthToSpaceTest")
.Input("Input") .Input("Input")
......
...@@ -188,9 +188,9 @@ class DepthwiseConv2dOp<DeviceType::CPU, uint8_t> ...@@ -188,9 +188,9 @@ class DepthwiseConv2dOp<DeviceType::CPU, uint8_t>
filter->dim(2) * filter->dim(3), filter->dim(0), filter->dim(1), 1}; filter->dim(2) * filter->dim(3), filter->dim(0), filter->dim(1), 1};
if (paddings_.empty()) { if (paddings_.empty()) {
CalcPaddingAndOutputSize(input->shape().data(), CalcPaddingAndOutputSize(input->shape().data(),
NHWC, DataFormat::NHWC,
ohwi_shape.data(), ohwi_shape.data(),
OHWI, DataFormat::OHWI,
dilations_.data(), dilations_.data(),
strides_.data(), strides_.data(),
padding_type_, padding_type_,
...@@ -199,9 +199,9 @@ class DepthwiseConv2dOp<DeviceType::CPU, uint8_t> ...@@ -199,9 +199,9 @@ class DepthwiseConv2dOp<DeviceType::CPU, uint8_t>
} else { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input->shape().data(), CalcOutputSize(input->shape().data(),
NHWC, DataFormat::NHWC,
ohwi_shape.data(), ohwi_shape.data(),
OHWI, DataFormat::OHWI,
paddings_.data(), paddings_.data(),
dilations_.data(), dilations_.data(),
strides_.data(), strides_.data(),
...@@ -375,14 +375,13 @@ class DepthwiseConv2dOp<DeviceType::GPU, T> : public DepthwiseConv2dOpBase { ...@@ -375,14 +375,13 @@ class DepthwiseConv2dOp<DeviceType::GPU, T> : public DepthwiseConv2dOpBase {
explicit DepthwiseConv2dOp(OpConstructContext *context) explicit DepthwiseConv2dOp(OpConstructContext *context)
: DepthwiseConv2dOpBase(context) { : DepthwiseConv2dOpBase(context) {
MemoryType mem_type; MemoryType mem_type;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::DepthwiseConv2dKernel<T>>(); kernel_ = make_unique<opencl::image::DepthwiseConv2dKernel<T>>();
} else { } else {
mem_type = MemoryType::GPU_BUFFER; mem_type = MemoryType::GPU_BUFFER;
kernel_ = make_unique<opencl::buffer::DepthwiseConv2dKernel<T>>(); kernel_ = make_unique<opencl::buffer::DepthwiseConv2dKernel<T>>();
} }
context->set_output_mem_type(mem_type);
Tensor *filter_tensor = context->workspace()->GetTensor( Tensor *filter_tensor = context->workspace()->GetTensor(
operator_def_->input(1)); operator_def_->input(1));
if (filter_tensor != nullptr && filter_tensor->is_weight()) { if (filter_tensor != nullptr && filter_tensor->is_weight()) {
...@@ -393,8 +392,6 @@ class DepthwiseConv2dOp<DeviceType::GPU, T> : public DepthwiseConv2dOpBase { ...@@ -393,8 +392,6 @@ class DepthwiseConv2dOp<DeviceType::GPU, T> : public DepthwiseConv2dOpBase {
1, 1,
OpenCLBufferType::DW_CONV2D_FILTER, OpenCLBufferType::DW_CONV2D_FILTER,
mem_type) == MaceStatus::MACE_SUCCESS); mem_type) == MaceStatus::MACE_SUCCESS);
} else {
context->SetInputOpenCLBufferType(1, OpenCLBufferType::DW_CONV2D_FILTER);
} }
if (operator_def_->input_size() > 2) { if (operator_def_->input_size() > 2) {
MACE_CHECK(TransformFilter<T>( MACE_CHECK(TransformFilter<T>(
...@@ -440,7 +437,40 @@ void RegisterDepthwiseConv2d(OpRegistryBase *op_registry) { ...@@ -440,7 +437,40 @@ void RegisterDepthwiseConv2d(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "DepthwiseConv2d", MACE_REGISTER_OP(op_registry, "DepthwiseConv2d",
DepthwiseConv2dOp, DeviceType::GPU, half); DepthwiseConv2dOp, DeviceType::GPU, half);
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("DepthwiseConv2d")
.SetInputMemoryTypeSetter(
[](OpConditionContext *context) -> void {
MemoryType mem_type = MemoryType::CPU_BUFFER;
if (context->device()->device_type() == DeviceType::GPU) {
if (context->device()->gpu_runtime()->UseImageMemory()) {
mem_type = MemoryType::GPU_IMAGE;
} else {
mem_type = MemoryType::GPU_BUFFER;
}
auto filter_tensor = context->workspace()->GetTensor(
context->operator_def()->input(1));
if (filter_tensor == nullptr || !filter_tensor->is_weight()) {
context->SetInputOpenCLBufferType(
1, OpenCLBufferType::DW_CONV2D_FILTER);
}
}
context->set_output_mem_type(mem_type);
}));
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("DepthwiseConv2d")
.SetInputsDataFormatSelector(
[](OpConditionContext *context) -> std::vector<DataFormat> {
DataFormat op_data_format =
static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*context->operator_def(), "data_format",
static_cast<int>(DataFormat::NONE)));
return {op_data_format, DataFormat::OIHW, DataFormat::NONE};
}));
} }
} // namespace ops } // namespace ops
......
...@@ -39,8 +39,8 @@ void SimpleValidTest() { ...@@ -39,8 +39,8 @@ void SimpleValidTest() {
true); true);
net.AddInputFromArray<D, float>("Bias", {2}, {.1f, .2f}, true); net.AddInputFromArray<D, float>("Bias", {2}, {.1f, .2f}, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -52,8 +52,8 @@ void SimpleValidTest() { ...@@ -52,8 +52,8 @@ void SimpleValidTest() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input") .Input("Input")
...@@ -127,8 +127,8 @@ void ComplexValidTest(index_t batch, ...@@ -127,8 +127,8 @@ void ComplexValidTest(index_t batch,
true); true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -141,8 +141,8 @@ void ComplexValidTest(index_t batch, ...@@ -141,8 +141,8 @@ void ComplexValidTest(index_t batch,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("Input") .Input("Input")
...@@ -249,8 +249,8 @@ void TestNxNS12(const index_t height, const index_t width) { ...@@ -249,8 +249,8 @@ void TestNxNS12(const index_t height, const index_t width) {
{multiplier * channel}, {multiplier * channel},
true, false); true, false);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -267,8 +267,8 @@ void TestNxNS12(const index_t height, const index_t width) { ...@@ -267,8 +267,8 @@ void TestNxNS12(const index_t height, const index_t width) {
// Run on cpu // Run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -389,9 +389,9 @@ void TestQuant(const index_t batch, ...@@ -389,9 +389,9 @@ void TestQuant(const index_t batch,
"Filter", {k_height, k_width, in_channels, multiplier}, true, false); "Filter", {k_height, k_width, in_channels, multiplier}, true, false);
net.AddRandomInput<CPU, float>("Bias", {out_channels}, true); net.AddRandomInput<CPU, float>("Bias", {out_channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>( net.TransformDataFormat<DeviceType::CPU, float>(
"Input", NHWC, "InputNCHW", NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.TransformFilterDataFormat<DeviceType::CPU, float>( net.TransformFilterDataFormat<DeviceType::CPU, float>(
"Filter", HWIO, "FilterOIHW", OIHW); "Filter", DataFormat::HWIO, "FilterOIHW", DataFormat::OIHW);
OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest") OpDefBuilder("DepthwiseConv2d", "DepthwiseConv2DTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -405,7 +405,7 @@ void TestQuant(const index_t batch, ...@@ -405,7 +405,7 @@ void TestQuant(const index_t batch,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>( net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", NCHW, "Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Quantize", "QuantizeFilter") OpDefBuilder("Quantize", "QuantizeFilter")
.Input("Filter") .Input("Filter")
......
...@@ -190,7 +190,7 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase { ...@@ -190,7 +190,7 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase {
explicit DepthwiseDeconv2dOp(OpConstructContext *context) explicit DepthwiseDeconv2dOp(OpConstructContext *context)
: Deconv2dOpBase(context) { : Deconv2dOpBase(context) {
MemoryType mem_type = MemoryType::GPU_IMAGE; MemoryType mem_type = MemoryType::GPU_IMAGE;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::DepthwiseDeconv2dKernel<T>>(); kernel_ = make_unique<opencl::image::DepthwiseDeconv2dKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -230,7 +230,7 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase { ...@@ -230,7 +230,7 @@ class DepthwiseDeconv2dOp<DeviceType::GPU, T> : public Deconv2dOpBase {
&out_paddings, &out_paddings,
nullptr, nullptr,
CAFFE, CAFFE,
NHWC); DataFormat::NHWC);
return kernel_->Compute(context, return kernel_->Compute(context,
input, input,
......
...@@ -39,7 +39,8 @@ void RunTestSimple(const int group, ...@@ -39,7 +39,8 @@ void RunTestSimple(const int group,
// Add input data // Add input data
net.AddInputFromArray<D, float>("Input", input_shape, input_data); net.AddInputFromArray<D, float>("Input", input_shape, input_data);
net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data, true); net.AddInputFromArray<D, float>("Filter", filter_shape, filter_data, true);
net.TransformFilterDataFormat<D, float>("Filter", HWOI, "FilterOIHW", OIHW); net.TransformFilterDataFormat<D, float>(
"Filter", DataFormat::HWOI, "FilterOIHW", DataFormat::OIHW);
const index_t out_channels = expected_shape[3]; const index_t out_channels = expected_shape[3];
net.AddInputFromArray<D, float>("Bias", {out_channels}, bias_data, true); net.AddInputFromArray<D, float>("Bias", {out_channels}, bias_data, true);
...@@ -56,8 +57,8 @@ void RunTestSimple(const int group, ...@@ -56,8 +57,8 @@ void RunTestSimple(const int group,
net.RunOp(D); net.RunOp(D);
} else { } else {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, net.TransformDataFormat<DeviceType::CPU, float>(
"InputNCHW", NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest") OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("FilterOIHW") .Input("FilterOIHW")
...@@ -69,8 +70,8 @@ void RunTestSimple(const int group, ...@@ -69,8 +70,8 @@ void RunTestSimple(const int group,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
auto expected = net.CreateTensor<float>(expected_shape, expected_data); auto expected = net.CreateTensor<float>(expected_shape, expected_data);
...@@ -193,8 +194,8 @@ void RandomTest(index_t batch, ...@@ -193,8 +194,8 @@ void RandomTest(index_t batch,
{channel * multiplier}, {channel * multiplier},
bias_data, true, false); bias_data, true, false);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest") OpDefBuilder("DepthwiseDeconv2d", "DepthwiseDeconv2dTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Filter") .Input("Filter")
...@@ -210,8 +211,8 @@ void RandomTest(index_t batch, ...@@ -210,8 +211,8 @@ void RandomTest(index_t batch,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
......
...@@ -1145,7 +1145,7 @@ class EltwiseOp<DeviceType::GPU, T> : public Operation { ...@@ -1145,7 +1145,7 @@ class EltwiseOp<DeviceType::GPU, T> : public Operation {
int32_t scalar_input_index = Operation::GetOptionalArg<int32_t>( int32_t scalar_input_index = Operation::GetOptionalArg<int32_t>(
"scalar_input_index", 1); "scalar_input_index", 1);
MemoryType mem_type; MemoryType mem_type;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::EltwiseKernel<T>>( kernel_ = make_unique<opencl::image::EltwiseKernel<T>>(
type, coeff, scalar_input, scalar_input_index); type, coeff, scalar_input, scalar_input_index);
......
...@@ -69,7 +69,8 @@ void SimpleTensorScalar(const ops::EltwiseType type, ...@@ -69,7 +69,8 @@ void SimpleTensorScalar(const ops::EltwiseType type,
net.AddInputFromArray<D, T>("Input", shape, input); net.AddInputFromArray<D, T>("Input", shape, input);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, T>("Input", NHWC, "TInput", NCHW); net.TransformDataFormat<D, T>(
"Input", DataFormat::NHWC, "TInput", DataFormat::NCHW);
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("TInput") .Input("TInput")
.AddIntArg("T", DataTypeToEnum<T>::v()) .AddIntArg("T", DataTypeToEnum<T>::v())
...@@ -81,7 +82,8 @@ void SimpleTensorScalar(const ops::EltwiseType type, ...@@ -81,7 +82,8 @@ void SimpleTensorScalar(const ops::EltwiseType type,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, DstType>("TOutput", NCHW, "Output", NHWC); net.TransformDataFormat<D, DstType>(
"TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else { } else {
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("Input") .Input("Input")
...@@ -124,13 +126,15 @@ void SimpleTensorEltwise(const ops::EltwiseType type, ...@@ -124,13 +126,15 @@ void SimpleTensorEltwise(const ops::EltwiseType type,
.OutputType({ops::IsLogicalType(type) ? DT_INT32 : DT_FLOAT}) .OutputType({ops::IsLogicalType(type) ? DT_INT32 : DT_FLOAT})
.Output("TOutput"); .Output("TOutput");
if (shape0.size() > 1) { if (shape0.size() > 1) {
net.TransformDataFormat<D, T>("Input0", NHWC, "TInput0", NCHW); net.TransformDataFormat<D, T>(
"Input0", DataFormat::NHWC, "TInput0", DataFormat::NCHW);
op_builder.Input("TInput0"); op_builder.Input("TInput0");
} else { } else {
op_builder.Input("Input0"); op_builder.Input("Input0");
} }
if (shape1.size() > 1) { if (shape1.size() > 1) {
net.TransformDataFormat<D, T>("Input1", NHWC, "TInput1", NCHW); net.TransformDataFormat<D, T>(
"Input1", DataFormat::NHWC, "TInput1", DataFormat::NCHW);
op_builder.Input("TInput1"); op_builder.Input("TInput1");
} else { } else {
op_builder.Input("Input1"); op_builder.Input("Input1");
...@@ -139,7 +143,8 @@ void SimpleTensorEltwise(const ops::EltwiseType type, ...@@ -139,7 +143,8 @@ void SimpleTensorEltwise(const ops::EltwiseType type,
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, DstType>("TOutput", NCHW, "Output", NHWC); net.TransformDataFormat<D, DstType>(
"TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else { } else {
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("Input0") .Input("Input0")
...@@ -560,7 +565,8 @@ void GPUOverflowTest(const ops::EltwiseType type, ...@@ -560,7 +565,8 @@ void GPUOverflowTest(const ops::EltwiseType type,
net.AddInputFromArray<DeviceType::GPU, T>( net.AddInputFromArray<DeviceType::GPU, T>(
"Filter", "Filter",
{output_shape.back(), shape0.back(), 3, 3}, {output_shape.back(), shape0.back(), 3, 3},
std::vector<float>(output_shape.back() * shape0.back() * 9, 1)); std::vector<float>(output_shape.back() * shape0.back() * 9, 1),
true);
OpDefBuilder("Conv2D", "Conv2D") OpDefBuilder("Conv2D", "Conv2D")
.AddIntArg("T", DataTypeToEnum<T>::v()) .AddIntArg("T", DataTypeToEnum<T>::v())
.Input("EltOutput") .Input("EltOutput")
...@@ -636,8 +642,8 @@ void RandomTensorScalar(const ops::EltwiseType type, ...@@ -636,8 +642,8 @@ void RandomTensorScalar(const ops::EltwiseType type,
// Add input data // Add input data
net.AddRandomInput<DeviceType::GPU, float>("Input", shape, false, true, true); net.AddRandomInput<DeviceType::GPU, float>("Input", shape, false, true, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "TInput", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "TInput", DataFormat::NCHW);
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("TInput") .Input("TInput")
.AddIntArg("type", static_cast<int>(type)) .AddIntArg("type", static_cast<int>(type))
...@@ -647,8 +653,8 @@ void RandomTensorScalar(const ops::EltwiseType type, ...@@ -647,8 +653,8 @@ void RandomTensorScalar(const ops::EltwiseType type,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("TOutput", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
...@@ -690,10 +696,10 @@ void RandomTensorEltwise(const ops::EltwiseType type, ...@@ -690,10 +696,10 @@ void RandomTensorEltwise(const ops::EltwiseType type,
true, true,
true); true);
net.TransformDataFormat<DeviceType::CPU, float>("Input0", NHWC, "TInput0", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input0", DataFormat::NHWC, "TInput0", DataFormat::NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Input1", NHWC, "TInput1", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input1", DataFormat::NHWC, "TInput1", DataFormat::NCHW);
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("TInput0") .Input("TInput0")
.Input("TInput1") .Input("TInput1")
...@@ -705,8 +711,8 @@ void RandomTensorEltwise(const ops::EltwiseType type, ...@@ -705,8 +711,8 @@ void RandomTensorEltwise(const ops::EltwiseType type,
// Run // Run
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("TOutput", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
...@@ -746,10 +752,10 @@ void Quantized(const std::vector<index_t> &shape, ...@@ -746,10 +752,10 @@ void Quantized(const std::vector<index_t> &shape,
true, true,
true); true);
net.TransformDataFormat<DeviceType::CPU, float>("Input0", NHWC, "TInput0", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input0", DataFormat::NHWC, "TInput0", DataFormat::NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Input1", NHWC, "TInput1", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input1", DataFormat::NHWC, "TInput1", DataFormat::NCHW);
OpDefBuilder("Eltwise", "EltwiseTest") OpDefBuilder("Eltwise", "EltwiseTest")
.Input("TInput0") .Input("TInput0")
...@@ -761,8 +767,8 @@ void Quantized(const std::vector<index_t> &shape, ...@@ -761,8 +767,8 @@ void Quantized(const std::vector<index_t> &shape,
// Run // Run
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("TOutput", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Quantize", "QuantizeInput0") OpDefBuilder("Quantize", "QuantizeInput0")
.Input("Input0") .Input("Input0")
......
...@@ -14,7 +14,6 @@ ...@@ -14,7 +14,6 @@
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/ops/common/transpose.h"
#include "mace/utils/math.h" #include "mace/utils/math.h"
namespace mace { namespace mace {
...@@ -44,27 +43,8 @@ class ExpandDimsOp<DeviceType::CPU, T> : public Operation { ...@@ -44,27 +43,8 @@ class ExpandDimsOp<DeviceType::CPU, T> : public Operation {
std::vector<index_t> output_shape(input_shape); std::vector<index_t> output_shape(input_shape);
output_shape.insert(output_shape.begin() + axis_, 1); output_shape.insert(output_shape.begin() + axis_, 1);
bool has_data_format = Operation::GetOptionalArg<int>( output->ReuseTensorBuffer(*input);
"has_data_format", 0) == 1; output->Reshape(output_shape);
if (has_data_format && output_shape.size() == 4) {
// only tensorflow support expand dim, so the default format is NHWC
// transform NHWC to NCHW
auto t_output_shape = TransposeShape<int64_t, int64_t>(output_shape,
{0, 3, 1, 2});
output->Resize(t_output_shape);
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
auto input_data = input->data<T>();
auto output_data = output->mutable_data<T>();
Transpose(&context->device()->cpu_runtime()->thread_pool(),
input_data, output_shape, {0, 3, 1, 2}, output_data);
} else {
output->Resize(output_shape);
Tensor::MappingGuard input_guard(input);
auto input_data = input->data<T>();
output->Copy<T>(input_data, input->size());
}
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
} }
......
...@@ -49,7 +49,8 @@ void Simple() { ...@@ -49,7 +49,8 @@ void Simple() {
net.AddInputFromArray<D, float>("Offset", {1}, offset, true); net.AddInputFromArray<D, float>("Offset", {1}, offset, true);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<D, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "FoldedBatchNormTest") OpDefBuilder("BatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Scale") .Input("Scale")
...@@ -58,7 +59,8 @@ void Simple() { ...@@ -58,7 +59,8 @@ void Simple() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("BatchNorm", "FoldedBatchNormTest") OpDefBuilder("BatchNorm", "FoldedBatchNormTest")
.Input("Input") .Input("Input")
...@@ -100,8 +102,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) { ...@@ -100,8 +102,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "FoldedBatchNormTest") OpDefBuilder("BatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -113,8 +115,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) { ...@@ -113,8 +115,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -151,8 +153,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) { ...@@ -151,8 +153,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "FoldedBatchNormTest") OpDefBuilder("BatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -164,8 +166,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) { ...@@ -164,8 +166,8 @@ TEST_F(FoldedBatchNormOpTest, SimpleRandomHalfOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -205,8 +207,8 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) { ...@@ -205,8 +207,8 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) {
net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true); net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "FoldedBatchNormTest") OpDefBuilder("BatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -218,8 +220,8 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) { ...@@ -218,8 +220,8 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -254,11 +256,11 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -254,11 +256,11 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) {
// Add input data // Add input data
net.AddRandomInput<DeviceType::GPU, float>("Input", net.AddRandomInput<DeviceType::GPU, float>("Input",
{batch, height, width, channels}); {batch, height, width, channels});
net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}); net.AddRandomInput<DeviceType::GPU, float>("Scale", {channels}, true);
net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}); net.AddRandomInput<DeviceType::GPU, float>("Offset", {channels}, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchNorm", "FoldedBatchNormTest") OpDefBuilder("BatchNorm", "FoldedBatchNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -270,8 +272,8 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) { ...@@ -270,8 +272,8 @@ TEST_F(FoldedBatchNormOpTest, ComplexRandomHalfOPENCL) {
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
......
...@@ -190,7 +190,7 @@ class FullyConnectedOp<DeviceType::GPU, T> : public FullyConnectedOpBase { ...@@ -190,7 +190,7 @@ class FullyConnectedOp<DeviceType::GPU, T> : public FullyConnectedOpBase {
explicit FullyConnectedOp(OpConstructContext *context) explicit FullyConnectedOp(OpConstructContext *context)
: FullyConnectedOpBase(context) { : FullyConnectedOpBase(context) {
MemoryType mem_type = MemoryType::CPU_BUFFER; MemoryType mem_type = MemoryType::CPU_BUFFER;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
mem_type = MemoryType::GPU_IMAGE; mem_type = MemoryType::GPU_IMAGE;
kernel_ = make_unique<opencl::image::FullyConnectedKernel<T>>(); kernel_ = make_unique<opencl::image::FullyConnectedKernel<T>>();
} else { } else {
......
...@@ -48,7 +48,8 @@ void Simple(const std::vector<index_t> &input_shape, ...@@ -48,7 +48,8 @@ void Simple(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("FullyConnected", "FullyConnectedTest") OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("Input") .Input("Input")
...@@ -129,8 +130,8 @@ void Random(const index_t batch, ...@@ -129,8 +130,8 @@ void Random(const index_t batch,
net.AddRandomInput<DeviceType::GPU, float>("Bias", {out_channel}, true, net.AddRandomInput<DeviceType::GPU, float>("Bias", {out_channel}, true,
false); false);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("FullyConnected", "FullyConnectedTest") OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("InputNCHW") .Input("InputNCHW")
.Input("Weight") .Input("Weight")
...@@ -143,7 +144,8 @@ void Random(const index_t batch, ...@@ -143,7 +144,8 @@ void Random(const index_t batch,
// run cpu // run cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
...@@ -215,8 +217,10 @@ void QuantRandom(const index_t batch, ...@@ -215,8 +217,10 @@ void QuantRandom(const index_t batch,
net.AddRandomInput<CPU, float>( net.AddRandomInput<CPU, float>(
"Weight", {out_channel, height, width, channels}, true); "Weight", {out_channel, height, width, channels}, true);
net.AddRandomInput<CPU, float>("Bias", {out_channel}, true); net.AddRandomInput<CPU, float>("Bias", {out_channel}, true);
net.TransformDataFormat<CPU, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<CPU, float>(
net.TransformFilterDataFormat<CPU, float>("Weight", OHWI, "WeightOIHW", OIHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.TransformFilterDataFormat<CPU, float>(
"Weight", DataFormat::OHWI, "WeightOIHW", DataFormat::OIHW);
OpDefBuilder("FullyConnected", "FullyConnectedTest") OpDefBuilder("FullyConnected", "FullyConnectedTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -226,7 +230,8 @@ void QuantRandom(const index_t batch, ...@@ -226,7 +230,8 @@ void QuantRandom(const index_t batch,
.AddIntArg("T", DT_FLOAT) .AddIntArg("T", DT_FLOAT)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(); net.RunOp();
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Quantize", "QuantizeWeight") OpDefBuilder("Quantize", "QuantizeWeight")
.Input("Weight") .Input("Weight")
......
...@@ -29,7 +29,8 @@ void Simple() { ...@@ -29,7 +29,8 @@ void Simple() {
{5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15}); {5, 5, 7, 7, 9, 9, 11, 11, 13, 13, 15, 15});
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<D, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("LocalResponseNorm", "LocalResponseNormTest") OpDefBuilder("LocalResponseNorm", "LocalResponseNormTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -41,7 +42,8 @@ void Simple() { ...@@ -41,7 +42,8 @@ void Simple() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
// Check // Check
......
...@@ -36,7 +36,7 @@ class LSTMCellOp<DeviceType::GPU, T> : public Operation { ...@@ -36,7 +36,7 @@ class LSTMCellOp<DeviceType::GPU, T> : public Operation {
Operation::GetOptionalArg<float>("scalar_input", Operation::GetOptionalArg<float>("scalar_input",
0.0)); 0.0));
MemoryType mem_type = MemoryType::GPU_IMAGE; MemoryType mem_type = MemoryType::GPU_IMAGE;
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::LSTMCellKernel<T>>(forget_bias); kernel_ = make_unique<opencl::image::LSTMCellKernel<T>>(forget_bias);
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
......
...@@ -518,14 +518,6 @@ void RegisterMatMul(OpRegistryBase *op_registry) { ...@@ -518,14 +518,6 @@ void RegisterMatMul(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp, MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp,
DeviceType::CPU, uint8_t); DeviceType::CPU, uint8_t);
#endif // MACE_ENABLE_QUANTIZE #endif // MACE_ENABLE_QUANTIZE
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp,
DeviceType::GPU, float);
MACE_REGISTER_OP(op_registry, "MatMul", MatMulOp,
DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL
} }
} // namespace ops } // namespace ops
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include "mace/ops/opencl/image/buffer_to_image.h" #include "mace/ops/opencl/image/buffer_to_image.h"
#include "mace/ops/opencl/image/image_to_buffer.h" #include "mace/ops/opencl/image/image_to_buffer.h"
#include "mace/ops/opencl/buffer/buffer_transform.h" #include "mace/ops/opencl/buffer/buffer_transform.h"
#include "mace/ops/common/transpose.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
namespace mace { namespace mace {
...@@ -48,7 +47,6 @@ class OpenCLBufferTransformer { ...@@ -48,7 +47,6 @@ class OpenCLBufferTransformer {
const OpenCLBufferType type, const OpenCLBufferType type,
const MemoryType out_mem_type, const MemoryType out_mem_type,
const int wino_blk_size, const int wino_blk_size,
bool has_data_format,
Tensor *output) { Tensor *output) {
Workspace *ws = context->workspace(); Workspace *ws = context->workspace();
DataType dt = DataTypeToEnum<T>::value; DataType dt = DataTypeToEnum<T>::value;
...@@ -67,31 +65,11 @@ class OpenCLBufferTransformer { ...@@ -67,31 +65,11 @@ class OpenCLBufferTransformer {
VLOG(2) << "Transform CPU Buffer " << input->name() VLOG(2) << "Transform CPU Buffer " << input->name()
<< " to GPU Buffer " << internal_tensor->name() << " to GPU Buffer " << internal_tensor->name()
<< " with data type " << dt; << " with data type " << dt;
if (has_data_format && input->shape().size() == 4) { internal_tensor->Resize(input->shape());
// 1. (NCHW -> NHWC) const uint8_t *input_ptr = input->data<uint8_t>();
std::vector<int> dst_dims = {0, 2, 3, 1}; Tensor::MappingGuard guard(internal_tensor);
std::vector<index_t> output_shape = uint8_t *internal_ptr = internal_tensor->mutable_data<uint8_t>();
TransposeShape<index_t, index_t>(input->shape(), memcpy(internal_ptr, input_ptr, input->raw_size());
dst_dims);
internal_tensor->Resize(output_shape);
internal_tensor->set_data_format(DataFormat::NHWC);
// TODO(liuqi): Only support float now
const float *input_ptr = input->data<float>();
Tensor::MappingGuard guard(internal_tensor);
float *internal_ptr = internal_tensor->mutable_data<float>();
MACE_RETURN_IF_ERROR(ops::Transpose(
&context->device()->cpu_runtime()->thread_pool(),
input_ptr,
input->shape(),
dst_dims,
internal_ptr));
} else {
internal_tensor->Resize(input->shape());
const uint8_t *input_ptr = input->data<uint8_t>();
Tensor::MappingGuard guard(internal_tensor);
uint8_t *internal_ptr = internal_tensor->mutable_data<uint8_t>();
memcpy(internal_ptr, input_ptr, input->raw_size());
}
// 2. convert the internal GPU Buffer to output. // 2. convert the internal GPU Buffer to output.
return kernel_->Compute( return kernel_->Compute(
context, internal_tensor, type, wino_blk_size, output); context, internal_tensor, type, wino_blk_size, output);
...@@ -108,30 +86,12 @@ class OpenCLBufferTransformer { ...@@ -108,30 +86,12 @@ class OpenCLBufferTransformer {
VLOG(2) << "Transform GPU Buffer " << internal_tensor.name() VLOG(2) << "Transform GPU Buffer " << internal_tensor.name()
<< " to CPU Buffer " << output->name() << " to CPU Buffer " << output->name()
<< " with data type " << dt; << " with data type " << dt;
if (has_data_format && internal_tensor.shape().size() == 4) { Tensor::MappingGuard guard(&internal_tensor);
// NHWC -> NCHW const T *internal_ptr = internal_tensor.data<T>();
std::vector<int> dst_dims = {0, 3, 1, 2}; output->Resize(internal_tensor.shape());
std::vector<index_t> output_shape = T *output_ptr = output->mutable_data<T>();
TransposeShape<index_t, index_t>(internal_tensor.shape(), memcpy(output_ptr, internal_ptr, internal_tensor.size() * sizeof(T));
dst_dims); return MaceStatus::MACE_SUCCESS;
output->set_data_format(DataFormat::NCHW);
Tensor::MappingGuard guard(&internal_tensor);
const float *internal_ptr = internal_tensor.data<float>();
output->Resize(output_shape);
float *output_ptr = output->mutable_data<float>();
return ops::Transpose(&context->device()->cpu_runtime()->thread_pool(),
internal_ptr,
internal_tensor.shape(),
dst_dims,
output_ptr);
} else {
Tensor::MappingGuard guard(&internal_tensor);
const T *internal_ptr = internal_tensor.data<T>();
output->Resize(internal_tensor.shape());
T *output_ptr = output->mutable_data<T>();
memcpy(output_ptr, internal_ptr, internal_tensor.size() * sizeof(T));
return MaceStatus::MACE_SUCCESS;
}
} else { } else {
LOG(FATAL) << "Unexpected error: " << out_mem_type; LOG(FATAL) << "Unexpected error: " << out_mem_type;
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
...@@ -172,7 +132,7 @@ MaceStatus TransformFilter( ...@@ -172,7 +132,7 @@ MaceStatus TransformFilter(
input->MarkUnused(); input->MarkUnused();
return OpenCLBufferTransformer<T>(input->memory_type(), mem_type). return OpenCLBufferTransformer<T>(input->memory_type(), mem_type).
Transform(&op_context, input, buffer_type, mem_type, wino_blk_size, Transform(&op_context, input, buffer_type, mem_type, wino_blk_size,
DataFormat::DF_NONE, output); output);
} }
} // namespace ops } // namespace ops
......
...@@ -71,14 +71,17 @@ MaceStatus EltwiseKernel<T>::Compute( ...@@ -71,14 +71,17 @@ MaceStatus EltwiseKernel<T>::Compute(
if (input1 == nullptr) { if (input1 == nullptr) {
input1_type = "INPUT_SCALAR"; input1_type = "INPUT_SCALAR";
} else { } else {
MACE_CHECK(input0->dim_size() == input1->dim_size() || MACE_CHECK((input0->dim_size() == input1->dim_size()
&& input0->dim_size() == 4) ||
input0->dim_size() == 1 || input1->dim_size() == 1) input0->dim_size() == 1 || input1->dim_size() == 1)
<< "Inputs of Eltwise op must be same shape"; << "Inputs of Eltwise op must be same shape or fulfill broadcast logic";
MACE_CHECK(type_ != EltwiseType::EQUAL) MACE_CHECK(type_ != EltwiseType::EQUAL)
<< "Eltwise op on GPU does not support EQUAL"; << "Eltwise op on GPU does not support EQUAL";
// broadcast // broadcast
if (input0->size() != input1->size()) { if (input0->size() != input1->size() ||
if (input0->size() < input1->size()) { input0->dim_size() != input1->dim_size()) {
if (input0->size() < input1->size()
|| input0->dim_size() < input1->dim_size()) {
std::swap(input0, input1); std::swap(input0, input1);
swapped = true; swapped = true;
} }
......
...@@ -59,11 +59,6 @@ MaceStatus ReduceKernel<T>::Compute( ...@@ -59,11 +59,6 @@ MaceStatus ReduceKernel<T>::Compute(
const Tensor *input, const Tensor *input,
Tensor *output) { Tensor *output) {
MACE_CHECK_NOTNULL(input); MACE_CHECK_NOTNULL(input);
MACE_CHECK(keep_dims_, "reduce mean gpu only support keep dims.");
MACE_CHECK(input->dim_size() == 4,
"reduce gpu only support 4-dim input");
MACE_CHECK(axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2,
"reduce gpu only support 1,2-axis reduce");
index_t batch = input->dim(0); index_t batch = input->dim(0);
const index_t in_height = input->dim(1); const index_t in_height = input->dim(1);
const index_t in_width = input->dim(2); const index_t in_width = input->dim(2);
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
#include "mace/core/memory_optimizer.h" #include "mace/core/memory_optimizer.h"
#include "mace/utils/memory.h" #include "mace/utils/memory.h"
#include "mace/core/net_def_adapter.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
...@@ -175,26 +176,27 @@ void OpTestContext::SetOCLImageAndBufferTestFlag() { ...@@ -175,26 +176,27 @@ void OpTestContext::SetOCLImageAndBufferTestFlag() {
bool OpsTestNet::Setup(mace::DeviceType device) { bool OpsTestNet::Setup(mace::DeviceType device) {
NetDef net_def; NetDef net_def;
for (auto &op_def : op_defs_) { for (auto &op_def : op_defs_) {
net_def.add_op()->CopyFrom(op_def); auto target_op = net_def.add_op();
target_op->CopyFrom(op_def);
auto has_data_format = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "has_data_format", 0);
auto is_quantized_op = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "T", static_cast<int>(DT_FLOAT))
== static_cast<int>(DT_UINT8);
for (auto input : op_def.input()) { for (auto input : op_def.input()) {
if (ws_.GetTensor(input) != nullptr && if (ws_.GetTensor(input) != nullptr &&
!ws_.GetTensor(input)->is_weight()) { !ws_.GetTensor(input)->is_weight()) {
auto input_info = net_def.add_input_info(); auto input_info = net_def.add_input_info();
input_info->set_name(input); input_info->set_name(input);
auto has_data_format = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "has_data_format", 1);
auto is_quantized_op = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def, "T", static_cast<int>(DT_FLOAT))
== static_cast<int>(DT_UINT8);
if (has_data_format) { if (has_data_format) {
if (is_quantized_op || device == DeviceType::GPU) { if (is_quantized_op || device == DeviceType::GPU) {
input_info->set_data_format(NHWC); input_info->set_data_format(static_cast<int>(DataFormat::NHWC));
} else { } else {
input_info->set_data_format(NCHW); input_info->set_data_format(static_cast<int>(DataFormat::NCHW));
} }
} else { } else {
input_info->set_data_format(DataFormat::DF_NONE); input_info->set_data_format(static_cast<int>(DataFormat::NONE));
} }
auto &shape = ws_.GetTensor(input)->shape(); auto &shape = ws_.GetTensor(input)->shape();
for (auto d : shape) { for (auto d : shape) {
...@@ -202,6 +204,10 @@ bool OpsTestNet::Setup(mace::DeviceType device) { ...@@ -202,6 +204,10 @@ bool OpsTestNet::Setup(mace::DeviceType device) {
} }
} }
} }
if (has_data_format) {
SetProtoArg<int>(target_op, "data_format",
static_cast<int>(DataFormat::AUTO));
}
} }
if (!op_defs_.empty()) { if (!op_defs_.empty()) {
auto op_def = op_defs_.back(); auto op_def = op_defs_.back();
...@@ -216,15 +222,21 @@ bool OpsTestNet::Setup(mace::DeviceType device) { ...@@ -216,15 +222,21 @@ bool OpsTestNet::Setup(mace::DeviceType device) {
} }
} }
} }
NetDef adapted_net_def;
NetDefAdapter net_def_adapter(op_registry_.get(), &ws_);
net_def_adapter.AdaptNetDef(&net_def,
OpTestContext::Get()->GetDevice(device),
&adapted_net_def);
MemoryOptimizer mem_optimizer; MemoryOptimizer mem_optimizer;
net_ = make_unique<SerialNet>( net_ = make_unique<SerialNet>(
op_registry_.get(), op_registry_.get(),
&net_def, &adapted_net_def,
&ws_, &ws_,
OpTestContext::Get()->GetDevice(device), OpTestContext::Get()->GetDevice(device),
&mem_optimizer); &mem_optimizer);
MaceStatus status = (ws_.PreallocateOutputTensor( MaceStatus status = (ws_.PreallocateOutputTensor(
net_def, adapted_net_def,
&mem_optimizer, &mem_optimizer,
OpTestContext::Get()->GetDevice(device))); OpTestContext::Get()->GetDevice(device)));
if (status != MaceStatus::MACE_SUCCESS) return false; if (status != MaceStatus::MACE_SUCCESS) return false;
...@@ -267,15 +279,20 @@ MaceStatus OpsTestNet::RunOp() { ...@@ -267,15 +279,20 @@ MaceStatus OpsTestNet::RunOp() {
MaceStatus OpsTestNet::RunNet(const mace::NetDef &net_def, MaceStatus OpsTestNet::RunNet(const mace::NetDef &net_def,
const mace::DeviceType device) { const mace::DeviceType device) {
device_type_ = device; device_type_ = device;
NetDef adapted_net_def;
NetDefAdapter net_def_adapter(op_registry_.get(), &ws_);
net_def_adapter.AdaptNetDef(&net_def,
OpTestContext::Get()->GetDevice(device),
&adapted_net_def);
MemoryOptimizer mem_optimizer; MemoryOptimizer mem_optimizer;
net_ = make_unique<SerialNet>( net_ = make_unique<SerialNet>(
op_registry_.get(), op_registry_.get(),
&net_def, &adapted_net_def,
&ws_, &ws_,
OpTestContext::Get()->GetDevice(device), OpTestContext::Get()->GetDevice(device),
&mem_optimizer); &mem_optimizer);
MACE_RETURN_IF_ERROR(ws_.PreallocateOutputTensor( MACE_RETURN_IF_ERROR(ws_.PreallocateOutputTensor(
net_def, adapted_net_def,
&mem_optimizer, &mem_optimizer,
OpTestContext::Get()->GetDevice(device))); OpTestContext::Get()->GetDevice(device)));
MACE_RETURN_IF_ERROR(net_->Init()); MACE_RETURN_IF_ERROR(net_->Init());
......
...@@ -223,7 +223,7 @@ class OpsTestNet { ...@@ -223,7 +223,7 @@ class OpsTestNet {
const std::vector<index_t> input_shape = input->shape(); const std::vector<index_t> input_shape = input->shape();
MACE_CHECK(input_shape.size() == 4, "input shape != 4"); MACE_CHECK(input_shape.size() == 4, "input shape != 4");
if (src_format == NHWC && dst_format == NCHW) { if (src_format == DataFormat::NHWC && dst_format == DataFormat::NCHW) {
index_t batch = input_shape[0]; index_t batch = input_shape[0];
index_t height = input_shape[1]; index_t height = input_shape[1];
index_t width = input_shape[2]; index_t width = input_shape[2];
...@@ -243,7 +243,8 @@ class OpsTestNet { ...@@ -243,7 +243,8 @@ class OpsTestNet {
} }
} }
} }
} else if (src_format == NCHW && dst_format == NHWC) { } else if (src_format == DataFormat::NCHW &&
dst_format == DataFormat::NHWC) {
index_t batch = input_shape[0]; index_t batch = input_shape[0];
index_t channels = input_shape[1]; index_t channels = input_shape[1];
index_t height = input_shape[2]; index_t height = input_shape[2];
...@@ -281,7 +282,7 @@ class OpsTestNet { ...@@ -281,7 +282,7 @@ class OpsTestNet {
input->is_weight()); input->is_weight());
const std::vector<index_t> input_shape = input->shape(); const std::vector<index_t> input_shape = input->shape();
MACE_CHECK(input_shape.size() == 4, "input shape != 4"); MACE_CHECK(input_shape.size() == 4, "input shape != 4");
if (src_format == HWOI && dst_format == OIHW) { if (src_format == DataFormat::HWOI && dst_format == DataFormat::OIHW) {
index_t height = input_shape[0]; index_t height = input_shape[0];
index_t width = input_shape[1]; index_t width = input_shape[1];
index_t out_channels = input_shape[2]; index_t out_channels = input_shape[2];
...@@ -299,7 +300,8 @@ class OpsTestNet { ...@@ -299,7 +300,8 @@ class OpsTestNet {
input_data[j * out_channels * in_channels + i]; input_data[j * out_channels * in_channels + i];
} }
} }
} else if (src_format == OIHW && dst_format == HWOI) { } else if (src_format == DataFormat::OIHW &&
dst_format == DataFormat::HWOI) {
index_t out_channels = input_shape[0]; index_t out_channels = input_shape[0];
index_t in_channels = input_shape[1]; index_t in_channels = input_shape[1];
index_t height = input_shape[2]; index_t height = input_shape[2];
...@@ -317,7 +319,8 @@ class OpsTestNet { ...@@ -317,7 +319,8 @@ class OpsTestNet {
input_data[j * height * width + i]; input_data[j * height * width + i];
} }
} }
} else if (src_format == HWIO && dst_format == OIHW) { } else if (src_format == DataFormat::HWIO &&
dst_format == DataFormat::OIHW) {
index_t height = input_shape[0]; index_t height = input_shape[0];
index_t width = input_shape[1]; index_t width = input_shape[1];
index_t in_channels = input_shape[2]; index_t in_channels = input_shape[2];
...@@ -337,7 +340,8 @@ class OpsTestNet { ...@@ -337,7 +340,8 @@ class OpsTestNet {
} }
} }
} }
} else if (src_format == OHWI && dst_format == OIHW) { } else if (src_format == DataFormat::OHWI &&
dst_format == DataFormat::OIHW) {
index_t out_channels = input_shape[0]; index_t out_channels = input_shape[0];
index_t height = input_shape[1]; index_t height = input_shape[1];
index_t width = input_shape[2]; index_t width = input_shape[2];
......
...@@ -179,7 +179,7 @@ class PadOp<DeviceType::GPU, T> : public Operation { ...@@ -179,7 +179,7 @@ class PadOp<DeviceType::GPU, T> : public Operation {
std::vector<int> paddings = Operation::GetRepeatedArgs<int>("paddings"); std::vector<int> paddings = Operation::GetRepeatedArgs<int>("paddings");
float constant_value = Operation::GetOptionalArg<float>( float constant_value = Operation::GetOptionalArg<float>(
"constant_value", 0.0); "constant_value", 0.0);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::PadKernel<T>>( kernel_ = make_unique<opencl::image::PadKernel<T>>(
type, paddings, constant_value); type, paddings, constant_value);
} else { } else {
......
...@@ -45,8 +45,8 @@ void SimpleConstant() { ...@@ -45,8 +45,8 @@ void SimpleConstant() {
// Run // Run
net.RunOp(D); net.RunOp(D);
} else { } else {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "TInput", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "TInput", DataFormat::NCHW);
OpDefBuilder("Pad", "PadTest") OpDefBuilder("Pad", "PadTest")
.Input("TInput") .Input("TInput")
.Output("TOutput") .Output("TOutput")
...@@ -58,8 +58,8 @@ void SimpleConstant() { ...@@ -58,8 +58,8 @@ void SimpleConstant() {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("TOutput", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
auto output = net.GetTensor("Output"); auto output = net.GetTensor("Output");
...@@ -93,7 +93,8 @@ void Result(const std::vector<index_t> &input_shape, ...@@ -93,7 +93,8 @@ void Result(const std::vector<index_t> &input_shape,
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
t_input = "TInput"; t_input = "TInput";
t_output = "TOutput"; t_output = "TOutput";
net.TransformDataFormat<DeviceType::CPU, T>(input, NHWC, t_input, NCHW); net.TransformDataFormat<DeviceType::CPU, T>(
input, DataFormat::NHWC, t_input, DataFormat::NCHW);
} }
OpDefBuilder("Pad", "PadTest") OpDefBuilder("Pad", "PadTest")
...@@ -108,7 +109,8 @@ void Result(const std::vector<index_t> &input_shape, ...@@ -108,7 +109,8 @@ void Result(const std::vector<index_t> &input_shape,
net.RunOp(D); net.RunOp(D);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, T>(t_output, NCHW, output, NHWC); net.TransformDataFormat<DeviceType::CPU, T>(
t_output, DataFormat::NCHW, output, DataFormat::NHWC);
} }
auto actual = net.GetTensor(output.c_str()); auto actual = net.GetTensor(output.c_str());
...@@ -172,8 +174,8 @@ TEST_F(PadTest, ComplexCPU) { ...@@ -172,8 +174,8 @@ TEST_F(PadTest, ComplexCPU) {
// Add input data // Add input data
net.AddRepeatedInput<DeviceType::CPU, float>("Input", {1, 1, 1, 2}, 2); net.AddRepeatedInput<DeviceType::CPU, float>("Input", {1, 1, 1, 2}, 2);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "TInput", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "TInput", DataFormat::NCHW);
OpDefBuilder("Pad", "PadTest") OpDefBuilder("Pad", "PadTest")
.Input("TInput") .Input("TInput")
.Output("TOutput") .Output("TOutput")
...@@ -184,8 +186,8 @@ TEST_F(PadTest, ComplexCPU) { ...@@ -184,8 +186,8 @@ TEST_F(PadTest, ComplexCPU) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("TOutput", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto output = net.GetTensor("Output"); auto output = net.GetTensor("Output");
...@@ -209,8 +211,8 @@ void Complex(const std::vector<index_t> &input_shape, ...@@ -209,8 +211,8 @@ void Complex(const std::vector<index_t> &input_shape,
// Add input data // Add input data
net.AddRandomInput<DeviceType::GPU, float>("Input", input_shape); net.AddRandomInput<DeviceType::GPU, float>("Input", input_shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "TInput", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "TInput", DataFormat::NCHW);
OpDefBuilder("Pad", "PadTest") OpDefBuilder("Pad", "PadTest")
.Input("TInput") .Input("TInput")
.Output("TOutput") .Output("TOutput")
...@@ -222,8 +224,8 @@ void Complex(const std::vector<index_t> &input_shape, ...@@ -222,8 +224,8 @@ void Complex(const std::vector<index_t> &input_shape,
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("TOutput", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "TOutput", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
......
...@@ -270,9 +270,9 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase { ...@@ -270,9 +270,9 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase {
std::vector<int> paddings(2); std::vector<int> paddings(2);
if (paddings_.empty()) { if (paddings_.empty()) {
CalcPaddingAndOutputSize(input_tensor->shape().data(), CalcPaddingAndOutputSize(input_tensor->shape().data(),
NHWC, DataFormat::NHWC,
filter_shape.data(), filter_shape.data(),
OHWI, DataFormat::OHWI,
dilations_.data(), dilations_.data(),
strides_.data(), strides_.data(),
padding_type_, padding_type_,
...@@ -281,9 +281,9 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase { ...@@ -281,9 +281,9 @@ class PoolingOp<DeviceType::CPU, uint8_t> : public PoolingOpBase {
} else { } else {
paddings = paddings_; paddings = paddings_;
CalcOutputSize(input_tensor->shape().data(), CalcOutputSize(input_tensor->shape().data(),
NHWC, DataFormat::NHWC,
filter_shape.data(), filter_shape.data(),
OHWI, DataFormat::OHWI,
paddings_.data(), paddings_.data(),
dilations_.data(), dilations_.data(),
strides_.data(), strides_.data(),
...@@ -477,10 +477,9 @@ class PoolingOp<DeviceType::GPU, T> : public PoolingOpBase { ...@@ -477,10 +477,9 @@ class PoolingOp<DeviceType::GPU, T> : public PoolingOpBase {
public: public:
explicit PoolingOp(OpConstructContext *context) explicit PoolingOp(OpConstructContext *context)
: PoolingOpBase(context) { : PoolingOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::PoolingKernel<T>>(); kernel_ = make_unique<opencl::image::PoolingKernel<T>>();
} else { } else {
context->set_output_mem_type(MemoryType::GPU_BUFFER);
kernel_ = make_unique<opencl::buffer::PoolingKernel<T>>(); kernel_ = make_unique<opencl::buffer::PoolingKernel<T>>();
} }
} }
......
...@@ -34,8 +34,8 @@ TEST_F(PoolingOpTest, MAX_VALID) { ...@@ -34,8 +34,8 @@ TEST_F(PoolingOpTest, MAX_VALID) {
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}); 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -50,8 +50,8 @@ TEST_F(PoolingOpTest, MAX_VALID) { ...@@ -50,8 +50,8 @@ TEST_F(PoolingOpTest, MAX_VALID) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = auto expected =
...@@ -68,8 +68,8 @@ TEST_F(PoolingOpTest, MAX_SAME) { ...@@ -68,8 +68,8 @@ TEST_F(PoolingOpTest, MAX_SAME) {
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 3, 1}, net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 3, 3, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8}); {0, 1, 2, 3, 4, 5, 6, 7, 8});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -84,8 +84,8 @@ TEST_F(PoolingOpTest, MAX_SAME) { ...@@ -84,8 +84,8 @@ TEST_F(PoolingOpTest, MAX_SAME) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 2, 2, 1}, {4, 5, 7, 8}); auto expected = net.CreateTensor<float>({1, 2, 2, 1}, {4, 5, 7, 8});
...@@ -102,8 +102,8 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) { ...@@ -102,8 +102,8 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
"Input", {1, 4, 4, 1}, "Input", {1, 4, 4, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}); {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -118,8 +118,8 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) { ...@@ -118,8 +118,8 @@ TEST_F(PoolingOpTest, MAX_VALID_DILATION) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 2, 2, 1}, {10, 11, 14, 15}); auto expected = net.CreateTensor<float>({1, 2, 2, 1}, {10, 11, 14, 15});
...@@ -136,8 +136,8 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { ...@@ -136,8 +136,8 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
"Input", {1, 2, 9, 1}, "Input", {1, 2, 9, 1},
{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17}); {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -152,8 +152,8 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) { ...@@ -152,8 +152,8 @@ TEST_F(PoolingOpTest, MAX_k2x2s2x2) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 5, 1}, {10, 12, 14, 16, 17}); auto expected = net.CreateTensor<float>({1, 1, 5, 1}, {10, 12, 14, 16, 17});
...@@ -174,8 +174,8 @@ void SimpleMaxPooling3S2() { ...@@ -174,8 +174,8 @@ void SimpleMaxPooling3S2() {
14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26}); 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26});
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
// Run // Run
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -187,8 +187,8 @@ void SimpleMaxPooling3S2() { ...@@ -187,8 +187,8 @@ void SimpleMaxPooling3S2() {
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else if (D == DeviceType::GPU) { } else if (D == DeviceType::GPU) {
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("Input") .Input("Input")
...@@ -224,8 +224,8 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape, ...@@ -224,8 +224,8 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape,
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", input_shape); net.AddRandomInput<D, float>("Input", input_shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -240,8 +240,8 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape, ...@@ -240,8 +240,8 @@ void MaxPooling3S2(const std::vector<index_t> &input_shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
...@@ -304,8 +304,8 @@ TEST_F(PoolingOpTest, AVG_VALID) { ...@@ -304,8 +304,8 @@ TEST_F(PoolingOpTest, AVG_VALID) {
{0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23, {0, 16, 1, 17, 2, 18, 3, 19, 4, 20, 5, 21, 6, 22, 7, 23,
8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31}); 8, 24, 9, 25, 10, 26, 11, 27, 12, 28, 13, 29, 14, 30, 15, 31});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -320,8 +320,8 @@ TEST_F(PoolingOpTest, AVG_VALID) { ...@@ -320,8 +320,8 @@ TEST_F(PoolingOpTest, AVG_VALID) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>( auto expected = net.CreateTensor<float>(
...@@ -373,8 +373,8 @@ void AvgPoolingTest(const std::vector<index_t> &shape, ...@@ -373,8 +373,8 @@ void AvgPoolingTest(const std::vector<index_t> &shape,
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", shape); net.AddRandomInput<D, float>("Input", shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Pooling", "PoolingTest") OpDefBuilder("Pooling", "PoolingTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -389,8 +389,8 @@ void AvgPoolingTest(const std::vector<index_t> &shape, ...@@ -389,8 +389,8 @@ void AvgPoolingTest(const std::vector<index_t> &shape,
// run on cpu // run on cpu
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
...@@ -563,7 +563,7 @@ void TestQuant(const index_t batch, ...@@ -563,7 +563,7 @@ void TestQuant(const index_t batch,
net.AddRandomInput<CPU, float>( net.AddRandomInput<CPU, float>(
"Input", input_shape, false, false); "Input", input_shape, false, false);
net.TransformDataFormat<DeviceType::CPU, float>( net.TransformDataFormat<DeviceType::CPU, float>(
"Input", NHWC, "InputNCHW", NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.AddRandomInput<DeviceType::CPU, float>( net.AddRandomInput<DeviceType::CPU, float>(
"OutputNCHW", input_shape, false, true, true); "OutputNCHW", input_shape, false, true, true);
...@@ -580,7 +580,7 @@ void TestQuant(const index_t batch, ...@@ -580,7 +580,7 @@ void TestQuant(const index_t batch,
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>( net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", NCHW, "Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Quantize", "QuantizeInput") OpDefBuilder("Quantize", "QuantizeInput")
.Input("Input") .Input("Input")
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <set>
#include <vector> #include <vector>
#include "mace/core/future.h" #include "mace/core/future.h"
...@@ -872,7 +873,7 @@ class ReduceOp<DeviceType::GPU, T> : public ReduceOpBase { ...@@ -872,7 +873,7 @@ class ReduceOp<DeviceType::GPU, T> : public ReduceOpBase {
public: public:
explicit ReduceOp(OpConstructContext *context) explicit ReduceOp(OpConstructContext *context)
: ReduceOpBase(context) { : ReduceOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ReduceKernel<T>>(reduce_type_, kernel_ = make_unique<opencl::image::ReduceKernel<T>>(reduce_type_,
axis_, axis_,
keep_dims_); keep_dims_);
...@@ -907,6 +908,34 @@ void RegisterReduce(OpRegistryBase *op_registry) { ...@@ -907,6 +908,34 @@ void RegisterReduce(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp, MACE_REGISTER_OP(op_registry, "Reduce", ReduceOp,
DeviceType::GPU, half); DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("Reduce")
.SetDevicePlacerFunc(
[](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
bool keep_dims =
ProtoArgHelper::GetOptionalArg<OperatorDef, bool>(
*op, "keepdims", false);
if (!keep_dims) {
return { DeviceType::CPU };
}
auto axis =
ProtoArgHelper::GetRepeatedArgs<OperatorDef, int>(
*op, "axis");
if (axis.size() != 2 || axis[0] != 1 || axis[1] != 2) {
return { DeviceType::CPU };
}
auto tensor_shape_info = context->tensor_shape_info();
if (tensor_shape_info->count(op->input(0)) == 0
|| tensor_shape_info->at(op->input(0)).size() != 4) {
return { DeviceType::CPU };
}
return { DeviceType::CPU, DeviceType::GPU };
}));
} }
} // namespace ops } // namespace ops
......
...@@ -38,7 +38,8 @@ void Simple(const std::vector<index_t> &input_shape, ...@@ -38,7 +38,8 @@ void Simple(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>("Input", input_shape, input); net.AddInputFromArray<D, float>("Input", input_shape, input);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<D, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<D, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Reduce", "ReduceTest") OpDefBuilder("Reduce", "ReduceTest")
.Input("InputNCHW") .Input("InputNCHW")
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
...@@ -49,7 +50,8 @@ void Simple(const std::vector<index_t> &input_shape, ...@@ -49,7 +50,8 @@ void Simple(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<D, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<D, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else { } else {
OpDefBuilder("Reduce", "ReduceTest") OpDefBuilder("Reduce", "ReduceTest")
.Input("Input") .Input("Input")
...@@ -289,8 +291,8 @@ void RandomTest(const std::vector<index_t> &input_shape, ...@@ -289,8 +291,8 @@ void RandomTest(const std::vector<index_t> &input_shape,
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", input_shape); net.AddRandomInput<D, float>("Input", input_shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Reduce", "ReduceTest") OpDefBuilder("Reduce", "ReduceTest")
.Input("InputNCHW") .Input("InputNCHW")
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
...@@ -301,8 +303,8 @@ void RandomTest(const std::vector<index_t> &input_shape, ...@@ -301,8 +303,8 @@ void RandomTest(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Reduce", "ReduceTest") OpDefBuilder("Reduce", "ReduceTest")
.Input("Input") .Input("Input")
.AddIntsArg("axis", axis) .AddIntsArg("axis", axis)
...@@ -353,7 +355,7 @@ void TestQuant(const std::vector<index_t> &input_shape, ...@@ -353,7 +355,7 @@ void TestQuant(const std::vector<index_t> &input_shape,
net.AddRandomInput<CPU, float>( net.AddRandomInput<CPU, float>(
"Input", input_shape, false, false); "Input", input_shape, false, false);
net.TransformDataFormat<DeviceType::CPU, float>( net.TransformDataFormat<DeviceType::CPU, float>(
"Input", NHWC, "InputNCHW", NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.AddRandomInput<DeviceType::CPU, float>( net.AddRandomInput<DeviceType::CPU, float>(
"OutputNCHW", input_shape, false, true, true); "OutputNCHW", input_shape, false, true, true);
...@@ -368,7 +370,7 @@ void TestQuant(const std::vector<index_t> &input_shape, ...@@ -368,7 +370,7 @@ void TestQuant(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>( net.TransformDataFormat<DeviceType::CPU, float>(
"OutputNCHW", NCHW, "Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("Quantize", "QuantizeInput") OpDefBuilder("Quantize", "QuantizeInput")
.Input("Input") .Input("Input")
......
...@@ -51,7 +51,7 @@ MaceStatus Deconv2d<float>::Compute(const OpContext *context, ...@@ -51,7 +51,7 @@ MaceStatus Deconv2d<float>::Compute(const OpContext *context,
&out_pad_size, &out_pad_size,
&padded_out_shape, &padded_out_shape,
framework_type_, framework_type_,
NCHW); DataFormat::NCHW);
MACE_RETURN_IF_ERROR(output->Resize(out_shape)); MACE_RETURN_IF_ERROR(output->Resize(out_shape));
......
...@@ -50,7 +50,7 @@ MaceStatus DepthwiseDeconv2d<float>::Compute(const OpContext *context, ...@@ -50,7 +50,7 @@ MaceStatus DepthwiseDeconv2d<float>::Compute(const OpContext *context,
&out_pad_size, &out_pad_size,
&padded_out_shape, &padded_out_shape,
framework_type_, framework_type_,
NCHW); DataFormat::NCHW);
MACE_RETURN_IF_ERROR(output->Resize(out_shape)); MACE_RETURN_IF_ERROR(output->Resize(out_shape));
...@@ -185,7 +185,7 @@ MaceStatus GroupDeconv2d<float>::Compute(const OpContext *context, ...@@ -185,7 +185,7 @@ MaceStatus GroupDeconv2d<float>::Compute(const OpContext *context,
&out_pad_size, &out_pad_size,
&padded_out_shape, &padded_out_shape,
framework_type_, framework_type_,
NCHW); DataFormat::NCHW);
MACE_RETURN_IF_ERROR(output->Resize(out_shape)); MACE_RETURN_IF_ERROR(output->Resize(out_shape));
......
...@@ -212,7 +212,7 @@ class ResizeBicubicOp<DeviceType::GPU, T> : public Operation { ...@@ -212,7 +212,7 @@ class ResizeBicubicOp<DeviceType::GPU, T> : public Operation {
std::vector<index_t> size = Operation::GetRepeatedArgs<index_t>( std::vector<index_t> size = Operation::GetRepeatedArgs<index_t>(
"size", {-1, -1}); "size", {-1, -1});
MACE_CHECK(size.size() == 2); MACE_CHECK(size.size() == 2);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ResizeBicubicKernel<T>>( kernel_ = make_unique<opencl::image::ResizeBicubicKernel<T>>(
align_corners, size[0], size[1]); align_corners, size[0], size[1]);
} else { } else {
......
...@@ -31,8 +31,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) { ...@@ -31,8 +31,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) {
std::vector<float> input(24); std::vector<float> input(24);
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -42,8 +42,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) { ...@@ -42,8 +42,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCorners) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8}); auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8});
...@@ -60,8 +60,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) { ...@@ -60,8 +60,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) {
std::vector<float> input(48); std::vector<float> input(48);
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 4, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 4, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -71,8 +71,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) { ...@@ -71,8 +71,8 @@ TEST_F(ResizeBicubicTest, CPUResizeBicubicWOAlignCornersFloat) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 2, 3, 3}, auto expected = net.CreateTensor<float>({1, 2, 3, 3},
...@@ -92,8 +92,8 @@ TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) { ...@@ -92,8 +92,8 @@ TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) {
std::vector<float> input(24); std::vector<float> input(24);
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -104,8 +104,8 @@ TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) { ...@@ -104,8 +104,8 @@ TEST_F(ResizeBicubicTest, ResizeBicubicWAlignCorners) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11}); auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11});
...@@ -133,8 +133,8 @@ void TestRandomResizeBicubic() { ...@@ -133,8 +133,8 @@ void TestRandomResizeBicubic() {
net.AddRandomInput<D, float>("Input", net.AddRandomInput<D, float>("Input",
{batch, in_height, in_width, channels}, {batch, in_height, in_width, channels},
false, true, true); false, true, true);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBicubic", "ResizeBicubicTest") OpDefBuilder("ResizeBicubic", "ResizeBicubicTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -144,8 +144,8 @@ void TestRandomResizeBicubic() { ...@@ -144,8 +144,8 @@ void TestRandomResizeBicubic() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on CPU // Run on CPU
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
Tensor expected; Tensor expected;
expected.Copy(*net.GetOutput("Output")); expected.Copy(*net.GetOutput("Output"));
......
...@@ -346,7 +346,7 @@ class ResizeBilinearOp<DeviceType::GPU, T> : public Operation { ...@@ -346,7 +346,7 @@ class ResizeBilinearOp<DeviceType::GPU, T> : public Operation {
std::vector<index_t> size = Operation::GetRepeatedArgs<index_t>( std::vector<index_t> size = Operation::GetRepeatedArgs<index_t>(
"size", {-1, -1}); "size", {-1, -1});
MACE_CHECK(size.size() == 2); MACE_CHECK(size.size() == 2);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ResizeBilinearKernel<T>>( kernel_ = make_unique<opencl::image::ResizeBilinearKernel<T>>(
align_corners, size[0], size[1]); align_corners, size[0], size[1]);
} else { } else {
......
...@@ -31,8 +31,8 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) { ...@@ -31,8 +31,8 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) {
std::vector<float> input(24); std::vector<float> input(24);
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -42,8 +42,8 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) { ...@@ -42,8 +42,8 @@ TEST_F(ResizeBilinearTest, CPUResizeBilinearWOAlignCorners) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8}); auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8});
...@@ -60,8 +60,8 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) { ...@@ -60,8 +60,8 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) {
std::vector<float> input(24); std::vector<float> input(24);
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -72,8 +72,8 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) { ...@@ -72,8 +72,8 @@ TEST_F(ResizeBilinearTest, ResizeBilinearWAlignCorners) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11}); auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11});
...@@ -100,8 +100,8 @@ void TestRandomResizeBilinear() { ...@@ -100,8 +100,8 @@ void TestRandomResizeBilinear() {
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", net.AddRandomInput<D, float>("Input",
{batch, in_height, in_width, channels}); {batch, in_height, in_width, channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -111,8 +111,8 @@ void TestRandomResizeBilinear() { ...@@ -111,8 +111,8 @@ void TestRandomResizeBilinear() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on CPU // Run on CPU
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
...@@ -155,8 +155,8 @@ void TestQuantizedResizeBilinear() { ...@@ -155,8 +155,8 @@ void TestQuantizedResizeBilinear() {
true, true,
-1.f, -1.f,
1.f); 1.f);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("ResizeBilinear", "ResizeBilinearTest") OpDefBuilder("ResizeBilinear", "ResizeBilinearTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -166,8 +166,8 @@ void TestQuantizedResizeBilinear() { ...@@ -166,8 +166,8 @@ void TestQuantizedResizeBilinear() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on CPU // Run on CPU
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// run quantize // run quantize
OpDefBuilder("Quantize", "QuantizeInput") OpDefBuilder("Quantize", "QuantizeInput")
......
...@@ -149,7 +149,7 @@ class ResizeNearestNeighborOp<DeviceType::GPU, T> : public Operation { ...@@ -149,7 +149,7 @@ class ResizeNearestNeighborOp<DeviceType::GPU, T> : public Operation {
: Operation(context) { : Operation(context) {
bool align_corners = Operation::GetOptionalArg<bool>( bool align_corners = Operation::GetOptionalArg<bool>(
"align_corners", false); "align_corners", false);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::ResizeNearestNeighborKernel<T>>( kernel_ = make_unique<opencl::image::ResizeNearestNeighborKernel<T>>(
align_corners); align_corners);
} else { } else {
......
...@@ -32,8 +32,8 @@ TEST_F(ResizeNearestNeighborTest, CPUResizeNearestNeighborWOAlignCorners) { ...@@ -32,8 +32,8 @@ TEST_F(ResizeNearestNeighborTest, CPUResizeNearestNeighborWOAlignCorners) {
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
std::vector<int32_t> size = {1, 2}; std::vector<int32_t> size = {1, 2};
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.AddInputFromArray<DeviceType::CPU, int32_t>("Size", {2}, size); net.AddInputFromArray<DeviceType::CPU, int32_t>("Size", {2}, size);
OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest") OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest")
...@@ -45,8 +45,8 @@ TEST_F(ResizeNearestNeighborTest, CPUResizeNearestNeighborWOAlignCorners) { ...@@ -45,8 +45,8 @@ TEST_F(ResizeNearestNeighborTest, CPUResizeNearestNeighborWOAlignCorners) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8}); auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 6, 7, 8});
...@@ -64,8 +64,8 @@ TEST_F(ResizeNearestNeighborTest, ResizeNearestNeighborWAlignCorners) { ...@@ -64,8 +64,8 @@ TEST_F(ResizeNearestNeighborTest, ResizeNearestNeighborWAlignCorners) {
std::iota(begin(input), end(input), 0); std::iota(begin(input), end(input), 0);
std::vector<int32_t> size = {1, 2}; std::vector<int32_t> size = {1, 2};
net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input); net.AddInputFromArray<DeviceType::CPU, float>("Input", {1, 2, 4, 3}, input);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.AddInputFromArray<DeviceType::CPU, int32_t>("Size", {2}, size); net.AddInputFromArray<DeviceType::CPU, int32_t>("Size", {2}, size);
OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest") OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest")
...@@ -78,8 +78,8 @@ TEST_F(ResizeNearestNeighborTest, ResizeNearestNeighborWAlignCorners) { ...@@ -78,8 +78,8 @@ TEST_F(ResizeNearestNeighborTest, ResizeNearestNeighborWAlignCorners) {
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
// Check // Check
auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11}); auto expected = net.CreateTensor<float>({1, 1, 2, 3}, {0, 1, 2, 9, 10, 11});
...@@ -105,8 +105,8 @@ void TestRandomResizeNearestNeighbor() { ...@@ -105,8 +105,8 @@ void TestRandomResizeNearestNeighbor() {
std::vector<int32_t> size = {20, 40}; std::vector<int32_t> size = {20, 40};
net.AddRandomInput<D, float>("Input", net.AddRandomInput<D, float>("Input",
{batch, in_height, in_width, channels}); {batch, in_height, in_width, channels});
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
net.AddInputFromArray<D, int32_t>("Size", {2}, size); net.AddInputFromArray<D, int32_t>("Size", {2}, size);
OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest") OpDefBuilder("ResizeNearestNeighbor", "ResizeNearestNeighborTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -116,8 +116,8 @@ void TestRandomResizeNearestNeighbor() { ...@@ -116,8 +116,8 @@ void TestRandomResizeNearestNeighbor() {
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run on CPU // Run on CPU
net.RunOp(DeviceType::CPU); net.RunOp(DeviceType::CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
expected->Copy(*net.GetOutput("Output")); expected->Copy(*net.GetOutput("Output"));
......
...@@ -100,11 +100,7 @@ class ScalarMathOp : public Operation { ...@@ -100,11 +100,7 @@ class ScalarMathOp : public Operation {
coeff_(Operation::GetRepeatedArgs<float>("coeff")), coeff_(Operation::GetRepeatedArgs<float>("coeff")),
scalar_input_(Operation::GetOptionalArg<float>("scalar_input", 1.0)), scalar_input_(Operation::GetOptionalArg<float>("scalar_input", 1.0)),
scalar_input_index_(Operation::GetOptionalArg<int32_t>( scalar_input_index_(Operation::GetOptionalArg<int32_t>(
"scalar_input_index", 1)) { "scalar_input_index", 1)) {}
if (D == DeviceType::GPU) {
context->set_output_mem_type(MemoryType::GPU_BUFFER);
}
}
MaceStatus Run(OpContext *context) override { MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context); MACE_UNUSED(context);
......
...@@ -414,10 +414,9 @@ class SoftmaxOp<DeviceType::GPU, T> : public Operation { ...@@ -414,10 +414,9 @@ class SoftmaxOp<DeviceType::GPU, T> : public Operation {
: Operation(context) { : Operation(context) {
bool use_log = ( bool use_log = (
Operation::GetOptionalArg<bool>("use_log", false)); Operation::GetOptionalArg<bool>("use_log", false));
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::SoftmaxKernel<T>>(use_log); kernel_ = make_unique<opencl::image::SoftmaxKernel<T>>(use_log);
} else { } else {
context->set_output_mem_type(MemoryType::GPU_BUFFER);
kernel_ = make_unique<opencl::buffer::SoftmaxKernel<T>>(use_log); kernel_ = make_unique<opencl::buffer::SoftmaxKernel<T>>(use_log);
} }
} }
...@@ -456,7 +455,7 @@ void RegisterSoftmax(OpRegistryBase *op_registry) { ...@@ -456,7 +455,7 @@ void RegisterSoftmax(OpRegistryBase *op_registry) {
op_registry, op_registry,
OpConditionBuilder("Softmax") OpConditionBuilder("Softmax")
.SetDevicePlacerFunc( .SetDevicePlacerFunc(
[](OpConstructContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) { if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU }; return { DeviceType::CPU, DeviceType::GPU };
......
...@@ -50,7 +50,8 @@ void Simple(bool use_log = false) { ...@@ -50,7 +50,8 @@ void Simple(bool use_log = false) {
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
// test 4d softmax // test 4d softmax
net.TransformDataFormat<CPU, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Softmax", "SoftmaxTest") OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -59,7 +60,8 @@ void Simple(bool use_log = false) { ...@@ -59,7 +60,8 @@ void Simple(bool use_log = false) {
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5);
...@@ -109,7 +111,8 @@ void Complex(const std::vector<index_t> &logits_shape, ...@@ -109,7 +111,8 @@ void Complex(const std::vector<index_t> &logits_shape,
net.AddRandomInput<D, float>("Input", logits_shape); net.AddRandomInput<D, float>("Input", logits_shape);
if (logits_shape.size() == 4) { if (logits_shape.size() == 4) {
net.TransformDataFormat<CPU, float>("Input", NHWC, "InputNCHW", NCHW); net.TransformDataFormat<CPU, float>(
"Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("Softmax", "SoftmaxTest") OpDefBuilder("Softmax", "SoftmaxTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -127,7 +130,8 @@ void Complex(const std::vector<index_t> &logits_shape, ...@@ -127,7 +130,8 @@ void Complex(const std::vector<index_t> &logits_shape,
net.RunOp(); net.RunOp();
if (logits_shape.size() == 4) { if (logits_shape.size() == 4) {
net.TransformDataFormat<CPU, float>("OutputNCHW", NCHW, "Output", NHWC); net.TransformDataFormat<CPU, float>(
"OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
auto expected = net.CreateTensor<float>(); auto expected = net.CreateTensor<float>();
......
...@@ -307,7 +307,7 @@ class SpaceToBatchNDOp<DeviceType::GPU, T> : public SpaceToBatchOpBase { ...@@ -307,7 +307,7 @@ class SpaceToBatchNDOp<DeviceType::GPU, T> : public SpaceToBatchOpBase {
public: public:
explicit SpaceToBatchNDOp(OpConstructContext *context) explicit SpaceToBatchNDOp(OpConstructContext *context)
: SpaceToBatchOpBase(context) { : SpaceToBatchOpBase(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::SpaceToBatchKernel<T>>(); kernel_ = make_unique<opencl::image::SpaceToBatchKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
......
...@@ -39,8 +39,8 @@ void RunSpaceToBatch(const std::vector<index_t> &input_shape, ...@@ -39,8 +39,8 @@ void RunSpaceToBatch(const std::vector<index_t> &input_shape,
.AddIntsArg("block_shape", block_shape_data) .AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else if (D == CPU) { } else if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest") OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -53,8 +53,8 @@ void RunSpaceToBatch(const std::vector<index_t> &input_shape, ...@@ -53,8 +53,8 @@ void RunSpaceToBatch(const std::vector<index_t> &input_shape,
net.RunOp(D); net.RunOp(D);
if (D == CPU) { if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
// Check // Check
ExpectTensorNear<float>(*expected, *net.GetOutput("Output")); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"));
...@@ -78,8 +78,8 @@ void RunBatchToSpace(const std::vector<index_t> &input_shape, ...@@ -78,8 +78,8 @@ void RunBatchToSpace(const std::vector<index_t> &input_shape,
.AddIntsArg("block_shape", block_shape_data) .AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else if (D == CPU) { } else if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest") OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -92,8 +92,8 @@ void RunBatchToSpace(const std::vector<index_t> &input_shape, ...@@ -92,8 +92,8 @@ void RunBatchToSpace(const std::vector<index_t> &input_shape,
net.RunOp(D); net.RunOp(D);
if (D == CPU) { if (D == CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} }
// Check // Check
ExpectTensorNear<float>(*expected, *net.GetOutput("Output")); ExpectTensorNear<float>(*expected, *net.GetOutput("Output"));
...@@ -155,8 +155,8 @@ void TestSpaceToBatchLargeInput(const std::vector<index_t> &input_shape, ...@@ -155,8 +155,8 @@ void TestSpaceToBatchLargeInput(const std::vector<index_t> &input_shape,
net.RunOp(GPU); net.RunOp(GPU);
// run cpu // run cpu
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest") OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -164,8 +164,8 @@ void TestSpaceToBatchLargeInput(const std::vector<index_t> &input_shape, ...@@ -164,8 +164,8 @@ void TestSpaceToBatchLargeInput(const std::vector<index_t> &input_shape,
.AddIntsArg("block_shape", block_shape_data) .AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"OutputCPU", NHWC); "OutputNCHW", DataFormat::NCHW, "OutputCPU", DataFormat::NHWC);
// Check // Check
ExpectTensorNear<float>(*net.GetOutput("OutputCPU"), ExpectTensorNear<float>(*net.GetOutput("OutputCPU"),
...@@ -188,8 +188,8 @@ void TestoBatchToSpaceLargeInput(const std::vector<index_t> &input_shape, ...@@ -188,8 +188,8 @@ void TestoBatchToSpaceLargeInput(const std::vector<index_t> &input_shape,
net.RunOp(GPU); net.RunOp(GPU);
// run cpu // run cpu
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest") OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -197,8 +197,8 @@ void TestoBatchToSpaceLargeInput(const std::vector<index_t> &input_shape, ...@@ -197,8 +197,8 @@ void TestoBatchToSpaceLargeInput(const std::vector<index_t> &input_shape,
.AddIntsArg("block_shape", block_shape_data) .AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"OutputCPU", NHWC); "OutputNCHW", DataFormat::NCHW, "OutputCPU", DataFormat::NHWC);
// Check // Check
ExpectTensorNear<float>(*net.GetOutput("OutputCPU"), ExpectTensorNear<float>(*net.GetOutput("OutputCPU"),
...@@ -218,8 +218,8 @@ void TestSpaceToBatchQuantize(const std::vector<index_t> &input_shape, ...@@ -218,8 +218,8 @@ void TestSpaceToBatchQuantize(const std::vector<index_t> &input_shape,
1.f); 1.f);
// run cpu // run cpu
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest") OpDefBuilder("SpaceToBatchND", "SpaceToBatchNDTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -227,8 +227,8 @@ void TestSpaceToBatchQuantize(const std::vector<index_t> &input_shape, ...@@ -227,8 +227,8 @@ void TestSpaceToBatchQuantize(const std::vector<index_t> &input_shape,
.AddIntsArg("block_shape", block_shape_data) .AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"OutputCPU", NHWC); "OutputNCHW", DataFormat::NCHW, "OutputCPU", DataFormat::NHWC);
// run quantize // run quantize
OpDefBuilder("Quantize", "QuantizeInput") OpDefBuilder("Quantize", "QuantizeInput")
...@@ -279,8 +279,8 @@ void TestoBatchToSpaceQuantize(const std::vector<index_t> &input_shape, ...@@ -279,8 +279,8 @@ void TestoBatchToSpaceQuantize(const std::vector<index_t> &input_shape,
1.f); 1.f);
// run cpu // run cpu
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest") OpDefBuilder("BatchToSpaceND", "BatchToSpaceNDTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -288,8 +288,8 @@ void TestoBatchToSpaceQuantize(const std::vector<index_t> &input_shape, ...@@ -288,8 +288,8 @@ void TestoBatchToSpaceQuantize(const std::vector<index_t> &input_shape,
.AddIntsArg("block_shape", block_shape_data) .AddIntsArg("block_shape", block_shape_data)
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
net.RunOp(CPU); net.RunOp(CPU);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"OutputCPU", NHWC); "OutputNCHW", DataFormat::NCHW, "OutputCPU", DataFormat::NHWC);
// run quantize // run quantize
OpDefBuilder("Quantize", "QuantizeInput") OpDefBuilder("Quantize", "QuantizeInput")
......
...@@ -94,7 +94,7 @@ class SpaceToDepthOp<DeviceType::GPU, T> : public Operation { ...@@ -94,7 +94,7 @@ class SpaceToDepthOp<DeviceType::GPU, T> : public Operation {
explicit SpaceToDepthOp(OpConstructContext *context) explicit SpaceToDepthOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
int block_size = Operation::GetOptionalArg<int>("block_size", 1); int block_size = Operation::GetOptionalArg<int>("block_size", 1);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::SpaceToDepthKernel<T>>(block_size); kernel_ = make_unique<opencl::image::SpaceToDepthKernel<T>>(block_size);
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
......
...@@ -32,8 +32,8 @@ void RunSpaceToDepth(const std::vector<index_t> &input_shape, ...@@ -32,8 +32,8 @@ void RunSpaceToDepth(const std::vector<index_t> &input_shape,
net.AddInputFromArray<D, float>("Input", input_shape, input_data); net.AddInputFromArray<D, float>("Input", input_shape, input_data);
// Construct graph // Construct graph
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("SpaceToDepth", "SpaceToDepthTest") OpDefBuilder("SpaceToDepth", "SpaceToDepthTest")
.Input("InputNCHW") .Input("InputNCHW")
.Output("OutputNCHW") .Output("OutputNCHW")
...@@ -41,8 +41,8 @@ void RunSpaceToDepth(const std::vector<index_t> &input_shape, ...@@ -41,8 +41,8 @@ void RunSpaceToDepth(const std::vector<index_t> &input_shape,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
} else { } else {
OpDefBuilder("SpaceToDepth", "SpaceToDepthTest") OpDefBuilder("SpaceToDepth", "SpaceToDepthTest")
...@@ -107,8 +107,8 @@ void RandomTest(const int block_size, ...@@ -107,8 +107,8 @@ void RandomTest(const int block_size,
// Add input data // Add input data
net.AddRandomInput<D, float>("Input", shape); net.AddRandomInput<D, float>("Input", shape);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("SpaceToDepth", "SpaceToDepthTest") OpDefBuilder("SpaceToDepth", "SpaceToDepthTest")
.Input("InputNCHW") .Input("InputNCHW")
.AddIntArg("block_size", block_size) .AddIntArg("block_size", block_size)
...@@ -118,8 +118,8 @@ void RandomTest(const int block_size, ...@@ -118,8 +118,8 @@ void RandomTest(const int block_size,
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("SpaceToDepth", "SpaceToDepthTest") OpDefBuilder("SpaceToDepth", "SpaceToDepthTest")
.Input("Input") .Input("Input")
......
...@@ -106,7 +106,7 @@ class SplitOp<DeviceType::GPU, T> : public Operation { ...@@ -106,7 +106,7 @@ class SplitOp<DeviceType::GPU, T> : public Operation {
explicit SplitOp(OpConstructContext *context) explicit SplitOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
int32_t axis = Operation::GetOptionalArg<int>("axis", 3); int32_t axis = Operation::GetOptionalArg<int>("axis", 3);
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::SplitKernel<T>>(axis); kernel_ = make_unique<opencl::image::SplitKernel<T>>(axis);
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
...@@ -144,7 +144,7 @@ void RegisterSplit(OpRegistryBase *op_registry) { ...@@ -144,7 +144,7 @@ void RegisterSplit(OpRegistryBase *op_registry) {
op_registry, op_registry,
OpConditionBuilder("Split") OpConditionBuilder("Split")
.SetDevicePlacerFunc( .SetDevicePlacerFunc(
[](OpConstructContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) { if (op->output_shape_size() != op->output_size()) {
return {DeviceType::CPU, DeviceType::GPU}; return {DeviceType::CPU, DeviceType::GPU};
......
...@@ -83,7 +83,7 @@ class SqrDiffMeanOp<DeviceType::GPU, T> : public Operation { ...@@ -83,7 +83,7 @@ class SqrDiffMeanOp<DeviceType::GPU, T> : public Operation {
public: public:
explicit SqrDiffMeanOp(OpConstructContext *context) explicit SqrDiffMeanOp(OpConstructContext *context)
: Operation(context) { : Operation(context) {
if (context->device()->gpu_runtime()->UseImageMemory()) { if (context->GetOpMemoryType() == MemoryType::GPU_IMAGE) {
kernel_ = make_unique<opencl::image::SqrDiffMeanKernel<T>>(); kernel_ = make_unique<opencl::image::SqrDiffMeanKernel<T>>();
} else { } else {
MACE_NOT_IMPLEMENTED; MACE_NOT_IMPLEMENTED;
......
...@@ -36,13 +36,13 @@ void Simple(const std::vector<index_t> &input_shape0, ...@@ -36,13 +36,13 @@ void Simple(const std::vector<index_t> &input_shape0,
net.AddInputFromArray<D, float>("Input1", input_shape1, input1); net.AddInputFromArray<D, float>("Input1", input_shape1, input1);
net.TransformDataFormat<DeviceType::CPU, float>("Input0", net.TransformDataFormat<DeviceType::CPU, float>("Input0",
NHWC, DataFormat::NHWC,
"InputNCHW0", "InputNCHW0",
NCHW); DataFormat::NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Input1", net.TransformDataFormat<DeviceType::CPU, float>("Input1",
NHWC, DataFormat::NHWC,
"InputNCHW1", "InputNCHW1",
NCHW); DataFormat::NCHW);
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest") OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest")
...@@ -54,9 +54,9 @@ void Simple(const std::vector<index_t> &input_shape0, ...@@ -54,9 +54,9 @@ void Simple(const std::vector<index_t> &input_shape0,
net.RunOp(D); net.RunOp(D);
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW",
NCHW, DataFormat::NCHW,
"Output", "Output",
NHWC); DataFormat::NHWC);
} else { } else {
OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest") OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest")
.Input("Input0") .Input("Input0")
...@@ -107,10 +107,10 @@ void RandomTest(const std::vector<index_t> &input_shape0, ...@@ -107,10 +107,10 @@ void RandomTest(const std::vector<index_t> &input_shape0,
net.AddRandomInput<D, float>("Input0", input_shape0); net.AddRandomInput<D, float>("Input0", input_shape0);
net.AddRandomInput<D, float>("Input1", input_shape1); net.AddRandomInput<D, float>("Input1", input_shape1);
net.TransformDataFormat<DeviceType::CPU, float>("Input0", NHWC, "InputNCHW0", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input0", DataFormat::NHWC, "InputNCHW0", DataFormat::NCHW);
net.TransformDataFormat<DeviceType::CPU, float>("Input1", NHWC, "InputNCHW1", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input1", DataFormat::NHWC, "InputNCHW1", DataFormat::NCHW);
OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest") OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest")
.Input("InputNCHW0") .Input("InputNCHW0")
.Input("InputNCHW1") .Input("InputNCHW1")
...@@ -118,8 +118,8 @@ void RandomTest(const std::vector<index_t> &input_shape0, ...@@ -118,8 +118,8 @@ void RandomTest(const std::vector<index_t> &input_shape0,
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
// Run // Run
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, net.TransformDataFormat<DeviceType::CPU, float>(
"Output", NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest") OpDefBuilder("SqrDiffMean", "SqrDiffMeanTest")
.Input("Input0") .Input("Input0")
.Input("Input1") .Input("Input1")
......
...@@ -77,7 +77,7 @@ void RegisterSqueeze(OpRegistryBase *op_registry) { ...@@ -77,7 +77,7 @@ void RegisterSqueeze(OpRegistryBase *op_registry) {
op_registry, op_registry,
OpConditionBuilder("Squeeze") OpConditionBuilder("Squeeze")
.SetDevicePlacerFunc( .SetDevicePlacerFunc(
[](OpConstructContext *context) -> std::set<DeviceType> { [](OpConditionContext *context) -> std::set<DeviceType> {
auto op = context->operator_def(); auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) { if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU }; return { DeviceType::CPU, DeviceType::GPU };
......
...@@ -86,8 +86,8 @@ void TestStridedSliceWithDataFormat(const std::vector<index_t> &input_shape, ...@@ -86,8 +86,8 @@ void TestStridedSliceWithDataFormat(const std::vector<index_t> &input_shape,
net.AddInputFromArray<CPU, int32_t>( net.AddInputFromArray<CPU, int32_t>(
"Strides", {static_cast<int32_t>(strides.size())}, strides); "Strides", {static_cast<int32_t>(strides.size())}, strides);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("StridedSlice", "StridedSliceOpTest") OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -105,8 +105,8 @@ void TestStridedSliceWithDataFormat(const std::vector<index_t> &input_shape, ...@@ -105,8 +105,8 @@ void TestStridedSliceWithDataFormat(const std::vector<index_t> &input_shape,
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output); net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"), ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output")); *net.GetOutput("Output"));
...@@ -154,8 +154,8 @@ void TestSliceWithDataFormat(const std::vector<index_t> &input_shape, ...@@ -154,8 +154,8 @@ void TestSliceWithDataFormat(const std::vector<index_t> &input_shape,
net.AddInputFromArray<CPU, int32_t>( net.AddInputFromArray<CPU, int32_t>(
"IndicesSize", {static_cast<int32_t>(indices_size.size())}, indices_size); "IndicesSize", {static_cast<int32_t>(indices_size.size())}, indices_size);
net.TransformDataFormat<DeviceType::CPU, float>("Input", NHWC, "InputNCHW", net.TransformDataFormat<DeviceType::CPU, float>(
NCHW); "Input", DataFormat::NHWC, "InputNCHW", DataFormat::NCHW);
OpDefBuilder("StridedSlice", "StridedSliceOpTest") OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("InputNCHW") .Input("InputNCHW")
...@@ -168,8 +168,8 @@ void TestSliceWithDataFormat(const std::vector<index_t> &input_shape, ...@@ -168,8 +168,8 @@ void TestSliceWithDataFormat(const std::vector<index_t> &input_shape,
net.RunOp(); net.RunOp();
net.TransformDataFormat<DeviceType::CPU, float>("OutputNCHW", NCHW, "Output", net.TransformDataFormat<DeviceType::CPU, float>(
NHWC); "OutputNCHW", DataFormat::NCHW, "Output", DataFormat::NHWC);
net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output); net.AddInputFromArray<CPU, float>("ExpectedOutput", output_shape, output);
ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"), ExpectTensorNear<float>(*net.GetOutput("ExpectedOutput"),
*net.GetOutput("Output")); *net.GetOutput("Output"));
......
...@@ -34,9 +34,10 @@ class NetDef; ...@@ -34,9 +34,10 @@ class NetDef;
enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3, HTA = 4 }; enum DeviceType { CPU = 0, GPU = 2, HEXAGON = 3, HTA = 4 };
enum DataFormat { enum class DataFormat {
DF_NONE = 0, NHWC = 1, NCHW = 2, NONE = 0, NHWC = 1, NCHW = 2,
HWOI = 100, OIHW = 101, HWIO = 102, OHWI = 103 HWOI = 100, OIHW = 101, HWIO = 102, OHWI = 103,
AUTO = 1000,
}; };
enum GPUPerfHint { enum GPUPerfHint {
......
...@@ -41,7 +41,7 @@ device_type_map = {'cpu': cvt.DeviceType.CPU.value, ...@@ -41,7 +41,7 @@ device_type_map = {'cpu': cvt.DeviceType.CPU.value,
'cpu+gpu': cvt.DeviceType.CPU.value} 'cpu+gpu': cvt.DeviceType.CPU.value}
data_format_map = { data_format_map = {
'NONE': cvt.DataFormat.DF_NONE, 'NONE': cvt.DataFormat.NONE,
'NHWC': cvt.DataFormat.NHWC, 'NHWC': cvt.DataFormat.NHWC,
'NCHW': cvt.DataFormat.NCHW, 'NCHW': cvt.DataFormat.NCHW,
'OIHW': cvt.DataFormat.OIHW, 'OIHW': cvt.DataFormat.OIHW,
......
...@@ -26,13 +26,14 @@ class DeviceType(Enum): ...@@ -26,13 +26,14 @@ class DeviceType(Enum):
class DataFormat(Enum): class DataFormat(Enum):
DF_NONE = 0 NONE = 0
NHWC = 1 NHWC = 1
NCHW = 2 NCHW = 2
HWIO = 100 HWIO = 100
OIHW = 101 OIHW = 101
HWOI = 102 HWOI = 102
OHWI = 103 OHWI = 103
AUTO = 1000
# SAME_LOWER: if the amount of paddings to be added is odd, # SAME_LOWER: if the amount of paddings to be added is odd,
...@@ -161,13 +162,39 @@ MaceSupportedOps = [ ...@@ -161,13 +162,39 @@ MaceSupportedOps = [
'SumGroup', 'SumGroup',
'TargetRMSNorm', 'TargetRMSNorm',
'Transpose', 'Transpose',
'WinogradInverseTransform',
'WinogradTransform',
'Cumsum', 'Cumsum',
] ]
MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str) MaceOp = Enum('MaceOp', [(op, op) for op in MaceSupportedOps], type=str)
MaceHasDataFormatOps = [MaceOp.BatchNorm,
MaceOp.BatchToSpaceND,
MaceOp.Conv2D,
MaceOp.Deconv2D,
MaceOp.DepthToSpace,
MaceOp.DepthwiseConv2d,
MaceOp.DepthwiseDeconv2d,
MaceOp.FullyConnected,
MaceOp.Pooling,
MaceOp.ResizeBicubic,
MaceOp.ResizeBilinear,
MaceOp.ResizeNearestNeighbor,
MaceOp.SpaceToBatchND,
MaceOp.SpaceToDepth]
MaceMayHasDataFormatOps = [MaceOp.Activation,
MaceOp.AddN,
MaceOp.BiasAdd,
MaceOp.ChannelShuffle,
MaceOp.Concat,
MaceOp.Crop,
MaceOp.Eltwise,
MaceOp.Pad,
MaceOp.Reduce,
MaceOp.Softmax,
MaceOp.Split,
MaceOp.SqrDiffMean]
class MaceKeyword(object): class MaceKeyword(object):
# node related str # node related str
...@@ -505,12 +532,11 @@ class ConverterOption(object): ...@@ -505,12 +532,11 @@ class ConverterOption(object):
TransformerRule.TRANSFORM_CHANNEL_SHUFFLE, TransformerRule.TRANSFORM_CHANNEL_SHUFFLE,
# Model data format related transformation # Model data format related transformation
TransformerRule.TRANSPOSE_FILTERS, TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.TRANSPOSE_DATA_FORMAT, # Mace model structure related transformation
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.TRANSPOSE_MATMUL_WEIGHT, TransformerRule.TRANSPOSE_MATMUL_WEIGHT,
# Add winograd argument # Add winograd argument
TransformerRule.ADD_WINOGRAD_ARG, TransformerRule.ADD_WINOGRAD_ARG,
# Mace model structure related transformation
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
# Data type related transformation # Data type related transformation
TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE, TransformerRule.UPDATE_FLOAT_OP_DATA_TYPE,
# Transform finalization # Transform finalization
...@@ -519,6 +545,7 @@ class ConverterOption(object): ...@@ -519,6 +545,7 @@ class ConverterOption(object):
TransformerRule.SORT_BY_EXECUTION, TransformerRule.SORT_BY_EXECUTION,
# update the data format of ops # update the data format of ops
TransformerRule.UPDATE_DATA_FORMAT, TransformerRule.UPDATE_DATA_FORMAT,
TransformerRule.TRANSPOSE_DATA_FORMAT,
# Need to be put after SORT_BY_EXECUTION # Need to be put after SORT_BY_EXECUTION
TransformerRule.ADD_QUANTIZE_TENSOR_RANGE, TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
] ]
...@@ -571,6 +598,8 @@ class ConverterUtil(object): ...@@ -571,6 +598,8 @@ class ConverterUtil(object):
return DataFormat.NHWC return DataFormat.NHWC
elif arg.i == DataFormat.NCHW.value: elif arg.i == DataFormat.NCHW.value:
return DataFormat.NCHW return DataFormat.NCHW
elif arg.i == DataFormat.AUTO.value:
return DataFormat.AUTO
else: else:
return None return None
......
...@@ -195,6 +195,7 @@ class CaffeConverter(base_converter.ConverterInterface): ...@@ -195,6 +195,7 @@ class CaffeConverter(base_converter.ConverterInterface):
self._option = option self._option = option
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NCHW)
self._caffe_net = CaffeNet() self._caffe_net = CaffeNet()
self._caffe_layers = caffe_pb2.NetParameter() self._caffe_layers = caffe_pb2.NetParameter()
caffe_weights = caffe_pb2.NetParameter() caffe_weights = caffe_pb2.NetParameter()
......
...@@ -387,6 +387,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -387,6 +387,8 @@ class OnnxConverter(base_converter.ConverterInterface):
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
self._data_format = DataFormat.NCHW self._data_format = DataFormat.NCHW
ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW) ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.OIHW)
ConverterUtil.add_data_format_arg(self._mace_net_def,
self._data_format)
onnx_model = onnx.load(src_model_file) onnx_model = onnx.load(src_model_file)
ir_version = onnx_model.ir_version ir_version = onnx_model.ir_version
...@@ -402,7 +404,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -402,7 +404,7 @@ class OnnxConverter(base_converter.ConverterInterface):
print("constains ops domain: ", domain, "version:", version) print("constains ops domain: ", domain, "version:", version)
if 'kaldi2onnx' in domain: if 'kaldi2onnx' in domain:
polish_available = False polish_available = False
self._data_format = DataFormat.DF_NONE self._data_format = DataFormat.NONE
self._isKaldi = True self._isKaldi = True
if polish_available: if polish_available:
onnx_model = onnx.utils.polish_model(onnx_model) onnx_model = onnx.utils.polish_model(onnx_model)
......
...@@ -270,6 +270,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -270,6 +270,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._option = option self._option = option
self._mace_net_def = mace_pb2.NetDef() self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.HWIO) ConverterUtil.set_filter_format(self._mace_net_def, DataFormat.HWIO)
ConverterUtil.add_data_format_arg(self._mace_net_def, DataFormat.NHWC)
# import tensorflow graph # import tensorflow graph
tf_graph_def = tf.GraphDef() tf_graph_def = tf.GraphDef()
......
...@@ -27,6 +27,8 @@ from mace.python.tools.converter_tool.base_converter import EltwiseType ...@@ -27,6 +27,8 @@ from mace.python.tools.converter_tool.base_converter import EltwiseType
from mace.python.tools.converter_tool.base_converter import FrameworkType from mace.python.tools.converter_tool.base_converter import FrameworkType
from mace.python.tools.converter_tool.base_converter import MaceKeyword from mace.python.tools.converter_tool.base_converter import MaceKeyword
from mace.python.tools.converter_tool.base_converter import MaceOp from mace.python.tools.converter_tool.base_converter import MaceOp
from mace.python.tools.converter_tool.base_converter import MaceHasDataFormatOps # noqa
from mace.python.tools.converter_tool.base_converter import MaceMayHasDataFormatOps # noqa
from mace.python.tools.converter_tool.base_converter import PaddingMode from mace.python.tools.converter_tool.base_converter import PaddingMode
from mace.python.tools.converter_tool.base_converter import ReduceType from mace.python.tools.converter_tool.base_converter import ReduceType
from mace.python.tools.converter_tool.base_converter import TransformerRule from mace.python.tools.converter_tool.base_converter import TransformerRule
...@@ -77,10 +79,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -77,10 +79,9 @@ class Transformer(base_converter.ConverterInterface):
self.transpose_matmul_weight, self.transpose_matmul_weight,
TransformerRule.FOLD_FC_RESHAPE: TransformerRule.FOLD_FC_RESHAPE:
self.fold_fc_reshape, self.fold_fc_reshape,
TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format,
TransformerRule.ADD_WINOGRAD_ARG: self.add_winograd_arg,
TransformerRule.ADD_IN_OUT_TENSOR_INFO: TransformerRule.ADD_IN_OUT_TENSOR_INFO:
self.add_in_out_tensor_info, self.add_in_out_tensor_info,
TransformerRule.ADD_WINOGRAD_ARG: self.add_winograd_arg,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC: TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC:
self.transform_global_conv_to_fc, self.transform_global_conv_to_fc,
TransformerRule.RESHAPE_FC_WEIGHT: self.reshape_fc_weight, TransformerRule.RESHAPE_FC_WEIGHT: self.reshape_fc_weight,
...@@ -96,6 +97,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -96,6 +97,7 @@ class Transformer(base_converter.ConverterInterface):
self.add_opencl_informations, self.add_opencl_informations,
TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution, TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution,
TransformerRule.UPDATE_DATA_FORMAT: self.update_data_format, TransformerRule.UPDATE_DATA_FORMAT: self.update_data_format,
TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format,
TransformerRule.CHECK_QUANTIZE_INFO: TransformerRule.CHECK_QUANTIZE_INFO:
self.check_quantize_info, self.check_quantize_info,
TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN: TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN:
...@@ -194,21 +196,19 @@ class Transformer(base_converter.ConverterInterface): ...@@ -194,21 +196,19 @@ class Transformer(base_converter.ConverterInterface):
op.type = "Input" op.type = "Input"
data_type_arg = op.arg.add() data_type_arg = op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = mace_pb2.DT_FLOAT data_type_arg.i = input_node.data_type
op.output.extend([input_node.name]) op.output.extend([input_node.name])
output_shape = op.output_shape.add() output_shape = op.output_shape.add()
output_shape.dims.extend(input_node.shape) output_shape.dims.extend(input_node.shape)
if input_node.name in self._consumers: if input_node.data_format != DataFormat.NONE:
if ConverterUtil.data_format( if input_node.data_format == DataFormat.NCHW:
self._consumers[input_node.name][0]) \
== DataFormat.NCHW:
self.transpose_shape(output_shape.dims, self.transpose_shape(output_shape.dims,
[0, 3, 1, 2]) [0, 3, 1, 2])
ConverterUtil.add_data_format_arg(op, ConverterUtil.add_data_format_arg(op,
DataFormat.NCHW) DataFormat.AUTO)
else: else:
ConverterUtil.add_data_format_arg(op, ConverterUtil.add_data_format_arg(op,
DataFormat.NHWC) DataFormat.NONE)
self._producer[op.output[0]] = op self._producer[op.output[0]] = op
@staticmethod @staticmethod
...@@ -256,6 +256,13 @@ class Transformer(base_converter.ConverterInterface): ...@@ -256,6 +256,13 @@ class Transformer(base_converter.ConverterInterface):
else: else:
return None return None
def get_tensor_data_format(self, tensor):
if tensor in self._producer:
producer = self._producer[tensor]
return ConverterUtil.data_format(producer)
else:
return DataFormat.NONE
def consumer_count(self, tensor_name): def consumer_count(self, tensor_name):
return len(self._consumers.get(tensor_name, [])) return len(self._consumers.get(tensor_name, []))
...@@ -838,8 +845,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -838,8 +845,6 @@ class Transformer(base_converter.ConverterInterface):
or op.type == MaceOp.DepthwiseConv2d.name or op.type == MaceOp.DepthwiseConv2d.name
or op.type == MaceOp.FullyConnected.name) or op.type == MaceOp.FullyConnected.name)
and len(op.input) == 2) and len(op.input) == 2)
or (op.type == MaceOp.WinogradInverseTransform.name
and len(op.input) == 1)
or (op.type == MaceOp.Deconv2D.name or (op.type == MaceOp.Deconv2D.name
and ((ConverterUtil.get_arg( and ((ConverterUtil.get_arg(
op, op,
...@@ -930,8 +935,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -930,8 +935,7 @@ class Transformer(base_converter.ConverterInterface):
or op.type == MaceOp.Deconv2D.name or op.type == MaceOp.Deconv2D.name
or op.type == MaceOp.DepthwiseConv2d.name or op.type == MaceOp.DepthwiseConv2d.name
or op.type == MaceOp.FullyConnected.name or op.type == MaceOp.FullyConnected.name
or op.type == MaceOp.BatchNorm.name or op.type == MaceOp.BatchNorm.name) \
or op.type == MaceOp.WinogradInverseTransform.name) \
and len(self._consumers.get(op.output[0], [])) == 1: and len(self._consumers.get(op.output[0], [])) == 1:
consumer_op = self._consumers[op.output[0]][0] consumer_op = self._consumers[op.output[0]][0]
if consumer_op.type == MaceOp.Activation.name \ if consumer_op.type == MaceOp.Activation.name \
...@@ -1017,97 +1021,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1017,97 +1021,6 @@ class Transformer(base_converter.ConverterInterface):
filter_format.name) filter_format.name)
return False return False
def transpose_data_format(self):
net = self._model
for op in net.op:
# transpose args
if op.type == MaceOp.Pad.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_paddings_str:
mace_check(len(arg.ints) == 8,
"pad dim rank should be 8.")
if ConverterUtil.data_format(op) == DataFormat.NCHW:
print("Transpose pad args: %s(%s)"
% (op.name, op.type))
self.transpose_shape(arg.ints,
[0, 1, 4, 5, 6, 7, 2, 3])
elif op.type == MaceOp.Concat.name or op.type == MaceOp.Split.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
if (ConverterUtil.data_format(op) == DataFormat.NCHW
and len(op.output_shape[0].dims) == 4):
print("Transpose concat/split args: %s(%s)"
% (op.name, op.type))
if arg.i == 1:
arg.i = 3
elif arg.i == 2:
arg.i = 1
elif arg.i == 3:
arg.i = 2
producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims
if producer.type == MaceOp.FullyConnected.name and \
len(input_shape) == 2:
axis_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_axis_str)
if axis_arg.i == 1:
axis_arg.i = 3
elif op.type == MaceOp.Squeeze.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
if ConverterUtil.data_format(op) == DataFormat.NCHW:
print("Transpose squeeze args: %s(%s)"
% (op.name, op.type))
mace_check(list(arg.ints) == [2, 3],
'only support squeeze at at [2, 3]')
arg.ints[:] = [1, 2]
elif op.type == MaceOp.Reduce.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
if ConverterUtil.data_format(
op) == DataFormat.NCHW:
print("Transpose reduce args: %s(%s)"
% (op.name, op.type))
reduce_axises = list(arg.ints)
new_axises = []
for i in range(len(reduce_axises)):
idx = reduce_axises[i]
if idx == 2 or idx == 3:
new_axises.append(idx - 1)
elif idx == 1:
new_axises.append(3)
else:
new_axises.append(idx)
new_axises.sort()
arg.ints[:] = []
arg.ints.extend(new_axises)
elif op.type == MaceOp.Crop.name:
offset_arg = ConverterUtil.get_arg(op,
MaceKeyword.mace_offset_str)
mace_check(offset_arg and
ConverterUtil.data_format(op) == DataFormat.NCHW and
len(op.output_shape[0].dims) == 4,
"MACE only support crop with NCHW format")
print("Transpose crop args: %s(%s)"
% (op.name, op.type))
self.transpose_shape(offset_arg.ints, [0, 2, 3, 1])
# transpose op output shape
data_format = ConverterUtil.data_format(op)
if data_format is not None \
and data_format != DataFormat.NHWC:
print("Transpose output shapes: %s(%s)" % (op.name, op.type))
for output_shape in op.output_shape:
if len(output_shape.dims) == 4:
self.transpose_shape(output_shape.dims,
[0, 2, 3, 1])
return False
def add_winograd_arg(self): def add_winograd_arg(self):
if self._wino_arg == 0: if self._wino_arg == 0:
return False return False
...@@ -1428,17 +1341,122 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1428,17 +1341,122 @@ class Transformer(base_converter.ConverterInterface):
def update_data_format(self): def update_data_format(self):
print("update data format") print("update data format")
data_format_flag = 1
for input_node in self._option.input_nodes.values():
if input_node.data_format.value == DataFormat.DF_NONE.value:
data_format_flag = 0
net = self._model net = self._model
for op in net.op: for op in net.op:
ConverterUtil.del_arg( df_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_data_format_str) op, MaceKeyword.mace_data_format_str)
has_data_format_arg = op.arg.add() if not df_arg:
has_data_format_arg.name = MaceKeyword.mace_has_data_format_str df_arg = op.arg.add()
has_data_format_arg.i = data_format_flag df_arg.name = MaceKeyword.mace_data_format_str
if op.type in MaceHasDataFormatOps:
df_arg.i = DataFormat.AUTO.value
elif op.type in MaceMayHasDataFormatOps:
input_df = DataFormat.AUTO.value
for input_tensor in op.input:
if input_tensor in self._consts:
continue
mace_check(
input_tensor in self._producer,
"Input tensor %s not in producer" % input_tensor)
father_op = self._producer[input_tensor]
temp_input_df = ConverterUtil.get_arg(
father_op, MaceKeyword.mace_data_format_str)
if temp_input_df.i != DataFormat.AUTO.value:
input_df = temp_input_df.i
if input_df == DataFormat.AUTO.value:
df_arg.i = input_df
# add flag to mark the ops may has data format
has_data_format_arg = op.arg.add()
has_data_format_arg.name = \
MaceKeyword.mace_has_data_format_str
has_data_format_arg.i = 1
return False
def transpose_data_format(self):
print("Transpose arguments based on data format")
net = self._model
src_data_format = ConverterUtil.data_format(net)
for op in net.op:
has_data_format = ConverterUtil.data_format(op) == \
DataFormat.AUTO
# transpose args
if op.type == MaceOp.Pad.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_paddings_str:
mace_check(len(arg.ints) == 8,
"pad dim rank should be 8.")
if src_data_format == DataFormat.NCHW and \
has_data_format:
print("Transpose pad args: %s(%s)"
% (op.name, op.type))
self.transpose_shape(arg.ints,
[0, 1, 4, 5, 6, 7, 2, 3])
elif op.type == MaceOp.Concat.name or op.type == MaceOp.Split.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
if (src_data_format == DataFormat.NCHW
and has_data_format
and len(op.output_shape[0].dims) == 4):
print("Transpose concat/split args: %s(%s)"
% (op.name, op.type))
if arg.i == 1:
arg.i = 3
elif arg.i == 2:
arg.i = 1
elif arg.i == 3:
arg.i = 2
producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims
if producer.type == MaceOp.FullyConnected.name and \
len(input_shape) == 2:
axis_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_axis_str)
if axis_arg.i == 1:
axis_arg.i = 3
elif op.type == MaceOp.Reduce.name:
for arg in op.arg:
if arg.name == MaceKeyword.mace_axis_str:
if src_data_format == DataFormat.NCHW and \
has_data_format:
print("Transpose reduce args: %s(%s)"
% (op.name, op.type))
reduce_axises = list(arg.ints)
new_axises = []
for i in range(len(reduce_axises)):
idx = reduce_axises[i]
if idx == 2 or idx == 3:
new_axises.append(idx - 1)
elif idx == 1:
new_axises.append(3)
else:
new_axises.append(idx)
new_axises.sort()
arg.ints[:] = []
arg.ints.extend(new_axises)
elif op.type == MaceOp.Crop.name:
offset_arg = ConverterUtil.get_arg(op,
MaceKeyword.mace_offset_str)
mace_check(offset_arg and
src_data_format == DataFormat.NCHW
and has_data_format
and len(op.output_shape[0].dims) == 4,
"MACE only support crop with NCHW format")
print("Transpose crop args: %s(%s)"
% (op.name, op.type))
self.transpose_shape(offset_arg.ints, [0, 2, 3, 1])
# transpose op output shape
if src_data_format == DataFormat.NCHW and \
has_data_format:
print("Transpose output shapes: %s(%s)" % (op.name, op.type))
for output_shape in op.output_shape:
if len(output_shape.dims) == 4:
self.transpose_shape(output_shape.dims,
[0, 2, 3, 1])
return False return False
def quantize_nodes(self): def quantize_nodes(self):
...@@ -1493,7 +1511,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1493,7 +1511,7 @@ class Transformer(base_converter.ConverterInterface):
self._model.input_info[i].zero_point = quantize_info.zero_point self._model.input_info[i].zero_point = quantize_info.zero_point
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8) ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8)
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC) ConverterUtil.add_data_format_arg(op_def, input_node.data_format)
# use actual ranges for model input quantize # use actual ranges for model input quantize
find_range_every_time_arg = op_def.arg.add() find_range_every_time_arg = op_def.arg.add()
find_range_every_time_arg.name = \ find_range_every_time_arg.name = \
...@@ -1516,6 +1534,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1516,6 +1534,7 @@ class Transformer(base_converter.ConverterInterface):
self._model.output_info[i].zero_point = quantize_info.zero_point self._model.output_info[i].zero_point = quantize_info.zero_point
ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8) ConverterUtil.add_data_type_arg(op_def, mace_pb2.DT_UINT8)
ConverterUtil.add_data_format_arg(op_def, output_node.data_format)
quantize_flag_arg = self._model.arg.add() quantize_flag_arg = self._model.arg.add()
quantize_flag_arg.name = MaceKeyword.mace_quantize_flag_arg_str quantize_flag_arg.name = MaceKeyword.mace_quantize_flag_arg_str
...@@ -1886,9 +1905,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1886,9 +1905,6 @@ class Transformer(base_converter.ConverterInterface):
shape_tensor.data_type = mace_pb2.DT_INT32 shape_tensor.data_type = mace_pb2.DT_INT32
else: else:
mace_check(False, "Only support reshape and flatten") mace_check(False, "Only support reshape and flatten")
# NCHW -> NHWC
if len(dims) == 4:
self.transpose_shape(dims, [0, 2, 3, 1])
shape_tensor.int32_data.extend(dims) shape_tensor.int32_data.extend(dims)
op.input.append(shape_tensor.name) op.input.append(shape_tensor.name)
...@@ -2030,6 +2046,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -2030,6 +2046,9 @@ class Transformer(base_converter.ConverterInterface):
data_type_arg = quantize_op.arg.add() data_type_arg = quantize_op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = mace_pb2.DT_UINT8 data_type_arg.i = mace_pb2.DT_UINT8
ConverterUtil.add_data_format_arg(
quantize_op,
self.get_tensor_data_format(input_tensor))
data_type_arg = quantize_op.arg.add() data_type_arg = quantize_op.arg.add()
data_type_arg.name = MaceKeyword.mace_non_zero data_type_arg.name = MaceKeyword.mace_non_zero
...@@ -2050,8 +2069,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -2050,8 +2069,8 @@ class Transformer(base_converter.ConverterInterface):
del op.input[:] del op.input[:]
op.input.extend(quantized_inputs_names) op.input.extend(quantized_inputs_names)
orginal_output_name = op.output[0] original_output_name = op.output[0]
op.output[0] = orginal_output_name + "_quant" op.output[0] = original_output_name + "_quant"
op.output_type.extend([to_quantize_ops_output_type[op.type]]) op.output_type.extend([to_quantize_ops_output_type[op.type]])
data_type_arg = ConverterUtil.get_arg(op, data_type_arg = ConverterUtil.get_arg(op,
MaceKeyword.mace_op_data_type_str) # noqa MaceKeyword.mace_op_data_type_str) # noqa
...@@ -2064,13 +2083,15 @@ class Transformer(base_converter.ConverterInterface): ...@@ -2064,13 +2083,15 @@ class Transformer(base_converter.ConverterInterface):
dequantize_op.name = op.name + "_dequant" dequantize_op.name = op.name + "_dequant"
dequantize_op.type = MaceOp.Dequantize.name dequantize_op.type = MaceOp.Dequantize.name
dequantize_op.input.extend([op.output[0]]) dequantize_op.input.extend([op.output[0]])
dequantize_op.output.extend([orginal_output_name]) dequantize_op.output.extend([original_output_name])
dequantize_op.output_shape.extend(op.output_shape) dequantize_op.output_shape.extend(op.output_shape)
dequantize_op.output_type.extend([mace_pb2.DT_FLOAT]) dequantize_op.output_type.extend([mace_pb2.DT_FLOAT])
data_type_arg = dequantize_op.arg.add() data_type_arg = dequantize_op.arg.add()
data_type_arg.name = MaceKeyword.mace_op_data_type_str data_type_arg.name = MaceKeyword.mace_op_data_type_str
data_type_arg.i = to_quantize_ops_output_type[op.type] data_type_arg.i = to_quantize_ops_output_type[op.type]
ConverterUtil.add_data_format_arg(
dequantize_op,
self.get_tensor_data_format(original_output_name))
quantize_flag_arg = ConverterUtil.get_arg(self._model, quantize_flag_arg = ConverterUtil.get_arg(self._model,
MaceKeyword.mace_quantize_flag_arg_str) # noqa MaceKeyword.mace_quantize_flag_arg_str) # noqa
if quantize_flag_arg is None: if quantize_flag_arg is None:
......
...@@ -80,7 +80,7 @@ void CreateInputInfo(NetDef *net_def) { ...@@ -80,7 +80,7 @@ void CreateInputInfo(NetDef *net_def) {
input_info = net_def->add_input_info(); input_info = net_def->add_input_info();
input_info->set_name({{ net.input_info[idx].name|tojson }}); input_info->set_name({{ net.input_info[idx].name|tojson }});
input_info->set_data_type(static_cast<DataType>({{ net.input_info[idx].data_type }})); input_info->set_data_type(static_cast<DataType>({{ net.input_info[idx].data_type }}));
input_info->set_data_format(static_cast<DataFormat>({{ net.input_info[idx].data_format }})); input_info->set_data_format({{ net.input_info[idx].data_format }});
input_info->mutable_dims()->Reserve({{ net.input_info[idx].dims|length }}); input_info->mutable_dims()->Reserve({{ net.input_info[idx].dims|length }});
{% for dim in net.input_info[idx].dims %} {% for dim in net.input_info[idx].dims %}
input_info->add_dims({{ dim }}); input_info->add_dims({{ dim }});
...@@ -97,7 +97,7 @@ void CreateOutputInfo(NetDef *net_def) { ...@@ -97,7 +97,7 @@ void CreateOutputInfo(NetDef *net_def) {
output_info = net_def->add_output_info(); output_info = net_def->add_output_info();
output_info->set_name({{ net.output_info[idx].name|tojson }}); output_info->set_name({{ net.output_info[idx].name|tojson }});
output_info->set_data_type(static_cast<DataType>({{ net.output_info[idx].data_type }})); output_info->set_data_type(static_cast<DataType>({{ net.output_info[idx].data_type }}));
output_info->set_data_format(static_cast<DataFormat>({{ net.output_info[idx].data_format }})); output_info->set_data_format({{ net.output_info[idx].data_format }});
output_info->mutable_dims()->Reserve({{ net.output_info[idx].dims|length }}); output_info->mutable_dims()->Reserve({{ net.output_info[idx].dims|length }});
{% for dim in net.output_info[idx].dims %} {% for dim in net.output_info[idx].dims %}
output_info->add_dims({{dim}}); output_info->add_dims({{dim}});
......
...@@ -48,7 +48,7 @@ void MaceRunFunc(const int in_out_size) { ...@@ -48,7 +48,7 @@ void MaceRunFunc(const int in_out_size) {
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
InputOutputInfo *info = net_def->add_input_info(); InputOutputInfo *info = net_def->add_input_info();
info->set_data_format(DataFormat::NHWC); info->set_data_format(static_cast<int>(DataFormat::NHWC));
info->set_name(input_names[i]); info->set_name(input_names[i]);
for (auto d : input_shapes[0]) { for (auto d : input_shapes[0]) {
info->add_dims(static_cast<int>(d)); info->add_dims(static_cast<int>(d));
......
...@@ -45,7 +45,7 @@ void MaceRun(const int in_out_size, ...@@ -45,7 +45,7 @@ void MaceRun(const int in_out_size,
for (size_t i = 0; i < input_names.size(); ++i) { for (size_t i = 0; i < input_names.size(); ++i) {
InputOutputInfo *info = net_def->add_input_info(); InputOutputInfo *info = net_def->add_input_info();
info->set_data_format(DataFormat::NHWC); info->set_data_format(static_cast<int>(DataFormat::NHWC));
info->set_name(input_names[i]); info->set_name(input_names[i]);
for (auto d : max_shape) { for (auto d : max_shape) {
info->add_dims(static_cast<int>(d)); info->add_dims(static_cast<int>(d));
......
...@@ -76,7 +76,7 @@ void Conv3x3(const std::string &input_name, ...@@ -76,7 +76,7 @@ void Conv3x3(const std::string &input_name,
.AddIntArg("padding", Padding::SAME) .AddIntArg("padding", Padding::SAME)
.AddIntsArg("dilations", {1, 1}) .AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("has_data_format", 1) .AddIntArg("data_format", static_cast<int>(DataFormat::AUTO))
.Finalize(&operator_def); .Finalize(&operator_def);
OutputShape *shape = operator_def.add_output_shape(); OutputShape *shape = operator_def.add_output_shape();
...@@ -99,7 +99,7 @@ void Relu(const std::string &input_name, ...@@ -99,7 +99,7 @@ void Relu(const std::string &input_name,
.AddStringArg("activation", "RELU") .AddStringArg("activation", "RELU")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("device", static_cast<int>(device_type)) .AddIntArg("device", static_cast<int>(device_type))
.AddIntArg("has_data_format", 1) .AddIntArg("data_format", static_cast<int>(DataFormat::AUTO))
.Finalize(&operator_def); .Finalize(&operator_def);
net_def->add_op()->CopyFrom(operator_def); net_def->add_op()->CopyFrom(operator_def);
...@@ -139,7 +139,8 @@ void CheckOutputs(const NetDef &net_def, ...@@ -139,7 +139,8 @@ void CheckOutputs(const NetDef &net_def,
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
std::string input_name = input.first + "NHWC"; std::string input_name = input.first + "NHWC";
net.AddInputFromArray<D, float>(input_name, input_shape, input_data); net.AddInputFromArray<D, float>(input_name, input_shape, input_data);
net.TransformDataFormat<D, float>(input_name, NHWC, input.first, NCHW); net.TransformDataFormat<D, float>(
input_name, DataFormat::NHWC, input.first, DataFormat::NCHW);
} else { } else {
net.AddInputFromArray<D, float>(input.first, input_shape, input_data); net.AddInputFromArray<D, float>(input.first, input_shape, input_data);
} }
...@@ -154,7 +155,7 @@ void CheckOutputs(const NetDef &net_def, ...@@ -154,7 +155,7 @@ void CheckOutputs(const NetDef &net_def,
memcpy(data.data(), memcpy(data.data(),
reinterpret_cast<const T *>(tensor_data.data()) + tensor.offset(), reinterpret_cast<const T *>(tensor_data.data()) + tensor.offset(),
tensor.data_size() * sizeof(T)); tensor.data_size() * sizeof(T));
net.AddInputFromArray<D, T>(tensor.name(), shape, data); net.AddInputFromArray<D, T>(tensor.name(), shape, data, true);
} }
net.RunNet(net_def, D); net.RunNet(net_def, D);
...@@ -175,9 +176,9 @@ void CheckOutputs(const NetDef &net_def, ...@@ -175,9 +176,9 @@ void CheckOutputs(const NetDef &net_def,
if (D == DeviceType::CPU) { if (D == DeviceType::CPU) {
output_name = output.first + "NHWC"; output_name = output.first + "NHWC";
net.TransformDataFormat<CPU, float>(output.first, net.TransformDataFormat<CPU, float>(output.first,
NCHW, DataFormat::NCHW,
output_name, output_name,
NHWC); DataFormat::NHWC);
} }
ops::test::ExpectTensorNear<float>(*tmp_tensor, ops::test::ExpectTensorNear<float>(*tmp_tensor,
*net.GetOutput(output_name.data()), *net.GetOutput(output_name.data()),
......
...@@ -91,7 +91,7 @@ DataFormat ParseDataFormat(const std::string &data_format_str) { ...@@ -91,7 +91,7 @@ DataFormat ParseDataFormat(const std::string &data_format_str) {
} else if (data_format_str == "OIHW") { } else if (data_format_str == "OIHW") {
return DataFormat::OIHW; return DataFormat::OIHW;
} else { } else {
return DataFormat::DF_NONE; return DataFormat::NONE;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册