diff --git a/mace/kernels/concat.h b/mace/kernels/concat.h index 0b1b4834e4aa4fe4555faf47e59283d4c2dc205b..e70b4e73c977b9d8da0735784219739c5dbd468a 100644 --- a/mace/kernels/concat.h +++ b/mace/kernels/concat.h @@ -13,17 +13,24 @@ namespace mace { namespace kernels { +struct ConcatFunctorBase { + ConcatFunctorBase(const int32_t axis): axis_(axis){} + + int32_t axis_; +}; + template -struct ConcatFunctor { +struct ConcatFunctor : ConcatFunctorBase { + ConcatFunctor(const int32_t axis): ConcatFunctorBase(axis){} + void operator()(const std::vector &input_list, - const int32_t axis, Tensor *output) { const Tensor *input0 = input_list.front(); - const int inputs_count = input_list.size() - 1; + const int inputs_count = input_list.size(); std::vector output_shape(input0->shape()); index_t inner_size = 1; - for (int i = 0; i < axis; ++i) { + for (int i = 0; i < axis_; ++i) { inner_size *= output_shape[i]; } std::vector outer_sizes(inputs_count, 0); @@ -33,14 +40,14 @@ struct ConcatFunctor { MACE_CHECK(input->dim_size() == input0->dim_size(), "Ranks of all input tensors must be same."); for (int j = 0; j < input->dim_size(); ++j) { - if (j == axis) { + if (j == axis_) { continue; } MACE_CHECK(input->dim(j) == input0->dim(j), "Dimensions of inputs should equal except axis."); } outer_sizes[i] = input->size() / inner_size; - output_shape[axis] += input->dim(axis); + output_shape[axis_] += input->dim(axis_); } output->Resize(output_shape); @@ -67,9 +74,10 @@ struct ConcatFunctor { }; template -struct ConcatFunctor { +struct ConcatFunctor : ConcatFunctorBase{ + ConcatFunctor(const int32_t axis): ConcatFunctorBase(axis){} + void operator()(const std::vector &input_list, - const int32_t axis, Tensor *output); }; diff --git a/mace/kernels/opencl/cl/concat.cl b/mace/kernels/opencl/cl/concat.cl index fdffde574bca75b0fa2115c463e75e583f80a5bd..5ae4fb04e141f6510529083b2fee9e8826494b6d 100644 --- a/mace/kernels/opencl/cl/concat.cl +++ b/mace/kernels/opencl/cl/concat.cl @@ -32,8 +32,6 @@ __kernel void concat_channel(__read_only image2d_t input0, const int hb_idx = get_global_id(2); const int input0_chan_blk = (input0_chan + 3) / 4; - const sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST; - DATA_TYPE4 data = 0; #ifdef DIVISIBLE_FOUR if (chan_blk_idx + 1 <= input0_chan_blk) { diff --git a/mace/kernels/opencl/concat.cc b/mace/kernels/opencl/concat.cc index 8f61e3536c071b8fec7ef2eb26ef03b07a8638a1..80a23a44ccdbb50cd2be9976f2dfecf8d0576a91 100644 --- a/mace/kernels/opencl/concat.cc +++ b/mace/kernels/opencl/concat.cc @@ -60,10 +60,9 @@ static void Concat2(const Tensor *input0, template void ConcatFunctor::operator()(const std::vector &input_list, - const int32_t axis, Tensor *output) { - const int inputs_count = input_list.size() - 1; - MACE_CHECK(inputs_count == 2 && axis == 3) + const int inputs_count = input_list.size(); + MACE_CHECK(inputs_count == 2 && axis_ == 3) << "Concat opencl kernel only support two elements with axis == 3"; const Tensor *input0 = input_list[0]; @@ -74,13 +73,13 @@ void ConcatFunctor::operator()(const std::vectordim_size() == input0->dim_size(), "Ranks of all input tensors must be same."); for (int j = 0; j < input->dim_size(); ++j) { - if (j == axis) { + if (j == axis_) { continue; } MACE_CHECK(input->dim(j) == input0->dim(j), "Dimensions of inputs should equal except axis."); } - output_shape[axis] += input->dim(axis); + output_shape[axis_] += input->dim(axis_); } std::vector image_shape; CalImage2DShape(output_shape, BufferType::IN_OUT, image_shape); diff --git a/mace/ops/BUILD b/mace/ops/BUILD index 89f6b8b78bb1f31a912817140756d562275a61ce..7d5e68322e52753e54858a9c6f2dd985bb96e466 100644 --- a/mace/ops/BUILD +++ b/mace/ops/BUILD @@ -61,6 +61,22 @@ cc_test( ], ) +cc_test( + name = "concat_test", + testonly = 1, + srcs = glob( + ["concat_test.cc"], + ), + copts = ["-std=c++11"], + linkopts = ["-fopenmp"], + linkstatic = 1, + deps = [ + ":ops", + ":test", + "@gtest//:gtest_main", + ], +) + cc_test( name = "ops_benchmark", testonly = 1, diff --git a/mace/ops/concat.h b/mace/ops/concat.h index b64060b21a2cdfa9e4873283819ebfcc8eb5c59d..77e430304c93341c176dd732c30559f0721e4f8a 100644 --- a/mace/ops/concat.h +++ b/mace/ops/concat.h @@ -14,17 +14,13 @@ template class ConcatOp : public Operator { public: ConcatOp(const OperatorDef &op_def, Workspace *ws) - : Operator(op_def, ws) {} + : Operator(op_def, ws), + functor_(OperatorBase::GetSingleArgument("axis", 3)){} bool Run() override { - const int32_t inputs_count = this->InputSize() - 1; + MACE_CHECK(this->InputSize() >= 2) << "There must be at least two inputs to concat"; const std::vector input_list = this->Inputs(); - const Tensor *axis_tensor = this->Input(inputs_count); - MACE_CHECK(axis_tensor->dim_size() == 0, - "axis should be a scalar integer, but got shape: ", - axis_tensor->dim_size()); - Tensor::MappingGuard axis_mapper(axis_tensor); - const int32_t concat_axis = *(axis_tensor->data()); + const int32_t concat_axis = OperatorBase::GetSingleArgument("axis", 3); const int32_t input_dims = input_list[0]->dim_size(); const int32_t axis = concat_axis < 0 ? concat_axis + input_dims : concat_axis; @@ -34,7 +30,7 @@ class ConcatOp : public Operator { Tensor *output = this->Output(OUTPUT); - functor_(input_list, axis, output); + functor_(input_list, output); return true; } diff --git a/mace/ops/concat_test.cc b/mace/ops/concat_test.cc index bf6c65a940a09f1d77d97b9e4ae9f829d90d2dd5..8a42899e49183dd35fcfaf804679b4807219688b 100644 --- a/mace/ops/concat_test.cc +++ b/mace/ops/concat_test.cc @@ -16,7 +16,7 @@ TEST_F(ConcatOpTest, CPUSimpleHorizon) { OpDefBuilder("Concat", "ConcatTest") .Input("Input0") .Input("Input1") - .Input("Axis") + .AddIntArg("axis", 0) .Output("Output") .Finalize(net.NewOperatorDef()); @@ -28,7 +28,6 @@ TEST_F(ConcatOpTest, CPUSimpleHorizon) { // Add inputs net.AddInputFromArray("Input0", input_shape, input0); net.AddInputFromArray("Input1", input_shape, input1); - net.AddInputFromArray("Axis", {}, {0}); // Run net.RunOp(); @@ -54,7 +53,7 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) { OpDefBuilder("Concat", "ConcatTest") .Input("Input0") .Input("Input1") - .Input("Axis") + .AddIntArg("axis", 1) .Output("Output") .Finalize(net.NewOperatorDef()); @@ -66,7 +65,6 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) { // Add inputs net.AddInputFromArray("Input0", input_shape, input0); net.AddInputFromArray("Input1", input_shape, input1); - net.AddInputFromArray("Axis", {}, {1}); // Run net.RunOp(); @@ -99,7 +97,7 @@ TEST_F(ConcatOpTest, CPURandom) { for (int i = 0; i < num_inputs; ++i) { builder = builder.Input(("Input" + ToString(i)).c_str()); } - builder.Input("Axis").Output("Output").Finalize(net.NewOperatorDef()); + builder.AddIntArg("axis", axis).Output("Output").Finalize(net.NewOperatorDef()); std::vector shape_data; GenerateRandomIntTypeData({dim}, shape_data, 1, dim); @@ -115,7 +113,6 @@ TEST_F(ConcatOpTest, CPURandom) { net.AddInputFromArray(("Input" + ToString(i)).c_str(), input_shapes[i], inputs[i]); } - net.AddInputFromArray("Axis", {}, {axis}); // Run net.RunOp(); @@ -156,14 +153,13 @@ void OpenclRandomTest(const std::vector> &shapes, shapes[i]); BufferToImage(net, input_name, image_name, kernels::BufferType::IN_OUT); } - net.AddInputFromArray("Axis", {}, {axis}); auto builder = OpDefBuilder("Concat", "ConcatTest"); for (int i = 0; i < num_inputs; ++i) { const std::string image_name = ("InputImage" + ToString(i)).c_str(); builder = builder.Input(image_name); } - builder.Input("Axis") + builder.AddIntArg("axis", axis) .Output("OutputImage") .AddIntArg("T", static_cast(DataTypeToEnum::value)) .Finalize(net.NewOperatorDef());