From f6006d5e3156ea32085db55ed023784502c3fea2 Mon Sep 17 00:00:00 2001 From: liuqi Date: Mon, 4 Dec 2017 18:05:18 +0800 Subject: [PATCH] Change the axis tensor of concat to an attribute. --- mace/kernels/concat.h | 24 ++++++++++++++++-------- mace/kernels/opencl/cl/concat.cl | 2 -- mace/kernels/opencl/concat.cc | 9 ++++----- mace/ops/BUILD | 16 ++++++++++++++++ mace/ops/concat.h | 14 +++++--------- mace/ops/concat_test.cc | 12 ++++-------- 6 files changed, 45 insertions(+), 32 deletions(-) diff --git a/mace/kernels/concat.h b/mace/kernels/concat.h index 0b1b4834..e70b4e73 100644 --- a/mace/kernels/concat.h +++ b/mace/kernels/concat.h @@ -13,17 +13,24 @@ namespace mace { namespace kernels { +struct ConcatFunctorBase { + ConcatFunctorBase(const int32_t axis): axis_(axis){} + + int32_t axis_; +}; + template -struct ConcatFunctor { +struct ConcatFunctor : ConcatFunctorBase { + ConcatFunctor(const int32_t axis): ConcatFunctorBase(axis){} + void operator()(const std::vector &input_list, - const int32_t axis, Tensor *output) { const Tensor *input0 = input_list.front(); - const int inputs_count = input_list.size() - 1; + const int inputs_count = input_list.size(); std::vector output_shape(input0->shape()); index_t inner_size = 1; - for (int i = 0; i < axis; ++i) { + for (int i = 0; i < axis_; ++i) { inner_size *= output_shape[i]; } std::vector outer_sizes(inputs_count, 0); @@ -33,14 +40,14 @@ struct ConcatFunctor { MACE_CHECK(input->dim_size() == input0->dim_size(), "Ranks of all input tensors must be same."); for (int j = 0; j < input->dim_size(); ++j) { - if (j == axis) { + if (j == axis_) { continue; } MACE_CHECK(input->dim(j) == input0->dim(j), "Dimensions of inputs should equal except axis."); } outer_sizes[i] = input->size() / inner_size; - output_shape[axis] += input->dim(axis); + output_shape[axis_] += input->dim(axis_); } output->Resize(output_shape); @@ -67,9 +74,10 @@ struct ConcatFunctor { }; template -struct ConcatFunctor { +struct ConcatFunctor : ConcatFunctorBase{ + ConcatFunctor(const int32_t axis): ConcatFunctorBase(axis){} + void operator()(const std::vector &input_list, - const int32_t axis, Tensor *output); }; diff --git a/mace/kernels/opencl/cl/concat.cl b/mace/kernels/opencl/cl/concat.cl index fdffde57..5ae4fb04 100644 --- a/mace/kernels/opencl/cl/concat.cl +++ b/mace/kernels/opencl/cl/concat.cl @@ -32,8 +32,6 @@ __kernel void concat_channel(__read_only image2d_t input0, const int hb_idx = get_global_id(2); const int input0_chan_blk = (input0_chan + 3) / 4; - const sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - DATA_TYPE4 data = 0; #ifdef DIVISIBLE_FOUR if (chan_blk_idx + 1 <= input0_chan_blk) { diff --git a/mace/kernels/opencl/concat.cc b/mace/kernels/opencl/concat.cc index 8f61e353..80a23a44 100644 --- a/mace/kernels/opencl/concat.cc +++ b/mace/kernels/opencl/concat.cc @@ -60,10 +60,9 @@ static void Concat2(const Tensor *input0, template void ConcatFunctor::operator()(const std::vector &input_list, - const int32_t axis, Tensor *output) { - const int inputs_count = input_list.size() - 1; - MACE_CHECK(inputs_count == 2 && axis == 3) + const int inputs_count = input_list.size(); + MACE_CHECK(inputs_count == 2 && axis_ == 3) << "Concat opencl kernel only support two elements with axis == 3"; const Tensor *input0 = input_list[0]; @@ -74,13 +73,13 @@ void ConcatFunctor::operator()(const std::vectordim_size() == input0->dim_size(), "Ranks of all input tensors must be same."); for (int j = 0; j < input->dim_size(); ++j) { - if (j == axis) { + if (j == axis_) { continue; } MACE_CHECK(input->dim(j) == input0->dim(j), "Dimensions of inputs should equal except axis."); } - output_shape[axis] += input->dim(axis); + output_shape[axis_] += input->dim(axis_); } std::vector image_shape; CalImage2DShape(output_shape, BufferType::IN_OUT, image_shape); diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 89f6b8b7..7d5e6832 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -61,6 +61,22 @@ cc_test( ], ) +cc_test( + name = "concat_test", + testonly = 1, + srcs = glob( + ["concat_test.cc"], + ), + copts = ["-std=c++11"], + linkopts = ["-fopenmp"], + linkstatic = 1, + deps = [ + ":ops", + ":test", + "@gtest//:gtest_main", + ], +) + cc_test( name = "ops_benchmark", testonly = 1, diff --git a/mace/ops/concat.h b/mace/ops/concat.h index b64060b2..77e43030 100644 --- a/mace/ops/concat.h +++ b/mace/ops/concat.h @@ -14,17 +14,13 @@ template class ConcatOp : public Operator { public: ConcatOp(const OperatorDef &op_def, Workspace *ws) - : Operator(op_def, ws) {} + : Operator(op_def, ws), + functor_(OperatorBase::GetSingleArgument("axis", 3)){} bool Run() override { - const int32_t inputs_count = this->InputSize() - 1; + MACE_CHECK(this->InputSize() >= 2) << "There must be at least two inputs to concat"; const std::vector input_list = this->Inputs(); - const Tensor *axis_tensor = this->Input(inputs_count); - MACE_CHECK(axis_tensor->dim_size() == 0, - "axis should be a scalar integer, but got shape: ", - axis_tensor->dim_size()); - Tensor::MappingGuard axis_mapper(axis_tensor); - const int32_t concat_axis = *(axis_tensor->data()); + const int32_t concat_axis = OperatorBase::GetSingleArgument("axis", 3); const int32_t input_dims = input_list[0]->dim_size(); const int32_t axis = concat_axis < 0 ? concat_axis + input_dims : concat_axis; @@ -34,7 +30,7 @@ class ConcatOp : public Operator { Tensor *output = this->Output(OUTPUT); - functor_(input_list, axis, output); + functor_(input_list, output); return true; } diff --git a/mace/ops/concat_test.cc b/mace/ops/concat_test.cc index bf6c65a9..8a42899e 100644 --- a/mace/ops/concat_test.cc +++ b/mace/ops/concat_test.cc @@ -16,7 +16,7 @@ TEST_F(ConcatOpTest, CPUSimpleHorizon) { OpDefBuilder("Concat", "ConcatTest") .Input("Input0") .Input("Input1") - .Input("Axis") + .AddIntArg("axis", 0) .Output("Output") .Finalize(net.NewOperatorDef()); @@ -28,7 +28,6 @@ TEST_F(ConcatOpTest, CPUSimpleHorizon) { // Add inputs net.AddInputFromArray("Input0", input_shape, input0); net.AddInputFromArray("Input1", input_shape, input1); - net.AddInputFromArray("Axis", {}, {0}); // Run net.RunOp(); @@ -54,7 +53,7 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) { OpDefBuilder("Concat", "ConcatTest") .Input("Input0") .Input("Input1") - .Input("Axis") + .AddIntArg("axis", 1) .Output("Output") .Finalize(net.NewOperatorDef()); @@ -66,7 +65,6 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) { // Add inputs net.AddInputFromArray("Input0", input_shape, input0); net.AddInputFromArray("Input1", input_shape, input1); - net.AddInputFromArray("Axis", {}, {1}); // Run net.RunOp(); @@ -99,7 +97,7 @@ TEST_F(ConcatOpTest, CPURandom) { for (int i = 0; i < num_inputs; ++i) { builder = builder.Input(("Input" + ToString(i)).c_str()); } - builder.Input("Axis").Output("Output").Finalize(net.NewOperatorDef()); + builder.AddIntArg("axis", axis).Output("Output").Finalize(net.NewOperatorDef()); std::vector shape_data; GenerateRandomIntTypeData({dim}, shape_data, 1, dim); @@ -115,7 +113,6 @@ TEST_F(ConcatOpTest, CPURandom) { net.AddInputFromArray(("Input" + ToString(i)).c_str(), input_shapes[i], inputs[i]); } - net.AddInputFromArray("Axis", {}, {axis}); // Run net.RunOp(); @@ -156,14 +153,13 @@ void OpenclRandomTest(const std::vector> &shapes, shapes[i]); BufferToImage(net, input_name, image_name, kernels::BufferType::IN_OUT); } - net.AddInputFromArray("Axis", {}, {axis}); auto builder = OpDefBuilder("Concat", "ConcatTest"); for (int i = 0; i < num_inputs; ++i) { const std::string image_name = ("InputImage" + ToString(i)).c_str(); builder = builder.Input(image_name); } - builder.Input("Axis") + builder.AddIntArg("axis", axis) .Output("OutputImage") .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef()); -- GitLab