From d4b4605fc9ebceaa0d9c11717e866947a0d71a6d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Wed, 6 Sep 2017 11:27:19 +0800 Subject: [PATCH] AddN op; Refactor op default template impl. --- mace/examples/helloworld.cc | 5 ++-- mace/kernels/addn.h | 4 +-- mace/kernels/relu.h | 4 +-- mace/ops/BUILD | 17 +----------- mace/ops/addn.cc | 25 ++++++++++++++++++ mace/ops/addn.h | 27 +++++++++++++++++++ mace/ops/relu.cc | 9 ------- mace/ops/relu.h | 8 +++++- mace/ops/relu_test.cc | 52 ------------------------------------- 9 files changed, 67 insertions(+), 84 deletions(-) create mode 100644 mace/ops/addn.cc create mode 100644 mace/ops/addn.h delete mode 100644 mace/ops/relu_test.cc diff --git a/mace/examples/helloworld.cc b/mace/examples/helloworld.cc index 2e9eb1e2..25a2e2ea 100644 --- a/mace/examples/helloworld.cc +++ b/mace/examples/helloworld.cc @@ -27,10 +27,11 @@ int main() { arg_1->set_f(1.5); OperatorDef op_def_2; + op_def_2.add_input("Output0"); op_def_2.add_input("Output1"); op_def_2.add_output("Output2"); - op_def_2.set_name("ReluTest2"); - op_def_2.set_type("Relu"); + op_def_2.set_name("AddNTest"); + op_def_2.set_type("AddN"); auto arg_2 = op_def_2.add_arg(); arg_2->set_name("arg0"); arg_2->set_f(2.5); diff --git a/mace/kernels/addn.h b/mace/kernels/addn.h index 70a0d584..d04885a6 100644 --- a/mace/kernels/addn.h +++ b/mace/kernels/addn.h @@ -18,10 +18,10 @@ void AddNFuntion(const vector& input_tensor, Tensor *output_tenso int64 size = input_tensor[0]->size(); vector inputs(n); for (int i = 0; i < n; ++i) { - inputs[i] = input_tensor[i]->data(); + inputs[i] = input_tensor[i]->data(); } output_tensor->ResizeLike(input_tensor[0]); - float* output = output_tensor->mutable_data(); + T* output = output_tensor->mutable_data(); for (int i = 0; i < n; ++i) { for (int64 j = 0; j < size; ++j) { diff --git a/mace/kernels/relu.h b/mace/kernels/relu.h index e2400e97..086f762b 100644 --- a/mace/kernels/relu.h +++ b/mace/kernels/relu.h @@ -14,8 +14,8 @@ template void ReluFuntion(const Tensor *input_tensor, Tensor *output_tensor) { int64 size = input_tensor->size(); output_tensor->ResizeLike(input_tensor); - const float *input = input_tensor->data(); - float *output = output_tensor->mutable_data(); + const T *input = input_tensor->data(); + T *output = output_tensor->mutable_data(); for (int64 i = 0; i < size; ++i) { output[i] = std::max(input[i], static_cast(0)); diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 1acbc1fd..ea4791c4 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -12,7 +12,7 @@ load("//mace:mace.bzl", "if_android") cc_library( name = "ops", - srcs = ["relu.cc"], + srcs = glob(["*.cc"]), hdrs = glob(["*.h"]), deps = [ "//mace/proto:cc_proto", @@ -23,19 +23,4 @@ cc_library( alwayslink = 1, ) -cc_test( - name = "relu_test", - srcs = ["relu_test.cc",], - deps = [ - "@gtest//:gtest_main", - ":ops", - ], - copts = ['-std=c++11'], - linkopts = if_android([ - "-pie", - "-llog", - "-latomic", - ]), - linkstatic = 1, -) diff --git a/mace/ops/addn.cc b/mace/ops/addn.cc new file mode 100644 index 00000000..94f506f7 --- /dev/null +++ b/mace/ops/addn.cc @@ -0,0 +1,25 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/ops/addn.h" +#include "mace/proto/mace.pb.h" +#if __ARM_NEON +#include "mace/kernels/neon/addn_neon.h" +#endif // __ARM_NEON + +namespace mace { + +REGISTER_CPU_OPERATOR(AddN, AddNOp); + +#if __ARM_NEON +template <> +bool AddNOp::Run() { + Tensor* output_tensor = Output(0); + kernels::NeonAddNFuntion_float(Inputs(), output_tensor); + return true; +} +REGISTER_NEON_OPERATOR(AddN, AddNOp); +#endif // __ARM_NEON + +} // namespace mace diff --git a/mace/ops/addn.h b/mace/ops/addn.h new file mode 100644 index 00000000..66e1dba0 --- /dev/null +++ b/mace/ops/addn.h @@ -0,0 +1,27 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_OPS_ADDN_H_ +#define MACE_OPS_ADDN_H_ + +#include "mace/core/operator.h" +#include "mace/kernels/addn.h" + +namespace mace { + +template +class AddNOp : public Operator { + public: + AddNOp(const OperatorDef &operator_def, Workspace *ws) + : Operator(operator_def, ws) {} + bool Run() override { + Tensor* output_tensor = this->Output(0); + kernels::AddNFuntion(this->Inputs(), output_tensor); + return true; + } +}; + +} // namespace mace + +#endif // MACE_OPS_ADDN_H_ diff --git a/mace/ops/relu.cc b/mace/ops/relu.cc index 59d4e3b7..c2dab6e5 100644 --- a/mace/ops/relu.cc +++ b/mace/ops/relu.cc @@ -4,23 +4,14 @@ #include "mace/ops/relu.h" #include "mace/proto/mace.pb.h" -#include "mace/kernels/relu.h" #if __ARM_NEON #include "mace/kernels/neon/relu_neon.h" #endif // __ARM_NEON namespace mace { -template <> -bool ReluOp::Run() { - const Tensor* input_tensor = Input(0); - Tensor* output_tensor = Output(0); - kernels::ReluFuntion(input_tensor, output_tensor); - return true; -} REGISTER_CPU_OPERATOR(Relu, ReluOp); - #if __ARM_NEON template <> bool ReluOp::Run() { diff --git a/mace/ops/relu.h b/mace/ops/relu.h index 8a0ea34d..2965dc55 100644 --- a/mace/ops/relu.h +++ b/mace/ops/relu.h @@ -6,6 +6,7 @@ #define MACE_OPS_RELU_H_ #include "mace/core/operator.h" +#include "mace/kernels/relu.h" namespace mace { @@ -14,7 +15,12 @@ class ReluOp : public Operator { public: ReluOp(const OperatorDef &operator_def, Workspace *ws) : Operator(operator_def, ws) {} - bool Run() override; + bool Run() override { + const Tensor* input_tensor = this->Input(0); + Tensor* output_tensor = this->Output(0); + kernels::ReluFuntion(input_tensor, output_tensor); + return true; + } }; } // namespace mace diff --git a/mace/ops/relu_test.cc b/mace/ops/relu_test.cc deleted file mode 100644 index 209fe83a..00000000 --- a/mace/ops/relu_test.cc +++ /dev/null @@ -1,52 +0,0 @@ -// -// Copyright (c) 2017 XiaoMi All rights reserved. -// - -#include "gtest/gtest.h" - -#include "mace/core/operator.h" -#include "mace/core/net.h" - -using namespace mace; - -TEST(ReluTest, Relu) { - OperatorRegistry* registry = gDeviceTypeRegistry()->at(DeviceType::CPU); - vector registry_keys = registry->Keys(); - for (auto& key: registry_keys) { - VLOG(0) << "registry_op: " << key; - } - - // Construct graph - OperatorDef op_def; - op_def.add_input("Input0"); - op_def.add_output("Output0"); - op_def.set_name("ReluTest"); - op_def.set_type("Relu"); - auto arg = op_def.add_arg(); - arg->set_name("arg0"); - arg->set_f(1.5); - - NetDef net_def; - net_def.set_name("NetTest"); - net_def.add_op()->CopyFrom(op_def); - - VLOG(0) << net_def.DebugString(); - - // Create workspace and input tensor - Workspace ws; - Tensor* input = ws.CreateTensor("Input0", cpu_allocator(), DataType::DT_FLOAT); - input->Resize({2,3}); - float* input_data = input->mutable_data(); - for (int i = 0; i < 6; ++i) { - input_data[i] = i-3; - } - - // Create Net & run - auto net = CreateNet(net_def, &ws, DeviceType::CPU); - net->Run(); - - // Create Op & run - auto op = CreateOperator(op_def, &ws, DeviceType::CPU); - ASSERT_FLOAT_EQ(1.5f, op->GetSingleArgument("arg0", 1.0f)); - -} \ No newline at end of file -- GitLab