提交 f61a938b 编写于 作者: L liuqi

Finish concat op.

上级 f3699f51
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/types.h"
namespace mace {
bool DataTypeCanUseMemcpy(DataType dt) {
switch (dt) {
case DT_FLOAT:
case DT_DOUBLE:
case DT_INT32:
case DT_INT64:
case DT_UINT16:
case DT_UINT8:
case DT_INT16:
case DT_INT8:
case DT_BOOL:
return true;
default:
return false;
}
}
} // namespace mace
\ No newline at end of file
......@@ -10,6 +10,8 @@
namespace mace {
bool DataTypeCanUseMemcpy(DataType dt);
template <class T>
struct IsValidDataType;
......@@ -51,7 +53,6 @@ MATCH_TYPE_AND_ENUM(int64_t, DT_INT64);
MATCH_TYPE_AND_ENUM(bool, DT_BOOL);
static const int32_t kint32_tmax = ((int32_t)0x7FFFFFFF);
} // namespace mace
#endif // MACE_CORE_TYPES_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_KERNELS_CONCAT_H_
#define MACE_KERNELS_CONCAT_H_
#include "mace/proto/mace.pb.h"
#include "mace/core/common.h"
#include "mace/core/types.h"
namespace mace {
namespace kernels {
template<DeviceType D, typename T>
struct ConcatFunctor {
void operator()(std::vector<const T *> &input_list,
const index_t inner_dim,
const index_t *outer_dims,
T *output) {
const size_t input_count = input_list.size();
for (int inner_idx = 0; inner_idx < inner_dim; ++inner_idx) {
for (int i = 0; i < input_count; ++i) {
if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
memcpy(output, input_list[i], outer_dims[i] * sizeof(T));
output += outer_dims[i];
input_list[i] += outer_dims[i];
} else {
for (index_t k = 0; k < outer_dims[i]; ++k) {
*output++ = *input_list[i]++;
}
}
}
}
}
};
} // namepsace kernels
} // namespace mace
#endif // MACE_KERNELS_CONCAT_H_
......@@ -30,7 +30,7 @@ void CalcPaddingAndOutputSize(const index_t *input_shape, // NCHW
padding_size[0] = 0;
padding_size[1] = 0;
index_t output_height, output_width;
index_t output_height = 0, output_width = 0;
index_t kernel_height = filter_shape[2];
index_t kernel_width = filter_shape[3];
index_t output_channels = filter_shape[0];
......@@ -85,7 +85,7 @@ void CalPaddingSize(const index_t *input_shape, // NCHW
"If dilations > 1, strides should be 1");
MACE_CHECK_NOTNULL(padding_size);
index_t output_height, output_width;
index_t output_height = 0, output_width = 0;
index_t k_extent_height = (filter_shape[2] - 1) * dilations[0] + 1;
index_t k_extent_width = (filter_shape[3] - 1) * dilations[1] + 1;
......
......@@ -25,48 +25,67 @@ cc_library(
name = "ops",
srcs = glob(
["*.cc"],
exclude = ["*_test.cc", "*_benchmark.cc"],
exclude = [
"*_test.cc",
"*_benchmark.cc",
],
),
hdrs = glob(
["*.h"],
exclude = ["ops_test_util.h"],
),
copts = ["-std=c++11"],
deps = [
"//mace/core",
"//mace/kernels",
"//mace/proto:cc_proto",
],
copts = ["-std=c++11"],
alwayslink = 1,
)
cc_test(
name = "ops_test",
testonly = 1,
srcs = glob(
["*_test.cc"],
),
copts = ["-std=c++11"],
linkopts = if_android(["-ldl"]),
linkstatic = 1,
deps = [
":ops",
":test",
"@gtest//:gtest_main",
],
)
cc_test(
name = "concat_test",
testonly = 1,
srcs = glob(
["concat_test.cc"],
),
copts = ["-std=c++11"],
linkopts = if_android(["-ldl"]),
linkstatic = 1,
testonly = 1,
deps = [
":ops",
":test",
"@gtest//:gtest_main",
],
)
cc_test(
name = "ops_benchmark",
testonly = 1,
srcs = glob(["*_benchmark.cc"]),
copts = ["-std=c++11"],
linkopts = if_android(["-ldl"]),
linkstatic = 1,
deps = [
":ops",
":test",
"//mace/core:core",
"//mace/core",
"//mace/core:test_benchmark_main",
],
copts = ['-std=c++11'],
linkopts = if_android(["-ldl"]),
linkstatic = 1,
testonly = 1,
)
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/concat.h"
namespace mace {
REGISTER_CPU_OPERATOR(Concat, ConcatOp<DeviceType::CPU, float>);
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_CONCAT_H_
#define MACE_OPS_CONCAT_H_
#include "mace/proto/mace.pb.h"
#include "mace/core/operator.h"
#include "mace/kernels/concat.h"
namespace mace {
template<DeviceType D, typename T>
class ConcatOp : public Operator<D, T> {
public:
ConcatOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws) {}
bool Run() override {
int32_t values_count = this->InputSize() - 1;
const Tensor *input0 = this->Input(0);
const Tensor *axis_tensor = this->Input(values_count);
MACE_CHECK(axis_tensor->dim_size() == 0,
"axis should be a scalar integer, but got shape: ",
axis_tensor->dim_size());
const int32_t concat_axis = *(axis_tensor->data<int32_t>());
const int32_t input_dims = input0->dim_size();
const int32_t axis = concat_axis < 0 ? concat_axis + input_dims : concat_axis;
MACE_CHECK((0 <= axis && axis < input_dims), "Expected concatenating axis in the range [",
-input_dims, ", ", input_dims, "], but got", concat_axis);
std::vector<index_t> output_shape(input0->shape());
index_t inner_size = 1;
for (int i = 0; i < axis; ++i) {
inner_size *= output_shape[i];
}
std::vector<index_t> outer_sizes(values_count, 0);
std::vector<const T *> input_list(values_count, nullptr);
input_list[0] = input0->data<T>();
outer_sizes[0] = input0->size() / inner_size;
const Tensor *input = nullptr;
for (int i = 1; i < values_count; ++i) {
input = this->Input(i);
MACE_CHECK(input->dim_size() == input0->dim_size(), "Ranks of all input tensors must be same.");
for (int j = 0; j < axis_tensor->dim_size(); ++j) {
if (j == axis) { continue; }
MACE_CHECK(input->dim(j) == input0->dim(j), "Dimensions of inputs should equal except axis.");
}
input_list[i] = input->data<T>();
outer_sizes[i] = input->size() / inner_size;
output_shape[axis] += input->dim(axis);
}
Tensor *output = this->Output(OUTPUT);
output->Resize(output_shape);
functor_(input_list, inner_size, outer_sizes.data(), output->mutable_data<T>());
return true;
}
private:
kernels::ConcatFunctor<D, T> functor_;
private:
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace mace
#endif // MACE_OPS_CONCAT_H_
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
template<DeviceType D, typename T>
static void ConcatHelper(
int iters, int concat_dim, int dim1) {
mace::testing::StopTiming();
OpsTestNet net;
OpDefBuilder("Concat", "ConcatBM")
.Input("Input0")
.Input("Input1")
.Input("Axis")
.Output("Output")
.Finalize(net.operator_def());
// Add input data
const int kDim0 = 100;
net.AddRandomInput<T>("Input0", {kDim0, dim1});
net.AddRandomInput<T>("Input1", {kDim0, dim1});
net.AddInputFromArray<int32_t>("Axis", {}, {concat_dim});
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
const int64_t tot = static_cast<int64_t>(iters) * kDim0 * dim1 * 2;
mace::testing::ItemsProcessed(tot);
testing::BytesProcessed(tot * sizeof(T));
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
}
static void BM_ConcatDim0Float(int iters, int dim1) {
ConcatHelper<DeviceType::CPU, float>(iters, 0, dim1);
}
static void BM_ConcatDim1Float(int iters, int dim1) {
ConcatHelper<DeviceType::CPU, float>(iters, 1, dim1);
}
BENCHMARK(BM_ConcatDim0Float)->Arg(1000)->Arg(100000);
BENCHMARK(BM_ConcatDim1Float)->Arg(1000)->Arg(100000);
} // namespace mace
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/concat.h"
#include "mace/ops/ops_test_util.h"
#include "gmock/gmock.h"
using namespace mace;
class ConcatOpTest : public OpsTestBase {};
TEST_F(ConcatOpTest, Simple_Horizon) {
// Construct graph
auto &net = test_net();
OpDefBuilder("Concat", "ConcatTest")
.Input("Input0")
.Input("Input1")
.Input("Axis")
.Output("Output")
.Finalize(net.operator_def());
std::vector<index_t> input_shape = {4, 4};
std::vector<float> input0;
GenerateRandomRealTypeData(input_shape, input0);
std::vector<float> input1;
GenerateRandomRealTypeData(input_shape, input1);
// Add inputs
net.AddInputFromArray<float>("Input0", input_shape, input0);
net.AddInputFromArray<float>("Input1", input_shape, input1);
net.AddInputFromArray<int>("Axis", {}, {0});
// Run
net.RunOp();
// Check
auto output = net.GetOutput("Output");
std::vector<index_t> expected_shape = {8, 4};
EXPECT_THAT(output->shape(), ::testing::ContainerEq(expected_shape));
const float *output_ptr = output->data<float>();
for (auto f : input0) {
ASSERT_EQ(f, *output_ptr++);
}
for (auto f : input1) {
ASSERT_EQ(f, *output_ptr++);
}
}
TEST_F(ConcatOpTest, Simple_Vertical) {
// Construct graph
auto &net = test_net();
OpDefBuilder("Concat", "ConcatTest")
.Input("Input0")
.Input("Input1")
.Input("Axis")
.Output("Output")
.Finalize(net.operator_def());
std::vector<index_t> input_shape = {4, 4};
std::vector<float> input0;
GenerateRandomRealTypeData(input_shape, input0);
std::vector<float> input1;
GenerateRandomRealTypeData(input_shape, input1);
// Add inputs
net.AddInputFromArray<float>("Input0", input_shape, input0);
net.AddInputFromArray<float>("Input1", input_shape, input1);
net.AddInputFromArray<int>("Axis", {}, {1});
// Run
net.RunOp();
// Check
auto output = net.GetOutput("Output");
std::vector<index_t> expected_shape = {4, 8};
EXPECT_THAT(output->shape(), ::testing::ContainerEq(expected_shape));
const float *output_ptr = output->data<float>();
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) {
ASSERT_EQ(input0[i * 4 + j], *output_ptr++);
}
for (int j = 0; j < 4; ++j) {
ASSERT_EQ(input1[i * 4 + j], *output_ptr++);
}
}
}
TEST_F(ConcatOpTest, Random) {
srand(time(nullptr));
int dim = 5;
int num_inputs = 2 + rand() % 10;
int axis = rand() % dim;
// Construct graph
auto &net = test_net();
auto builder = OpDefBuilder("Concat", "ConcatTest");
for (int i = 0; i < num_inputs; ++i) {
builder = builder.Input(("Input" + std::to_string(i)).c_str());
}
builder.Input("Axis")
.Output("Output")
.Finalize(net.operator_def());
std::vector<index_t> shape_data;
GenerateRandomIntTypeData<index_t>({dim}, shape_data, 1, dim);
std::vector<std::vector<index_t>> input_shapes(num_inputs, shape_data);
std::vector<std::vector<float>> inputs(num_inputs, std::vector<float>());
std::vector<float *> input_ptrs(num_inputs, nullptr);
index_t concat_axis_size = 0;
for (int i = 0; i < num_inputs; ++i) {
input_shapes[i][axis] = 1 + rand() % dim;
concat_axis_size += input_shapes[i][axis];
GenerateRandomRealTypeData(input_shapes[i], inputs[i]);
input_ptrs[i] = inputs[i].data();
net.AddInputFromArray<float>(("Input" + std::to_string(i)).c_str(), input_shapes[i], inputs[i]);
}
net.AddInputFromArray<int>("Axis", {}, {axis});
// Run
net.RunOp();
// Check
auto output = net.GetOutput("Output");
std::vector<index_t> expected_shape = input_shapes[0];
expected_shape[axis] = concat_axis_size;
EXPECT_THAT(output->shape(), ::testing::ContainerEq(expected_shape));
const float *output_ptr = output->data<float>();
while (output_ptr != (output->data<float>() + output->size())) {
for (int i = 0; i < num_inputs; ++i) {
index_t num_elements = std::accumulate(input_shapes[i].begin() + axis,
input_shapes[i].end(), 1,
std::multiplies<index_t>());
for (int j = 0; j < num_elements; ++j) {
EXPECT_EQ(*input_ptrs[i]++, *output_ptr++);
}
}
}
}
......@@ -37,7 +37,7 @@ class ConvPool2dOpBase : public Operator<D, T> {
* http://deeplearning.net/software/theano/tutorial/conv_arithmetic.html
*/
index_t output_height, output_width;
index_t output_height = 0, output_width = 0;
switch (padding_) {
case VALID:
......
......@@ -43,7 +43,7 @@ class OpsTestNet {
public:
OpsTestNet() {}
template <typename T>
template<typename T>
void AddInputFromArray(const char *name,
const std::vector<index_t> &shape,
const std::vector<T> &data) {
......@@ -51,11 +51,11 @@ class OpsTestNet {
ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape);
T *input_data = input->mutable_data<T>();
MACE_CHECK(input->size() == data.size());
MACE_CHECK(static_cast<size_t>(input->size()) == data.size());
memcpy(input_data, data.data(), data.size() * sizeof(T));
}
template <typename T>
template<typename T>
void AddRepeatedInput(const char *name,
const std::vector<index_t> &shape,
const T data) {
......@@ -67,7 +67,7 @@ class OpsTestNet {
std::fill(input_data, input_data + input->size(), data);
}
template <typename T>
template<typename T>
void AddRandomInput(const char *name,
const std::vector<index_t> &shape,
bool positive = false) {
......@@ -86,18 +86,6 @@ class OpsTestNet {
});
}
template <typename T>
void AddFixedInput(const char *name,
const std::vector<index_t> &shape,
T value) {
Tensor *input =
ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape);
float *input_data = input->mutable_data<T>();
std::fill(input_data, input_data + input->size(), value);
}
void AddIntArg(const char *name, const int value) {
auto arg = op_def_.add_arg();
arg->set_name(name);
......@@ -186,7 +174,38 @@ class OpsTestBase : public ::testing::Test {
OpsTestNet test_net_;
};
template <typename T>
template<typename T>
void GenerateRandomRealTypeData(const std::vector<index_t> &shape, std::vector<T> &res) {
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<T> nd(0, 1);
index_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<index_t>());
res.resize(size);
std::generate(res.begin(), res.end(),
[&gen, &nd] {
return nd(gen);
});
}
template<typename T>
void GenerateRandomIntTypeData(const std::vector<index_t> &shape, std::vector<T> &res,
const T a = 0, const T b = std::numeric_limits<T>::max()) {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> nd(a, b);
index_t size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<index_t>());
res.resize(size);
std::generate(res.begin(), res.end(),
[&gen, &nd] {
return nd(gen);
});
}
template<typename T>
unique_ptr<Tensor> CreateTensor(const std::vector<index_t> &shape,
const std::vector<T> &data) {
unique_ptr<Tensor> res(new Tensor(cpu_allocator(), DataTypeToEnum<T>::v()));
......@@ -219,23 +238,23 @@ inline std::string ShapeToString(const Tensor &x) {
return std::string(stream.str());
}
template <typename T>
template<typename T>
struct is_floating_point_type {
static const bool value =
std::is_same<T, float>::value || std::is_same<T, double>::value;
};
template <typename T>
template<typename T>
inline void ExpectEqual(const T &a, const T &b) {
EXPECT_EQ(a, b);
}
template <>
template<>
inline void ExpectEqual<float>(const float &a, const float &b) {
EXPECT_FLOAT_EQ(a, b);
}
template <>
template<>
inline void ExpectEqual<double>(const double &a, const double &b) {
EXPECT_DOUBLE_EQ(a, b);
}
......@@ -246,11 +265,11 @@ inline void AssertSameTypeDims(const Tensor &x, const Tensor &y) {
<< "y.shape [ " << ShapeToString(y) << "]";
}
template <typename T, bool is_fp = is_floating_point_type<T>::value>
template<typename T, bool is_fp = is_floating_point_type<T>::value>
struct Expector;
// Partial specialization for float and double.
template <typename T>
template<typename T>
struct Expector<T, true> {
static void Equal(const T &a, const T &b) { ExpectEqual(a, b); }
......@@ -276,7 +295,7 @@ struct Expector<T, true> {
}
};
template <typename T>
template<typename T>
void ExpectTensorNear(const Tensor &x, const Tensor &y, const double abs_err) {
static_assert(is_floating_point_type<T>::value,
"T is not a floating point type");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册