diff --git a/mace/core/runtime/opencl/opencl_runtime.h b/mace/core/runtime/opencl/opencl_runtime.h index f21ade57fa73cabd48a50e6baf4f7284cc51e40f..ed7d0c68d353f4a1bb10581657e2f7bb3830fc4b 100644 --- a/mace/core/runtime/opencl/opencl_runtime.h +++ b/mace/core/runtime/opencl/opencl_runtime.h @@ -33,8 +33,8 @@ class OpenCLRuntime { private: cl::Context context_; - cl::CommandQueue command_queue_; cl::Device device_; + cl::CommandQueue command_queue_; cl::Program program_; std::once_flag build_flag_; }; diff --git a/mace/kernels/addn.h b/mace/kernels/addn.h index 4e30d314432c9cdf6cd18f98ecfd2fea2db9d94c..b47ef7e73f83a780fd4baf5aa729e980732da7ed 100644 --- a/mace/kernels/addn.h +++ b/mace/kernels/addn.h @@ -10,22 +10,31 @@ namespace mace { namespace kernels { -template +template struct AddNFunctor { - void operator()(const vector &inputs, T *output, index_t size) { - memset(output, 0, size * sizeof(T)); - int n = inputs.size(); + void operator()(std::vector &input_tensors, Tensor *output_tensor) { + Tensor::MappingGuard output_map(output_tensor); + index_t size = input_tensors[0]->size(); + T *output_ptr = output_tensor->mutable_data(); + memset(output_ptr, 0, size * sizeof(T)); + int n = input_tensors.size(); for (int i = 0; i < n; ++i) { + Tensor::MappingGuard input_map(input_tensors[i]); + const T *input_ptr = input_tensors[i]->data(); for (index_t j = 0; j < size; ++j) { - output[j] += inputs[i][j]; + output_ptr[j] += input_ptr[j]; } } } }; -template <> +template<> void AddNFunctor::operator()( - const vector &inputs, float *output, index_t size); + std::vector &input_tensors, Tensor *output_tensor); + +template<> +void AddNFunctor::operator()( + std::vector &inputs, Tensor *output); } // namespace kernels } // namespace mace diff --git a/mace/kernels/neon/addn_neon.cc b/mace/kernels/neon/addn_neon.cc index fed0c4e1e97726aaed46115de236515250b3c8ea..d7ff94864ea3ba7469cea561558e39b41624db1f 100644 --- a/mace/kernels/neon/addn_neon.cc +++ b/mace/kernels/neon/addn_neon.cc @@ -10,10 +10,12 @@ namespace kernels { template <> void AddNFunctor::operator()( - const vector &inputs, float *output, index_t size) { + std::vector &input_tensors, Tensor *output_tensor) { // TODO: neon mem copy - memset(output, 0, size * sizeof(float)); - int n = inputs.size(); + index_t size = output_tensor->size(); + float *output_ptr = output_tensor->mutable_data(); + memset(output_ptr, 0, size * sizeof(float)); + int n = input_tensors.size(); int64_t cost = size * n; int64_t groups = 1; if (cost > kCostPerGroup) { @@ -27,8 +29,9 @@ void AddNFunctor::operator()( int nn = count >> 2; int remain = count - (nn << 2); for (int64_t j = 0; j < n; ++j) { - const float *inptr = inputs[j] + i; - float *outptr = output + i; + const float *input_base = input_tensors[j]->data(); + const float *inptr = input_base + i; + float *outptr = output_ptr + i; for (int k = 0; k < nn; ++k) { float32x4_t _inptr = vld1q_f32(inptr); float32x4_t _outptr = vld1q_f32(outptr); diff --git a/mace/kernels/opencl/addn.cc b/mace/kernels/opencl/addn.cc new file mode 100644 index 0000000000000000000000000000000000000000..24d084cab9f4844d7e2ec756007212234375865d --- /dev/null +++ b/mace/kernels/opencl/addn.cc @@ -0,0 +1,54 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/kernels/addn.h" +#include "mace/core/runtime/opencl/opencl_runtime.h" + +namespace mace { +namespace kernels { + +static void Add2(const Tensor *input0, const Tensor *input1, Tensor *output) { + index_t element_size = input0->NumElements(); + index_t blocks = (element_size + 3) / 4; + + const uint32_t gws = blocks; + + auto runtime = OpenCLRuntime::Get(); + auto program = runtime->program(); + + auto addn_kernel = cl::Kernel(program, "add2"); + + const uint32_t lws = runtime->GetKernelMaxWorkGroupSize(addn_kernel); + + uint32_t idx = 0; + addn_kernel.setArg(idx++, *(static_cast(input0->buffer()))); + addn_kernel.setArg(idx++, *(static_cast(input1->buffer()))); + addn_kernel.setArg(idx++, static_cast(element_size)); + addn_kernel.setArg(idx++, *(static_cast(output->buffer()))); + + cl_int error = runtime->command_queue().enqueueNDRangeKernel( + addn_kernel, cl::NullRange, + cl::NDRange(gws), + cl::NDRange(lws)); + MACE_CHECK(error == CL_SUCCESS); +} + +template<> +void AddNFunctor::operator()(std::vector &input_tensors, + Tensor *output_tensor) { + + if (input_tensors.empty() || input_tensors.front() == nullptr) { + return; + } + size_t size = input_tensors.size(); + + switch (size) { + case 2:Add2(input_tensors[0], input_tensors[1], output_tensor); + break; + default:MACE_NOT_IMPLEMENTED; + } +}; + +} // namespace kernels +} // namespace mace diff --git a/mace/kernels/opencl/cl/addn.cl b/mace/kernels/opencl/cl/addn.cl new file mode 100644 index 0000000000000000000000000000000000000000..eb1be1cac7fa0222a01bc27d37f307e36f86942a --- /dev/null +++ b/mace/kernels/opencl/cl/addn.cl @@ -0,0 +1,17 @@ +__kernel void add2(__global const float *input0, + __global const float *input1, + __private const int size, + __global float *output) { + int idx = get_global_id(0); + + if (idx + 4 > size) { + for(; idx < size; ++idx) { + *(output+idx) = *(input0+idx) + *(input1+idx); + } + } else { + float4 in_data0 = vload4(idx, input0); + float4 in_data1 = vload4(idx, input1); + vstore4(in_data0+in_data1, idx, output); + } +} + diff --git a/mace/ops/addn.cc b/mace/ops/addn.cc index 0598d1cdc93e634a22aa9074798b18292f525d8b..b4b74b04b84d01ac4f6941c649acabc04f25c0d8 100644 --- a/mace/ops/addn.cc +++ b/mace/ops/addn.cc @@ -12,4 +12,6 @@ REGISTER_CPU_OPERATOR(AddN, AddNOp); REGISTER_NEON_OPERATOR(AddN, AddNOp); #endif // __ARM_NEON +REGISTER_OPENCL_OPERATOR(AddN, AddNOp); + } // namespace mace diff --git a/mace/ops/addn.h b/mace/ops/addn.h index b626596306f8cb86715199c3b2cfebaa47bd1d30..a2ffefbbc54e846317415e653078706a2938f67b 100644 --- a/mace/ops/addn.h +++ b/mace/ops/addn.h @@ -10,7 +10,7 @@ namespace mace { -template +template class AddNOp : public Operator { public: AddNOp(const OperatorDef &operator_def, Workspace *ws) @@ -19,16 +19,13 @@ class AddNOp : public Operator { bool Run() override { Tensor *output_tensor = this->outputs_[0]; output_tensor->ResizeLike(this->inputs_[0]); - T *output = output_tensor->mutable_data(); - index_t size = this->inputs_[0]->size(); int n = this->inputs_.size(); - vector inputs(n); + vector inputs(n, nullptr); for (int i = 0; i < n; ++i) { - const Tensor *input_tensor = this->inputs_[i]; - inputs[i] = input_tensor->data(); + inputs[i] = this->inputs_[i]; } - functor_(inputs, output, size); + functor_(inputs, output_tensor); return true; } diff --git a/mace/ops/addn_test.cc b/mace/ops/addn_test.cc index a48d066235eec33f1465ffe6f74fce6bb97e0d37..3fc58011f623ebf5ff541c1ed2f48d2b9eb5a959 100644 --- a/mace/ops/addn_test.cc +++ b/mace/ops/addn_test.cc @@ -9,9 +9,44 @@ namespace mace { class AddnOpTest : public OpsTestBase {}; -TEST_F(AddnOpTest, AddnOp) { +template +void SimpleAdd2() { // Construct graph - auto &net = test_net(); + OpsTestNet net; + OpDefBuilder("AddN", "AddNTest") + .Input("Input1") + .Input("Input2") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddInputFromArray("Input1", {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}); + net.AddInputFromArray("Input2", {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}); + + // Run + net.RunOp(D); + + auto expected = CreateTensor({1, 1, 2, 3}, {2, 4, 6, 8, 10, 12}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(AddnOpTest, CPUSimpleAdd2) { + SimpleAdd2(); +} + +TEST_F(AddnOpTest, NEONSimpleAdd2) { + SimpleAdd2(); +} + +TEST_F(AddnOpTest, OPENCLSimpleAdd2) { + SimpleAdd2(); +} + +template +void SimpleAdd3() { + // Construct graph + OpsTestNet net; OpDefBuilder("AddN", "AddNTest") .Input("Input1") .Input("Input2") @@ -20,20 +55,62 @@ TEST_F(AddnOpTest, AddnOp) { .Finalize(net.NewOperatorDef()); // Add input data - net.AddRandomInput("Input1", {1, 2, 3, 4}); - net.AddRandomInput("Input2", {1, 2, 3, 4}); - net.AddRandomInput("Input3", {1, 2, 3, 4}); + net.AddInputFromArray("Input1", {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}); + net.AddInputFromArray("Input2", {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}); + net.AddInputFromArray("Input3", {1, 1, 2, 3}, {1, 2, 3, 4, 5, 6}); // Run - net.RunOp(); + net.RunOp(D); + + auto expected = CreateTensor({1, 1, 2, 3}, {3, 6, 9, 12, 15, 18}); + + ExpectTensorNear(*expected, *net.GetOutput("Output"), 1e-5); +} + +TEST_F(AddnOpTest, CPUSimpleAdd3) { + SimpleAdd3(); +} - Tensor expected; - expected.Copy(*net.GetOutput("Output")); +TEST_F(AddnOpTest, NEONSimpleAdd3) { + SimpleAdd3(); +} + +template +void RandomTest() { + // Construct graph + OpsTestNet net; + OpDefBuilder("AddN", "AddNTest") + .Input("Input1") + .Input("Input2") + .Output("Output") + .Finalize(net.NewOperatorDef()); + + // Add input data + net.AddRandomInput("Input1", {1, 2, 3, 4}); + net.AddRandomInput("Input2", {1, 2, 3, 4}); // Check - net.RunOp(DeviceType::NEON); + net.RunOp(D); + + Tensor result; + result.Copy(*net.GetOutput("Output")); + + // Run + net.RunOp(); + + ExpectTensorNear(*net.GetOutput("Output"), result, 1e-5); +} + +TEST_F(AddnOpTest, CPURandom) { + RandomTest(); +} + +TEST_F(AddnOpTest, NEONRandom) { + RandomTest(); +} - ExpectTensorNear(expected, *net.GetOutput("Output"), 0.01); +TEST_F(AddnOpTest, OPENCLRandom) { + RandomTest(); } } // namespace mace diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index 671230e51dfb01fa11d31e96e677817e2ac0ed88..0568bfcd82839c31c9ffcf2b56d1f1bde490089e 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -12,6 +12,33 @@ pooling_type_mode = { 'MaxPool': 2 } +def convert_tensor(op, tensor): + tf_tensor = op.outputs[0].eval() + tensor.name = op.outputs[0].name + + shape = list(tf_tensor.shape) + if (op.name.find('pointwise_kernel') != -1 or + op.name.find('depthwise_kernel') != -1 or + op.name.endswith('weights') or + op.name.endswith('kernel')) \ + and op.outputs[0].consumers()[0].type.find('Conv') != -1: + if op.outputs[0].consumers()[0].get_attr('data_format') == 'NCHW': + tf_tensor = np.transpose(tf_tensor, axes=(3, 2, 0, 1)) + shape = [shape[3], shape[2], shape[0], shape[1]] + # print (tensor.name, shape) + tensor.dims.extend(shape) + + tf_dt = op.get_attr('dtype') + if tf_dt == tf.float32: + tensor.data_type = mace_pb2.DT_FLOAT + tensor.float_data.extend(tf_tensor.astype(float).flat) + elif tf_dt == tf.int32: + tensor.data_type = mace_pb2.DT_INT32 + tensor.int32_data.extend(tf_tensor.astype(np.int32).flat) + else: + raise Exception("Not supported tensor type: " + tf_dt.name) + + def get_input_tensor(op, index): input_tensor = op.inputs[index] if input_tensor.op.type == 'Reshape': @@ -24,26 +51,11 @@ def convert_ops(unresolved_ops, net_def): first_op = unresolved_ops[0] - if first_op.type == 'Placeholder' or first_op.type == 'Reshape': + if first_op.type in ['Placeholder', 'Reshape', 'Identity']: pass elif first_op.type == 'Const': - tf_tensor = first_op.outputs[0].eval() tensor = net_def.tensors.add() - tensor.name = first_op.outputs[0].name - # TODO: support other type than float - tensor.data_type = mace_pb2.DT_FLOAT - - shape = list(tf_tensor.shape) - if (first_op.name.find('pointwise_kernel') != -1 or - first_op.name.find('depthwise_kernel') != -1 or - first_op.name.endswith('weights') or - first_op.name.endswith('kernel')) \ - and first_op.outputs[0].consumers()[0].type.find('Conv') != -1: - tf_tensor = np.transpose(tf_tensor, axes=(3, 2, 0, 1)) - shape = [shape[3], shape[2], shape[0], shape[1]] - # print (tensor.name, shape) - tensor.dims.extend(shape) - tensor.float_data.extend(tf_tensor.astype(float).flat) + convert_tensor(first_op, tensor) elif first_op.type == 'Conv2D' or first_op.type == 'DepthwiseConv2dNative': op_def = net_def.op.add() op_def.name = first_op.name @@ -61,9 +73,7 @@ def convert_ops(unresolved_ops, net_def): strides_arg.ints.extend(first_op.get_attr('strides')[2:]) data_format_arg = op_def.arg.add() data_format_arg.name = 'data_format' - data_format_arg.s = first_op.get_attr('data_format') - if first_op.get_attr('data_format') != 'NCHW': - raise Exception('only support NCHW now') + data_format_arg.s = 'NCHW' if ops_count >= 2 and unresolved_ops[1].type == 'BiasAdd': bias_add_op = unresolved_ops[1] @@ -78,7 +88,8 @@ def convert_ops(unresolved_ops, net_def): sub_op = unresolved_ops[5] add_1_op = unresolved_ops[6] # print (mul_op.type, mul_2_op.type, mul_1_op.type, sub_op.type) - if mul_op.type != 'Mul' or mul_2_op.type != 'Mul' or mul_1_op.type != 'Mul' or sub_op.type != 'Sub' or add_1_op.type != 'Add': + if mul_op.type != 'Mul' or mul_2_op.type != 'Mul' or \ + mul_1_op.type != 'Mul' or sub_op.type != 'Sub' or add_1_op.type != 'Add': raise Exception('Invalid BatchNorm Op') input_name = get_input_tensor(mul_1_op, 0).name @@ -104,12 +115,6 @@ def convert_ops(unresolved_ops, net_def): max_limit_arg = op_def.arg.add() max_limit_arg.name = 'max_limit' max_limit_arg.f = 6 - elif first_op.type == 'Relu': - op_def = net_def.op.add() - op_def.name = first_op.name - op_def.type = first_op.type - op_def.input.extend([input.name for input in first_op.inputs]) - op_def.output.extend([output.name for output in first_op.outputs]) elif first_op.type == 'AvgPool' or first_op.type == 'MaxPool': op_def = net_def.op.add() op_def.name = first_op.name @@ -130,9 +135,19 @@ def convert_ops(unresolved_ops, net_def): kernels_arg.ints.extend(first_op.get_attr('ksize')[2:]) data_format_arg = op_def.arg.add() data_format_arg.name = 'data_format' - data_format_arg.s = first_op.get_attr('data_format') - if first_op.get_attr('data_format') != 'NCHW': - raise Exception('only support NCHW now') + data_format_arg.s = 'NCHW' + elif first_op.type == 'Add': + op_def = net_def.op.add() + op_def.name = first_op.name + op_def.type = "AddN" + op_def.input.extend([input.name for input in first_op.inputs]) + op_def.output.extend([output.name for output in first_op.outputs]) + elif first_op.type in ['Relu', 'ResizeBilinear', 'SpaceToBatchND', 'BatchToSpaceND']: + op_def = net_def.op.add() + op_def.name = first_op.name + op_def.type = first_op.type + op_def.input.extend([input.name for input in first_op.inputs]) + op_def.output.extend([output.name for output in first_op.outputs]) else: raise Exception('Unknown Op: ' + first_op.name) pass @@ -152,4 +167,6 @@ def convert_to_mace_pb(input_graph_def): while len(unresolved_ops) > 0: convert_ops(unresolved_ops, net_def) + print "Done." + return net_def