提交 f6006d5e 编写于 作者: L liuqi

Change the axis tensor of concat to an attribute.

上级 4743a1e6
......@@ -13,17 +13,24 @@
namespace mace {
namespace kernels {
struct ConcatFunctorBase {
ConcatFunctorBase(const int32_t axis): axis_(axis){}
int32_t axis_;
};
template<DeviceType D, typename T>
struct ConcatFunctor {
struct ConcatFunctor : ConcatFunctorBase {
ConcatFunctor(const int32_t axis): ConcatFunctorBase(axis){}
void operator()(const std::vector<const Tensor *> &input_list,
const int32_t axis,
Tensor *output) {
const Tensor *input0 = input_list.front();
const int inputs_count = input_list.size() - 1;
const int inputs_count = input_list.size();
std::vector<index_t> output_shape(input0->shape());
index_t inner_size = 1;
for (int i = 0; i < axis; ++i) {
for (int i = 0; i < axis_; ++i) {
inner_size *= output_shape[i];
}
std::vector<index_t> outer_sizes(inputs_count, 0);
......@@ -33,14 +40,14 @@ struct ConcatFunctor {
MACE_CHECK(input->dim_size() == input0->dim_size(),
"Ranks of all input tensors must be same.");
for (int j = 0; j < input->dim_size(); ++j) {
if (j == axis) {
if (j == axis_) {
continue;
}
MACE_CHECK(input->dim(j) == input0->dim(j),
"Dimensions of inputs should equal except axis.");
}
outer_sizes[i] = input->size() / inner_size;
output_shape[axis] += input->dim(axis);
output_shape[axis_] += input->dim(axis_);
}
output->Resize(output_shape);
......@@ -67,9 +74,10 @@ struct ConcatFunctor {
};
template<typename T>
struct ConcatFunctor<DeviceType::OPENCL, T> {
struct ConcatFunctor<DeviceType::OPENCL, T> : ConcatFunctorBase{
ConcatFunctor(const int32_t axis): ConcatFunctorBase(axis){}
void operator()(const std::vector<const Tensor *> &input_list,
const int32_t axis,
Tensor *output);
};
......
......@@ -32,8 +32,6 @@ __kernel void concat_channel(__read_only image2d_t input0,
const int hb_idx = get_global_id(2);
const int input0_chan_blk = (input0_chan + 3) / 4;
const sampler_t SAMPLER = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;
DATA_TYPE4 data = 0;
#ifdef DIVISIBLE_FOUR
if (chan_blk_idx + 1 <= input0_chan_blk) {
......
......@@ -60,10 +60,9 @@ static void Concat2(const Tensor *input0,
template<typename T>
void ConcatFunctor<DeviceType::OPENCL, T>::operator()(const std::vector<const Tensor *> &input_list,
const int32_t axis,
Tensor *output) {
const int inputs_count = input_list.size() - 1;
MACE_CHECK(inputs_count == 2 && axis == 3)
const int inputs_count = input_list.size();
MACE_CHECK(inputs_count == 2 && axis_ == 3)
<< "Concat opencl kernel only support two elements with axis == 3";
const Tensor *input0 = input_list[0];
......@@ -74,13 +73,13 @@ void ConcatFunctor<DeviceType::OPENCL, T>::operator()(const std::vector<const Te
MACE_CHECK(input->dim_size() == input0->dim_size(),
"Ranks of all input tensors must be same.");
for (int j = 0; j < input->dim_size(); ++j) {
if (j == axis) {
if (j == axis_) {
continue;
}
MACE_CHECK(input->dim(j) == input0->dim(j),
"Dimensions of inputs should equal except axis.");
}
output_shape[axis] += input->dim(axis);
output_shape[axis_] += input->dim(axis_);
}
std::vector<size_t> image_shape;
CalImage2DShape(output_shape, BufferType::IN_OUT, image_shape);
......
......@@ -61,6 +61,22 @@ cc_test(
],
)
cc_test(
name = "concat_test",
testonly = 1,
srcs = glob(
["concat_test.cc"],
),
copts = ["-std=c++11"],
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
":ops",
":test",
"@gtest//:gtest_main",
],
)
cc_test(
name = "ops_benchmark",
testonly = 1,
......
......@@ -14,17 +14,13 @@ template <DeviceType D, typename T>
class ConcatOp : public Operator<D, T> {
public:
ConcatOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws) {}
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetSingleArgument<int>("axis", 3)){}
bool Run() override {
const int32_t inputs_count = this->InputSize() - 1;
MACE_CHECK(this->InputSize() >= 2) << "There must be at least two inputs to concat";
const std::vector<const Tensor *> input_list = this->Inputs();
const Tensor *axis_tensor = this->Input(inputs_count);
MACE_CHECK(axis_tensor->dim_size() == 0,
"axis should be a scalar integer, but got shape: ",
axis_tensor->dim_size());
Tensor::MappingGuard axis_mapper(axis_tensor);
const int32_t concat_axis = *(axis_tensor->data<int32_t>());
const int32_t concat_axis = OperatorBase::GetSingleArgument<int>("axis", 3);
const int32_t input_dims = input_list[0]->dim_size();
const int32_t axis =
concat_axis < 0 ? concat_axis + input_dims : concat_axis;
......@@ -34,7 +30,7 @@ class ConcatOp : public Operator<D, T> {
Tensor *output = this->Output(OUTPUT);
functor_(input_list, axis, output);
functor_(input_list, output);
return true;
}
......
......@@ -16,7 +16,7 @@ TEST_F(ConcatOpTest, CPUSimpleHorizon) {
OpDefBuilder("Concat", "ConcatTest")
.Input("Input0")
.Input("Input1")
.Input("Axis")
.AddIntArg("axis", 0)
.Output("Output")
.Finalize(net.NewOperatorDef());
......@@ -28,7 +28,6 @@ TEST_F(ConcatOpTest, CPUSimpleHorizon) {
// Add inputs
net.AddInputFromArray<DeviceType::CPU, float>("Input0", input_shape, input0);
net.AddInputFromArray<DeviceType::CPU, float>("Input1", input_shape, input1);
net.AddInputFromArray<DeviceType::CPU, int>("Axis", {}, {0});
// Run
net.RunOp();
......@@ -54,7 +53,7 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) {
OpDefBuilder("Concat", "ConcatTest")
.Input("Input0")
.Input("Input1")
.Input("Axis")
.AddIntArg("axis", 1)
.Output("Output")
.Finalize(net.NewOperatorDef());
......@@ -66,7 +65,6 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) {
// Add inputs
net.AddInputFromArray<DeviceType::CPU, float>("Input0", input_shape, input0);
net.AddInputFromArray<DeviceType::CPU, float>("Input1", input_shape, input1);
net.AddInputFromArray<DeviceType::CPU, int>("Axis", {}, {1});
// Run
net.RunOp();
......@@ -99,7 +97,7 @@ TEST_F(ConcatOpTest, CPURandom) {
for (int i = 0; i < num_inputs; ++i) {
builder = builder.Input(("Input" + ToString(i)).c_str());
}
builder.Input("Axis").Output("Output").Finalize(net.NewOperatorDef());
builder.AddIntArg("axis", axis).Output("Output").Finalize(net.NewOperatorDef());
std::vector<index_t> shape_data;
GenerateRandomIntTypeData<index_t>({dim}, shape_data, 1, dim);
......@@ -115,7 +113,6 @@ TEST_F(ConcatOpTest, CPURandom) {
net.AddInputFromArray<DeviceType::CPU, float>(("Input" + ToString(i)).c_str(),
input_shapes[i], inputs[i]);
}
net.AddInputFromArray<DeviceType::CPU, int>("Axis", {}, {axis});
// Run
net.RunOp();
......@@ -156,14 +153,13 @@ void OpenclRandomTest(const std::vector<std::vector<index_t>> &shapes,
shapes[i]);
BufferToImage<DeviceType::OPENCL, T>(net, input_name, image_name, kernels::BufferType::IN_OUT);
}
net.AddInputFromArray<DeviceType::OPENCL, int>("Axis", {}, {axis});
auto builder = OpDefBuilder("Concat", "ConcatTest");
for (int i = 0; i < num_inputs; ++i) {
const std::string image_name = ("InputImage" + ToString(i)).c_str();
builder = builder.Input(image_name);
}
builder.Input("Axis")
builder.AddIntArg("axis", axis)
.Output("OutputImage")
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册