提交 4a298aac 编写于 作者: L liuqi

Feature: Add registration of Op conditions.

1. Add registration of Op conditions.
2. Add data format for ops based on model's input data format.
上级 3bd4df8f
......@@ -64,92 +64,26 @@ bool TransformRequiredOp(const std::string &op_type) {
}
#endif // MACE_ENABLE_OPENCL
// TODO(lichao): Move to runtime driver class after universality done.
// fallback to gpu buffer when kernels are implemented
void FindAvailableDevicesForOp(const OpRegistryBase &op_registry,
const OperatorDef &op,
const std::unordered_map<std::string,
std::vector<index_t>> &tensor_shape_info,
std::set<DeviceType>
*available_devices) {
auto devices = op_registry.AvailableDevices(op.type());
available_devices->insert(devices.begin(), devices.end());
std::string op_type = op.type();
// For those whose shape is not 4-rank but can run on GPU
if (op_type == "BufferTransform"
|| op_type == "LSTMCell"
|| op_type == "FullyConnected"
|| op_type == "Softmax"
|| op_type == "Squeeze") {
return;
} else {
if (op.output_shape_size() != op.output_size()) {
return;
}
if (op.output_shape(0).dims_size() != 4) {
available_devices->erase(DeviceType::GPU);
}
if (op_type == "Split") {
if (op.output_shape(0).dims_size() != 4
|| op.output_shape(0).dims()[3] % 4 != 0) {
available_devices->erase(DeviceType::GPU);
}
} else if (op_type == "Concat") {
if (op.output_shape(0).dims_size() != 4) {
available_devices->erase(DeviceType::GPU);
} else {
if (op.input_size() != 2) {
for (const std::string &input : op.input()) {
if (tensor_shape_info.find(input) != tensor_shape_info.end()) {
auto &input_shape = tensor_shape_info.at(input);
if (input_shape[3] % 4 != 0) {
available_devices->erase(DeviceType::GPU);
break;
}
}
}
}
}
} else if (op_type == "ChannelShuffle") {
int groups = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op, "group", 1);
int channels = op.output_shape(0).dims(3);
int channels_per_group = channels / groups;
if (groups % 4 != 0 || channels_per_group % 4 != 0) {
available_devices->erase(DeviceType::GPU);
}
}
}
}
} // namespace
std::unique_ptr<Operation> SerialNet::CreateOperation(
const OpRegistryBase *op_registry,
OpConstructContext *construct_context,
std::shared_ptr<OperatorDef> op_def,
const std::unordered_map<std::string,
std::vector<index_t>> tensor_shape_info,
DataFormat data_format_flag,
bool is_quantize_model) {
// Create the Operation
DeviceType target_device_type = target_device_->device_type();
// Get available devices
std::set<DeviceType> available_devices;
FindAvailableDevicesForOp(*op_registry,
*op_def,
tensor_shape_info,
&available_devices);
// Find the device type to run the op.
// If the target_device_type in available devices, use target_device_type,
// otherwise, fallback to CPU device.
DeviceType device_type = DeviceType::CPU;
construct_context->set_device(cpu_device_);
construct_context->set_operator_def(op_def);
construct_context->set_output_mem_type(MemoryType::CPU_BUFFER);
// Get available devices
auto available_devices =
op_registry->AvailableDevices(op_def->type(), construct_context);
// Find the device type to run the op.
// If the target_device_type in available devices, use target_device_type,
// otherwise, fallback to CPU device.
for (auto device : available_devices) {
if (device == target_device_type) {
device_type = target_device_type;
......@@ -208,6 +142,23 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
MemoryType target_mem_type;
// quantize model flag
bool is_quantize_model = IsQuantizedModel(*net_def);
// Tensor Shape map
std::unordered_map<std::string, std::vector<index_t>> tensor_shape_map;
for (auto &op : net_def->op()) {
if (op.output_size() != op.output_shape_size()) {
continue;
}
for (int i = 0; i < op.output_size(); ++i) {
tensor_shape_map[op.output(i)] =
std::move(std::vector<index_t>(op.output_shape(i).dims().begin(),
op.output_shape(i).dims().end()));
}
}
for (auto &tensor : net_def->tensors()) {
tensor_shape_map[tensor.name()] =
std::move(std::vector<index_t>(tensor.dims().begin(),
tensor.dims().end()));
}
DataFormat data_format_flag = NHWC;
if (target_device_->device_type() == DeviceType::CPU) {
......@@ -216,11 +167,12 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
std::vector<index_t> input_shape =
std::vector<index_t>(input_info.dims().begin(),
input_info.dims().end());
// update tensor shape map
tensor_shape_map[input_info.name()] = input_shape;
// Only could be NONE or NHWC
auto input_data_format = static_cast<DataFormat>(
input_info.data_format());
if (!is_quantize_model &&
input_data_format == NHWC &&
if (!is_quantize_model && input_data_format == NHWC &&
input_info.dims_size() == 4) {
// NHWC -> NCHW
input_shape =
......@@ -237,39 +189,29 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
else { // GPU NOLINT[readability/braces]
target_mem_type = MemoryType::GPU_BUFFER;
for (auto &input_info : net_def->input_info()) {
auto input_data_format = static_cast<DataFormat>(
input_info.data_format());
if (input_data_format == DataFormat::DF_NONE) {
data_format_flag = DataFormat::DF_NONE;
}
std::vector<index_t> input_shape =
std::vector<index_t>(input_info.dims().begin(),
input_info.dims().end());
// update tensor shape map
tensor_shape_map[input_info.name()] = input_shape;
output_map.emplace(input_info.name(), InternalOutputInfo(
target_mem_type, DataType::DT_FLOAT, input_shape, -1));
}
}
#endif // MACE_ENABLE_OPENCL
std::unordered_map<std::string, std::vector<index_t>> tensor_shape_info;
for (auto &op : net_def->op()) {
if (op.output_size() != op.output_shape_size()) {
continue;
}
for (int i = 0; i < op.output_size(); ++i) {
tensor_shape_info[op.output(i)] =
std::move(std::vector<index_t>(op.output_shape(i).dims().begin(),
op.output_shape(i).dims().end()));
}
}
for (auto &tensor : net_def->tensors()) {
tensor_shape_info[tensor.name()] =
std::move(std::vector<index_t>(tensor.dims().begin(),
tensor.dims().end()));
}
OpConstructContext construct_context(ws_);
OpConstructContext construct_context(ws_, &tensor_shape_map);
for (int idx = 0; idx < net_def->op_size(); ++idx) {
std::shared_ptr<OperatorDef> op_def(new OperatorDef(net_def->op(idx)));
// Create operation
auto op = CreateOperation(op_registry,
&construct_context,
op_def,
tensor_shape_info,
data_format_flag,
is_quantize_model);
#ifdef MACE_ENABLE_OPENCL
......@@ -317,12 +259,12 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
}
auto transform_op_def = OpenCLUtil::CreateTransformOpDef(
input_name, input_shape, t_input_name,
wanted_in_dt, wanted_in_mem_type);
wanted_in_dt, wanted_in_mem_type, data_format_flag);
OpConstructContext t_construct_context(ws_);
auto transform_op = CreateOperation(
op_registry,
&construct_context,
&t_construct_context,
transform_op_def,
tensor_shape_info,
data_format_flag);
operators_.emplace_back(std::move(transform_op));
transformed_set.insert(t_input_name);
......@@ -405,12 +347,12 @@ SerialNet::SerialNet(const OpRegistryBase *op_registry,
internal_output_info.shape,
output_info.name(),
output_info.data_type(),
target_mem_type);
target_mem_type,
data_format_flag);
auto transform_op = CreateOperation(
op_registry,
&construct_context,
transform_op_def,
tensor_shape_info,
output_data_format);
operators_.emplace_back(std::move(transform_op));
// where to do graph reference count.
......
......@@ -59,8 +59,6 @@ class SerialNet : public NetBase {
const OpRegistryBase *op_registry,
OpConstructContext *construct_context,
std::shared_ptr<OperatorDef> op_def,
const std::unordered_map<std::string,
std::vector<index_t>> tensor_shape_info,
DataFormat input_format,
bool is_quantize_model = false);
......
......@@ -22,7 +22,18 @@
namespace mace {
OpConstructContext::OpConstructContext(Workspace *ws)
: operator_def_(nullptr), ws_(ws), device_(nullptr) {}
: operator_def_(nullptr),
ws_(ws),
device_(nullptr),
tensor_shape_info_(nullptr) {}
OpConstructContext::OpConstructContext(
mace::Workspace *ws,
mace::OpConstructContext::TensorShapeMap *info)
: operator_def_(nullptr),
ws_(ws),
device_(nullptr),
tensor_shape_info_(info) {}
void OpConstructContext::set_operator_def(
std::shared_ptr<mace::OperatorDef> operator_def) {
......@@ -169,6 +180,19 @@ const std::string OpKeyBuilder::Build() {
}
} // namespace
OpRegistrationInfo::OpRegistrationInfo() {
device_placer = [this](OpConstructContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
// The GPU ops only support 4D In/Out tensor by default
if (this->devices.count(DeviceType::CPU) == 1 &&
op->output_shape_size() == op->output_size() &&
op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU };
}
return this->devices;
};
}
void OpRegistrationInfo::AddDevice(mace::DeviceType device) {
devices.insert(device);
}
......@@ -179,10 +203,11 @@ void OpRegistrationInfo::Register(const std::string &key, OpCreator creator) {
creators[key] = creator;
}
MaceStatus OpRegistryBase::Register(const std::string &op_type,
const mace::DeviceType device_type,
const mace::DataType dt,
mace::OpRegistrationInfo::OpCreator creator) {
MaceStatus OpRegistryBase::Register(
const std::string &op_type,
const mace::DeviceType device_type,
const mace::DataType dt,
mace::OpRegistrationInfo::OpCreator creator) {
if (registry_.count(op_type) == 0) {
registry_[op_type] = std::unique_ptr<OpRegistrationInfo>(
new OpRegistrationInfo);
......@@ -197,15 +222,25 @@ MaceStatus OpRegistryBase::Register(const std::string &op_type,
return MaceStatus::MACE_SUCCESS;
}
MaceStatus OpRegistryBase::Register(
const OpConditionBuilder &builder) {
std::string op_type = builder.type();
if (registry_.count(op_type) == 0) {
registry_[op_type] = std::unique_ptr<OpRegistrationInfo>(
new OpRegistrationInfo);
}
builder.Finalize(registry_[op_type].get());
return MaceStatus::MACE_SUCCESS;
}
const std::set<DeviceType> OpRegistryBase::AvailableDevices(
const std::string &op_type) const {
const std::string &op_type, OpConstructContext *context) const {
MACE_CHECK(registry_.count(op_type) != 0,
op_type, " operation is not registered.");
return registry_.at(op_type)->devices;
return registry_.at(op_type)->device_placer(context);
}
std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
OpConstructContext *context,
DeviceType device_type) const {
......@@ -238,4 +273,24 @@ std::unique_ptr<Operation> OpRegistryBase::CreateOperation(
}
return registry_.at(op_type)->creators.at(key)(context);
}
OpConditionBuilder::OpConditionBuilder(const std::string &type)
: type_(type) {}
const std::string OpConditionBuilder::type() const {
return type_;
}
OpConditionBuilder &OpConditionBuilder::SetDevicePlacerFunc(
OpRegistrationInfo::DevicePlacer placer) {
placer_ = placer;
return *this;
}
void OpConditionBuilder::Finalize(OpRegistrationInfo *info) const {
if (info != nullptr && placer_) {
info->device_placer = placer_;
}
}
} // namespace mace
......@@ -31,8 +31,11 @@ namespace mace {
// memory_optimizer, device
class OpConstructContext {
typedef std::unordered_map<std::string, std::vector<index_t>> TensorShapeMap;
public:
explicit OpConstructContext(Workspace *ws);
OpConstructContext(Workspace *ws, TensorShapeMap *info);
~OpConstructContext() = default;
void set_operator_def(std::shared_ptr<OperatorDef> operator_def);
......@@ -53,6 +56,10 @@ class OpConstructContext {
return device_;
}
inline TensorShapeMap *tensor_shape_info() const {
return tensor_shape_info_;
}
void set_output_mem_type(MemoryType type);
inline MemoryType output_mem_type() const {
......@@ -69,6 +76,7 @@ class OpConstructContext {
std::shared_ptr<OperatorDef> operator_def_;
Workspace *ws_;
Device *device_;
TensorShapeMap *tensor_shape_info_;
// used for memory transform
std::vector<MemoryType> input_mem_types_;
std::vector<DataType> input_data_types_;
......@@ -188,8 +196,10 @@ struct OpRegistrationInfo {
public:
typedef std::function<std::unique_ptr<Operation>(OpConstructContext *)>
OpCreator;
typedef std::function<std::set<DeviceType>(OpConstructContext *)>
DevicePlacer;
OpRegistrationInfo() = default;
OpRegistrationInfo();
void AddDevice(DeviceType);
......@@ -197,8 +207,26 @@ struct OpRegistrationInfo {
std::set<DeviceType> devices;
std::unordered_map<std::string, OpCreator> creators;
DevicePlacer device_placer;
};
class OpConditionBuilder {
public:
explicit OpConditionBuilder(const std::string &type);
const std::string type() const;
OpConditionBuilder &SetDevicePlacerFunc(
OpRegistrationInfo::DevicePlacer placer);
void Finalize(OpRegistrationInfo *info) const;
private:
std::string type_;
OpRegistrationInfo::DevicePlacer placer_;
};
class OpRegistryBase {
public:
OpRegistryBase() = default;
......@@ -208,8 +236,10 @@ class OpRegistryBase {
const DataType dt,
OpRegistrationInfo::OpCreator creator);
MaceStatus Register(const OpConditionBuilder &builder);
const std::set<DeviceType> AvailableDevices(
const std::string &op_type) const;
const std::string &op_type, OpConstructContext *context) const;
std::unique_ptr<Operation> CreateOperation(
OpConstructContext *context,
......@@ -234,6 +264,9 @@ class OpRegistryBase {
DataTypeToEnum<dt>::value, \
OpRegistryBase::DefaultCreator<class_name<device, dt>>)
#define MACE_REGISTER_OP_CONDITION(op_registry, builder) \
op_registry->Register(builder)
} // namespace mace
#endif // MACE_CORE_OPERATOR_H_
......@@ -151,7 +151,8 @@ std::shared_ptr<OperatorDef> OpenCLUtil::CreateTransformOpDef(
const std::vector<mace::index_t> &input_shape,
const std::string &output_name,
const mace::DataType dt,
const mace::MemoryType mem_type) {
const mace::MemoryType mem_type,
const DataFormat data_format) {
std::unique_ptr<OperatorDef> op(new OperatorDef);
std::string op_name = "mace_node_" + output_name;
op->set_name(op_name);
......@@ -168,8 +169,8 @@ std::shared_ptr<OperatorDef> OpenCLUtil::CreateTransformOpDef(
arg->set_name("T");
arg->set_i(static_cast<int32_t>(dt));
arg = op->add_arg();
arg->set_name("device");
arg->set_i(DeviceType::GPU);
arg->set_name("data_format");
arg->set_i(data_format);
if (!input_shape.empty()) {
OutputShape *shape = op->add_output_shape();
for (auto value : input_shape) {
......
......@@ -20,6 +20,7 @@
#include <vector>
#include "mace/core/types.h"
#include "mace/public/mace.h"
namespace mace {
enum OpenCLBufferType {
......@@ -47,7 +48,8 @@ class OpenCLUtil {
const std::vector<mace::index_t> &input_shape,
const std::string &output_name,
const mace::DataType dt,
const MemoryType mem_type);
const MemoryType mem_type,
const DataFormat data_format);
};
} // namespace mace
......
......@@ -46,6 +46,7 @@ void FilterBufferToImage(int iters,
OpenCLBufferType::IN_OUT_CHANNEL,
MemoryType::GPU_IMAGE,
0,
DataFormat::NHWC,
b2i_output);
};
......
......@@ -35,14 +35,14 @@ void TestBidirectionTransform(const OpenCLBufferType type,
OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE)
.Transform(&context, net.ws()->GetTensor("Input"),
type, MemoryType::GPU_IMAGE, 0, b2i_output);
type, MemoryType::GPU_IMAGE, 0, DataFormat::NHWC, b2i_output);
// Inverse Transform
Tensor *i2b_output = net.ws()->CreateTensor(
"I2BOutput", context.device()->allocator(), DataTypeToEnum<T>::value);
OpenCLBufferTransformer<T>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER)
.Transform(&context, b2i_output,
type, MemoryType::GPU_BUFFER, 0, i2b_output);
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, i2b_output);
// Check
ExpectTensorNear<T>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"),
......@@ -176,14 +176,14 @@ void TestDiffTypeBidirectionTransform(const OpenCLBufferType type,
OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE)
.Transform(&context, net.ws()->GetTensor("Input"),
type, MemoryType::GPU_IMAGE, 0, b2i_output);
type, MemoryType::GPU_IMAGE, 0, DataFormat::NHWC, b2i_output);
// Inverse Transform
Tensor *i2b_output = net.ws()->CreateTensor(
"I2BOutput", context.device()->allocator(), DT_FLOAT);
OpenCLBufferTransformer<float>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER)
.Transform(&context, b2i_output,
type, MemoryType::GPU_BUFFER, 0, i2b_output);
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, i2b_output);
// Check
ExpectTensorNear<float>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"),
......@@ -216,14 +216,14 @@ void TestStringHalfBidirectionTransform(const OpenCLBufferType type,
// Transform
OpenCLBufferTransformer<T>(MemoryType::GPU_BUFFER, MemoryType::GPU_IMAGE)
.Transform(&context, net.ws()->GetTensor("Input"),
type, MemoryType::GPU_IMAGE, 0, b2i_output);
type, MemoryType::GPU_IMAGE, 0, DataFormat::NHWC, b2i_output);
// Inverse Transform
Tensor *i2b_output = net.ws()->CreateTensor(
"I2BOutput", context.device()->allocator(), DataTypeToEnum<T>::value);
OpenCLBufferTransformer<T>(MemoryType::GPU_IMAGE, MemoryType::GPU_BUFFER)
.Transform(&context, b2i_output,
type, MemoryType::GPU_BUFFER, 0, i2b_output);
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, i2b_output);
// Check
ExpectTensorNear<half>(*net.GetOutput("Input"), *net.GetOutput("I2BOutput"),
......
......@@ -39,11 +39,14 @@ class BufferTransformOp<DeviceType::GPU, T> : public Operation {
auto type =
static_cast<OpenCLBufferType>(Operation::GetOptionalArg<int>(
"buffer_type", static_cast<int>(CONV2D_FILTER)));
auto data_format = static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE));
MemoryType in_mem_type = context->workspace()->GetTensor(
operator_def_->input(0))->memory_type();
return OpenCLBufferTransformer<T>(in_mem_type, out_mem_type_).Transform(
context, input, type, out_mem_type_, wino_blk_size_, output);
context, input, type, out_mem_type_, wino_blk_size_,
data_format, output);
}
private:
......
......@@ -46,7 +46,7 @@ void TestBidirectionTransform(const OpenCLBufferType type,
OpenCLBufferTransformer<DstType>(MemoryType::GPU_BUFFER,
MemoryType::GPU_BUFFER)
.Transform(&context, net.ws()->GetTensor("Input"),
type, MemoryType::GPU_BUFFER, 0, bt_output);
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, bt_output);
// Inverse Transform
Tensor *output = net.ws()->CreateTensor(
......@@ -55,7 +55,7 @@ void TestBidirectionTransform(const OpenCLBufferType type,
OpenCLBufferTransformer<OrgType>(MemoryType::GPU_BUFFER,
MemoryType::GPU_BUFFER)
.Transform(&context, bt_output,
type, MemoryType::GPU_BUFFER, 0, output);
type, MemoryType::GPU_BUFFER, 0, DataFormat::NHWC, output);
if (DataTypeToEnum<OrgType>::value == DataTypeToEnum<DstType>::value) {
EXPECT_EQ(net.GetOutput("Input")->UnderlyingBuffer(),
......@@ -92,7 +92,7 @@ void TestArgumentTransform(const index_t input_size) {
MemoryType::GPU_BUFFER)
.Transform(&context, net.ws()->GetTensor("Input"),
OpenCLBufferType::ARGUMENT, MemoryType::GPU_BUFFER,
0, output);
0, DataFormat::NHWC, output);
index_t expected_size = RoundUp<index_t>(input_size, 4);
EXPECT_EQ(expected_size, output->buffer_shape()[0]);
......
......@@ -111,6 +111,28 @@ void RegisterChannelShuffle(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "ChannelShuffle",
ChannelShuffleOp, DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("ChannelShuffle")
.SetDevicePlacerFunc(
[](OpConstructContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
int groups = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "group", 1);
if (op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU };
}
index_t channels = op->output_shape(0).dims(3);
index_t channels_per_group = channels / groups;
if (groups % 4 != 0 || channels_per_group % 4 != 0) {
return { DeviceType::CPU };
}
return { DeviceType::CPU, DeviceType::GPU };
}));
}
} // namespace ops
......
......@@ -59,7 +59,9 @@ class ConcatOp<DeviceType::CPU, T> : public ConcatOpBase {
MACE_UNUSED(context);
if (!checked_) {
Validate();
if (this->Input(0)->dim_size() == 4) {
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE));
if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) {
if (axis_ == 3) axis_ = 1;
else if (axis_ == 2) axis_ = 3;
else if (axis_ == 1) axis_ = 2;
......@@ -232,7 +234,42 @@ void RegisterConcat(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Concat", ConcatOp,
DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("Concat")
.SetDevicePlacerFunc(
[](OpConstructContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
auto tensor_shape_info = context->tensor_shape_info();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
if (op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU };
} else {
int axis = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "axis", 3);
if (axis != 3) {
return { DeviceType::CPU };
}
bool divisible_four = true;
for (const std::string &input : op->input()) {
if (tensor_shape_info->find(input)
!= tensor_shape_info->end()) {
divisible_four = divisible_four
&& (tensor_shape_info->at(input)[3] % 4 == 0);
}
}
// Only support not divisible 4 case with 2 inputs.
if (op->input_size() > 2 && !divisible_four) {
return { DeviceType::CPU };
}
}
return { DeviceType::CPU, DeviceType::GPU };
}));
}
} // namespace ops
......
......@@ -186,6 +186,7 @@ TEST_F(ConcatOpTest, QuantizedCPURandom) {
builder = builder.Input(MakeString("Input", i));
}
builder.AddIntArg("axis", axis_arg)
.AddIntArg("data_format", DataFormat::NHWC)
.Output("Output")
.Finalize(net.NewOperatorDef());
......@@ -245,8 +246,9 @@ TEST_F(ConcatOpTest, QuantizedCPURandom) {
namespace {
template <typename T>
void OpenclRandomTest(const std::vector<std::vector<index_t>> &shapes,
const int axis) {
void OpenCLRandomTest(const std::vector<std::vector<index_t>> &shapes,
const int axis,
DataFormat data_format) {
srand(time(nullptr));
int num_inputs = shapes.size();
int concat_axis_size = 0;
......@@ -262,6 +264,8 @@ void OpenclRandomTest(const std::vector<std::vector<index_t>> &shapes,
net.AddInputFromArray<DeviceType::GPU, float>(input_name, shapes[i],
inputs[i]);
}
std::vector<index_t> expected_shape = shapes[0];
expected_shape[axis] = concat_axis_size;
auto builder = OpDefBuilder("Concat", "ConcatTest");
for (int i = 0; i < num_inputs; ++i) {
......@@ -271,6 +275,8 @@ void OpenclRandomTest(const std::vector<std::vector<index_t>> &shapes,
builder.AddIntArg("axis", axis)
.Output("Output")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("data_format", data_format)
.OutputShape(expected_shape)
.Finalize(net.NewOperatorDef());
// Run
......@@ -279,8 +285,6 @@ void OpenclRandomTest(const std::vector<std::vector<index_t>> &shapes,
// Check
auto output = net.GetOutput("Output");
std::vector<index_t> expected_shape = shapes[0];
expected_shape[axis] = concat_axis_size;
EXPECT_THAT(output->shape(), ::testing::ContainerEq(expected_shape));
Tensor::MappingGuard output_mapper(output);
......@@ -305,20 +309,38 @@ void OpenclRandomTest(const std::vector<std::vector<index_t>> &shapes,
} // namespace
TEST_F(ConcatOpTest, OPENCLAligned) {
OpenclRandomTest<float>({{3, 32, 32, 32}, {3, 32, 32, 64}}, 3);
OpenCLRandomTest<float>({{3, 32, 32, 32}, {3, 32, 32, 64}}, 3,
DataFormat::NHWC);
}
TEST_F(ConcatOpTest, OPENCLHalfAligned) {
OpenclRandomTest<half>({{3, 32, 32, 32}, {3, 32, 32, 64}}, 3);
OpenCLRandomTest<half>({{3, 32, 32, 32}, {3, 32, 32, 64}}, 3,
DataFormat::NHWC);
}
TEST_F(ConcatOpTest, OPENCLUnAligned) {
OpenclRandomTest<float>({{3, 32, 32, 13}, {3, 32, 32, 17}}, 3);
OpenCLRandomTest<float>({{3, 32, 32, 13}, {3, 32, 32, 17}}, 3,
DataFormat::NHWC);
}
TEST_F(ConcatOpTest, OPENCLAlignedMultiInput) {
OpenclRandomTest<float>(
{{3, 32, 32, 32}, {3, 32, 32, 32}, {3, 32, 32, 32}, {3, 32, 32, 32}}, 3);
OpenCLRandomTest<float>(
{{3, 32, 32, 32}, {3, 32, 32, 32}, {3, 32, 32, 32}, {3, 32, 32, 32}},
3, DataFormat::NHWC);
}
TEST_F(ConcatOpTest, GPUFallbackToCPU2DInput) {
OpenCLRandomTest<float>({{3, 4}, {3, 4}}, 1, DataFormat::DF_NONE);
}
TEST_F(ConcatOpTest, GPUFallbackToCPUChanNotDivisibleBy4) {
OpenCLRandomTest<float>({{1, 1, 4, 3}, {1, 1, 4, 3}}, 3,
DataFormat::DF_NONE);
}
TEST_F(ConcatOpTest, GPUFallbackToCPUAxis2) {
OpenCLRandomTest<float>({{1, 1, 4, 3}, {1, 1, 4, 3}}, 2,
DataFormat::DF_NONE);
}
} // namespace test
......
......@@ -78,7 +78,6 @@ MaceStatus BufferTransform<T>::Compute(OpContext *context,
const OpenCLBufferType type,
const int wino_blk_size,
Tensor *output) {
MACE_UNUSED(type);
MACE_UNUSED(wino_blk_size);
const DataType dt = DataTypeToEnum<T>::value;
switch (type) {
......@@ -92,8 +91,8 @@ MaceStatus BufferTransform<T>::Compute(OpContext *context,
if (input->dtype() != dt) {
return BufferTypeTransform(context, &kernel_, input, dt, output);
} else {
LOG(FATAL) << "Should not reach here. " << input->name()
<< "<" << type << "> to " << output->name();
SetFutureDefaultWaitFn(context->future());
output->ReuseTensorBuffer(*input);
return MaceStatus::MACE_SUCCESS;
}
}
......
......@@ -47,6 +47,7 @@ class OpenCLBufferTransformer {
const OpenCLBufferType type,
const MemoryType out_mem_type,
const int wino_blk_size,
const DataFormat data_format,
Tensor *output) {
Workspace *ws = context->workspace();
DataType dt = DataTypeToEnum<T>::value;
......@@ -65,7 +66,7 @@ class OpenCLBufferTransformer {
VLOG(2) << "Transform CPU Buffer " << input->name()
<< " to GPU Buffer " << internal_tensor->name()
<< " with data type " << dt;
if (input->shape().size() == 4) {
if (data_format == DataFormat::NHWC && input->shape().size() == 4) {
// 1. (NCHW -> NHWC)
std::vector<int> dst_dims = {0, 2, 3, 1};
std::vector<index_t> output_shape =
......@@ -103,7 +104,8 @@ class OpenCLBufferTransformer {
VLOG(2) << "Transform GPU Buffer " << internal_tensor.name()
<< " to CPU Buffer " << output->name()
<< " with data type " << dt;
if (internal_tensor.shape().size() == 4) {
if (data_format == DataFormat::NHWC &&
internal_tensor.shape().size() == 4) {
// NHWC -> NCHW
std::vector<int> dst_dims = {0, 3, 1, 2};
std::vector<index_t> output_shape =
......@@ -165,7 +167,7 @@ MaceStatus TransformFilter(
input->MarkUnused();
return OpenCLBufferTransformer<T>(input->memory_type(), mem_type).
Transform(&op_context, input, buffer_type, mem_type, wino_blk_size,
output);
DataFormat::DF_NONE, output);
}
} // namespace ops
......
......@@ -61,9 +61,6 @@ MaceStatus ChannelShuffleKernel<T>::Compute(
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
const index_t channels_per_group = channels / groups_;
MACE_CHECK(channels_per_group % 4 == 0,
"channels per group must be multiple of 4");
MACE_CHECK(groups_ % 4 == 0, "groups must be multiple of 4");
const index_t group_channel_blocks = RoundUpDiv4(channels_per_group);
const uint32_t gws[3] = {static_cast<uint32_t>(group_channel_blocks),
......
......@@ -67,18 +67,14 @@ MaceStatus ConcatKernel<T>::Compute(
const std::vector<const Tensor *> &input_list,
Tensor *output) {
const int inputs_count = input_list.size();
MACE_CHECK(inputs_count >= 2 && axis_ == 3)
<< "Concat opencl kernel only support >=2 elements with axis == 3";
const Tensor *input0 = input_list[0];
bool divisible_four = input0->dim(axis_) % 4 == 0;
std::vector<index_t> output_shape(input0->shape());
for (int i = 1; i < inputs_count; ++i) {
const Tensor *input = input_list[i];
MACE_CHECK(input->dim_size() == input0->dim_size(),
"Ranks of all input tensors must be same.");
divisible_four &= input->dim(axis_) % 4 == 0;
for (int j = 0; j < input->dim_size(); ++j) {
if (j == axis_) {
continue;
......@@ -88,9 +84,6 @@ MaceStatus ConcatKernel<T>::Compute(
}
output_shape[axis_] += input->dim(axis_);
}
MACE_CHECK(
inputs_count == 2 || divisible_four,
"Dimensions of inputs should be divisible by 4 when inputs_count > 2.");
std::vector<size_t> image_shape;
OpenCLUtil::CalImage2DShape(output_shape,
OpenCLBufferType::IN_OUT_CHANNEL,
......@@ -103,15 +96,9 @@ MaceStatus ConcatKernel<T>::Compute(
context, &kernel_, input_list[0], input_list[1],
DataTypeToEnum<T>::value, &input_shape_, output, &kwg_size_);
default:
if (divisible_four) {
return concat::ConcatN(context, &kernel_, input_list,
DataTypeToEnum<T>::value, output, &kwg_size_);
} else {
MACE_NOT_IMPLEMENTED;
}
return concat::ConcatN(context, &kernel_, input_list,
DataTypeToEnum<T>::value, output, &kwg_size_);
}
return MaceStatus::MACE_SUCCESS;
}
} // namespace image
......
......@@ -34,9 +34,7 @@ namespace image {
template <typename T>
class SplitKernel : public OpenCLSplitKernel {
public:
explicit SplitKernel(const int32_t axis) : axis_(axis) {
MACE_CHECK(axis == 3) << "GPU only support channel-dimension split";
}
explicit SplitKernel(const int32_t axis) : axis_(axis) {}
MaceStatus Compute(
OpContext *context,
const Tensor *input,
......@@ -56,8 +54,6 @@ MaceStatus SplitKernel<T>::Compute(
const index_t input_channels = input->dim(3);
const size_t outputs_count = output_list.size();
const index_t output_channels = input_channels / outputs_count;
MACE_CHECK(output_channels % 4 == 0)
<< "output channels of split op must be divisible by 4";
std::vector<index_t> output_shape(
{input->dim(0), input->dim(1), input->dim(2), output_channels});
......
......@@ -167,6 +167,9 @@ bool OpsTestNet::Setup(mace::DeviceType device) {
!ws_.GetTensor(input)->is_weight()) {
auto input_info = net_def.add_input_info();
input_info->set_name(input);
auto data_format = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
op_def_, "data_format", DataFormat::DF_NONE);
input_info->set_data_format(data_format);
auto &shape = ws_.GetTensor(input)->shape();
for (auto d : shape) {
input_info->add_dims(static_cast<int>(d));
......
......@@ -35,7 +35,11 @@ class PadOp<DeviceType::CPU, T> : public Operation {
constant_value_(Operation::GetOptionalArg<float>(
"constant_value", 0.0)) {
MACE_CHECK(paddings_.size() == 8);
paddings_ = TransposeShape<int, int>(paddings_, {0, 1, 6, 7, 2, 3, 4, 5});
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE));
if (df == DataFormat::NHWC) {
paddings_ = TransposeShape<int, int>(paddings_, {0, 1, 6, 7, 2, 3, 4, 5});
}
}
MaceStatus Run(OpContext *context) override {
......
......@@ -34,6 +34,7 @@ void Simple() {
.Output("Output")
.AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0})
.AddFloatArg("constant_value", 1.0)
.AddIntArg("data_format", DataFormat::NHWC)
.Finalize(net.NewOperatorDef());
// Run
......@@ -46,6 +47,7 @@ void Simple() {
.Output("TOutput")
.AddIntsArg("paddings", {0, 0, 1, 2, 1, 2, 0, 0})
.AddFloatArg("constant_value", 1.0)
.AddIntArg("data_format", DataFormat::NHWC)
.Finalize(net.NewOperatorDef());
// Run
......@@ -84,6 +86,7 @@ TEST_F(PadTest, ComplexCPU) {
.Output("TOutput")
.AddIntsArg("paddings", {0, 0, 1, 1, 1, 1, 1, 1})
.AddFloatArg("constant_value", 1.0)
.AddIntArg("data_format", DataFormat::NHWC)
.Finalize(net.NewOperatorDef());
// Run
......@@ -120,6 +123,7 @@ void Complex(const std::vector<index_t> &input_shape,
.Output("TOutput")
.AddIntsArg("paddings", paddings)
.AddFloatArg("constant_value", 1.0)
.AddIntArg("data_format", DataFormat::NHWC)
.Finalize(net.NewOperatorDef());
// Run
......@@ -135,6 +139,7 @@ void Complex(const std::vector<index_t> &input_shape,
.Output("Output")
.AddIntsArg("paddings", paddings)
.AddFloatArg("constant_value", 1.0)
.AddIntArg("data_format", DataFormat::NHWC)
.Finalize(net.NewOperatorDef());
// Run
......
......@@ -90,7 +90,9 @@ class ReduceOp<DeviceType::CPU, T> : public ReduceOpBase {
int index = axis_[i] >= 0 ?
axis_[i] :
axis_[i] + input->dim_size();
if (input->dim_size() == 4) {
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE));
if (df == DataFormat::NHWC && input->dim_size() == 4) {
if (index == 1 || index == 2) index = index + 1;
else if (index == 3) index = 1;
}
......
......@@ -44,6 +44,7 @@ void Simple(const std::vector<index_t> &input_shape,
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", keepdims ? 1 : 0)
.AddIntArg("reduce_type", type)
.AddIntArg("data_format", DataFormat::NHWC)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
......@@ -55,6 +56,7 @@ void Simple(const std::vector<index_t> &input_shape,
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", keepdims ? 1 : 0)
.AddIntArg("reduce_type", type)
.AddIntArg("data_format", DataFormat::NHWC)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
......@@ -82,6 +84,7 @@ void Simple3D(const std::vector<index_t> &input_shape,
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", keepdims ? 1 : 0)
.AddIntArg("reduce_type", type)
.AddIntArg("data_format", DataFormat::NHWC)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
......@@ -585,6 +588,7 @@ void RandomTest(const std::vector<index_t> &input_shape,
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", 1)
.AddIntArg("reduce_type", type)
.AddIntArg("data_format", DataFormat::NHWC)
.Output("OutputNCHW")
.Finalize(net.NewOperatorDef());
// Run
......@@ -596,6 +600,7 @@ void RandomTest(const std::vector<index_t> &input_shape,
.AddIntsArg("axis", axis)
.AddIntArg("keepdims", 1)
.AddIntArg("reduce_type", type)
.AddIntArg("data_format", DataFormat::NHWC)
.Output("OPENCLOutput")
.Finalize(net.NewOperatorDef());
// Run
......
......@@ -99,6 +99,21 @@ void RegisterReshape(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Reshape", ReshapeOp,
DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("Reshape")
.SetDevicePlacerFunc(
[](OpConstructContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
if (op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU };
}
return { DeviceType::CPU, DeviceType::GPU };
}));
}
} // namespace ops
......
......@@ -410,6 +410,22 @@ void RegisterSoftmax(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Softmax", SoftmaxOp,
DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("Softmax")
.SetDevicePlacerFunc(
[](OpConstructContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
if (op->output_shape(0).dims_size() != 2 &&
op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU };
}
return { DeviceType::CPU, DeviceType::GPU };
}));
}
} // namespace ops
......
......@@ -35,7 +35,9 @@ class SplitOp<DeviceType::CPU, T> : public Operation {
checked_(false) {}
void Validate() {
if (this->Input(0)->dim_size() == 4) {
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE));
if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) {
if (axis_ == 3) axis_ = 1;
else if (axis_ == 2) axis_ = 3;
else if (axis_ == 1) axis_ = 2;
......@@ -139,6 +141,24 @@ void RegisterSplit(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Split", SplitOp,
DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("Split")
.SetDevicePlacerFunc(
[](OpConstructContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
int axis = ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "axis", 3);
if (axis != 3 || op->output_shape(0).dims_size() != 4 ||
(op->output_shape(0).dims()[3] % 4 != 0)) {
return { DeviceType::CPU };
}
return { DeviceType::CPU, DeviceType::GPU };
}));
}
} // namespace ops
......
......@@ -54,6 +54,7 @@ void RandomTest(const int num_outputs, int axis) {
builder = builder.Output(MakeString("Output", i));
}
builder.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.AddIntArg("data_format", DataFormat::NHWC)
.Finalize(net.NewOperatorDef());
// Run
......
......@@ -31,11 +31,14 @@ class SqueezeOp : public Operation {
MaceStatus Run(OpContext *context) override {
MACE_UNUSED(context);
if (!checked_ && D == DeviceType::CPU
&& DataTypeToEnum<T>::value != DT_UINT8
&& this->Input(0)->dim_size() == 4) {
if (axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2) {
axis_[0] = 2;
axis_[1] = 3;
&& DataTypeToEnum<T>::value != DT_UINT8) {
auto df = static_cast<DataFormat>(Operation::GetOptionalArg<int>(
"data_format", DataFormat::DF_NONE));
if (df == DataFormat::NHWC && this->Input(0)->dim_size() == 4) {
if (axis_.size() == 2 && axis_[0] == 1 && axis_[1] == 2) {
axis_[0] = 2;
axis_[1] = 3;
}
}
checked_ = true;
}
......@@ -70,6 +73,21 @@ void RegisterSqueeze(OpRegistryBase *op_registry) {
MACE_REGISTER_OP(op_registry, "Squeeze", SqueezeOp, DeviceType::GPU, float);
MACE_REGISTER_OP(op_registry, "Squeeze", SqueezeOp, DeviceType::GPU, half);
#endif // MACE_ENABLE_OPENCL
MACE_REGISTER_OP_CONDITION(
op_registry,
OpConditionBuilder("Squeeze")
.SetDevicePlacerFunc(
[](OpConstructContext *context) -> std::set<DeviceType> {
auto op = context->operator_def();
if (op->output_shape_size() != op->output_size()) {
return { DeviceType::CPU, DeviceType::GPU };
}
if (op->output_shape(0).dims_size() != 2 &&
op->output_shape(0).dims_size() != 4) {
return { DeviceType::CPU };
}
return { DeviceType::CPU, DeviceType::GPU };
}));
}
} // namespace ops
......
......@@ -30,6 +30,7 @@ void TestSqueeze(const std::vector<index_t> &org_shape,
OpDefBuilder("Squeeze", "SqueezeTest")
.Input("Input")
.AddIntsArg("axis", axis)
.AddIntArg("data_format", DataFormat::NHWC)
.Output("Output")
.Finalize(net.NewOperatorDef());
......
......@@ -259,6 +259,7 @@ class TransformerRule(Enum):
TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN = 36
FOLD_FC_RESHAPE = 37
TRANSFORM_CHANNEL_SHUFFLE = 38
UPDATE_DATA_FORMAT = 39
class ConverterInterface(object):
......@@ -479,6 +480,8 @@ class ConverterOption(object):
TransformerRule.ADD_OPENCL_INFORMATIONS,
# for quantization entropy calibration use
TransformerRule.SORT_BY_EXECUTION,
# update the data format of ops
TransformerRule.UPDATE_DATA_FORMAT,
# Need to be put after SORT_BY_EXECUTION
TransformerRule.ADD_QUANTIZE_TENSOR_RANGE,
]
......
......@@ -96,6 +96,7 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.ADD_OPENCL_INFORMATIONS:
self.add_opencl_informations,
TransformerRule.SORT_BY_EXECUTION: self.sort_by_execution,
TransformerRule.UPDATE_DATA_FORMAT: self.update_data_format,
TransformerRule.CHECK_QUANTIZE_INFO:
self.check_quantize_info,
TransformerRule.TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN:
......@@ -1358,6 +1359,24 @@ class Transformer(base_converter.ConverterInterface):
out_shape.dims for out_shape in op.output_shape]))
return False
def update_data_format(self):
data_format_flag = DataFormat.NHWC.value
for input_node in self._option.input_nodes.values():
if input_node.data_format.value == DataFormat.DF_NONE.value:
data_format_flag = DataFormat.DF_NONE.value
net = self._model
for op in net.op:
data_format_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_data_format_str)
if not data_format_arg:
data_format_arg = op.arg.add()
data_format_arg.name = MaceKeyword.mace_data_format_str
data_format_arg.i = data_format_flag
elif data_format_arg.i != data_format_flag:
data_format_arg.i = data_format_flag
return False
def quantize_nodes(self):
if not self._option.quantize:
return False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册