提交 103f81de 编写于 作者: L liuqi

Move test and benchmark of batch_norm and conv2d3x3r1 to ops directory.

上级 5536cf7f
......@@ -8,7 +8,7 @@
namespace mace {
namespace kernels {
static const int REGISTER_SIZE = 4;
static const int kRegisterSize = 4;
void Conv2dNeonK3x3S1(const float* input, // NCHW
const index_t* input_shape,
......@@ -44,7 +44,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
float32x4_t filter3 = vld1q_f32(filter_ptr+3);
float32x4_t filter6 = vld1q_f32(filter_ptr+6);
const float* row[REGISTER_SIZE] = {
const float* row[kRegisterSize] = {
input_ptr, input_ptr + input_width,
input_ptr + 2 * input_width, input_ptr + 3 * input_width
};
......@@ -61,7 +61,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
float32x4_t sum0 = vdupq_n_f32(.0f);
float32x4_t sum1 = vdupq_n_f32(.0f);
float32x4_t row0_ext_0 = vld1q_f32(row[0]); //0123
float32x4_t row0_latter = vld1q_f32(row[0] + REGISTER_SIZE); //4567
float32x4_t row0_latter = vld1q_f32(row[0] + kRegisterSize); //4567
float32x4_t row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234
float32x4_t row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345
......@@ -70,7 +70,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2);
float32x4_t row1_ext_0 = vld1q_f32(row[1]); //0123
float32x4_t row1_latter = vld1q_f32(row[1] + REGISTER_SIZE); //4567
float32x4_t row1_latter = vld1q_f32(row[1] + kRegisterSize); //4567
float32x4_t row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); //1234
float32x4_t row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); //2345
......@@ -79,7 +79,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2);
row0_ext_0 = vld1q_f32(row[2]); //0123
row0_latter = vld1q_f32(row[2] + REGISTER_SIZE); //4567
row0_latter = vld1q_f32(row[2] + kRegisterSize); //4567
row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234
row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345
......@@ -97,7 +97,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
sum1 = vfmaq_laneq_f32(sum1, row0_ext_2, filter3, 2);
row1_ext_0 = vld1q_f32(row[3]); //0123
row1_latter = vld1q_f32(row[3] + REGISTER_SIZE); //4567
row1_latter = vld1q_f32(row[3] + kRegisterSize); //4567
row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); //1234
row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); //2345
......@@ -112,10 +112,10 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
vst1q_f32(output_ptr1, output_row0);
vst1q_f32(output_ptr2, output_row1);
output_ptr1 += REGISTER_SIZE;
output_ptr2 += REGISTER_SIZE;
for(int i = 0; i < REGISTER_SIZE; ++i) {
row[i] += REGISTER_SIZE;
output_ptr1 += kRegisterSize;
output_ptr2 += kRegisterSize;
for(int i = 0; i < kRegisterSize; ++i) {
row[i] += kRegisterSize;
}
}
for (; remain_count > 0; --remain_count) {
......@@ -138,13 +138,13 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
++output_ptr1;
++output_ptr2;
for(int i = 0; i < REGISTER_SIZE; ++i) {
for(int i = 0; i < kRegisterSize; ++i) {
row[i] += 1;
}
}
output_ptr1 += width;
output_ptr2 += width;
for(int i = 0; i < REGISTER_SIZE; ++i) {
for(int i = 0; i < kRegisterSize; ++i) {
row[i] += 2 + input_width;
}
}
......@@ -155,7 +155,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
for(; count > 0; --count) {
float32x4_t sum0 = vdupq_n_f32(.0f);
float32x4_t row0_ext_0 = vld1q_f32(row[0]); //0123
float32x4_t row0_latter = vld1q_f32(row[0] + REGISTER_SIZE); //4567
float32x4_t row0_latter = vld1q_f32(row[0] + kRegisterSize); //4567
float32x4_t row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234
float32x4_t row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345
......@@ -164,7 +164,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
sum0 = vfmaq_laneq_f32(sum0, row0_ext_2, filter0, 2);
float32x4_t row1_ext_0 = vld1q_f32(row[1]); //0123
float32x4_t row1_latter = vld1q_f32(row[1] + REGISTER_SIZE); //4567
float32x4_t row1_latter = vld1q_f32(row[1] + kRegisterSize); //4567
float32x4_t row1_ext_1 = vextq_f32(row1_ext_0, row1_latter, 1); //1234
float32x4_t row1_ext_2 = vextq_f32(row1_ext_0, row1_latter, 2); //2345
......@@ -173,7 +173,7 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
sum0 = vfmaq_laneq_f32(sum0, row1_ext_2, filter3, 2);
row0_ext_0 = vld1q_f32(row[2]); //0123
row0_latter = vld1q_f32(row[2] + REGISTER_SIZE); //4567
row0_latter = vld1q_f32(row[2] + kRegisterSize); //4567
row0_ext_1 = vextq_f32(row0_ext_0, row0_latter, 1); //1234
row0_ext_2 = vextq_f32(row0_ext_0, row0_latter, 2); //2345
......@@ -184,9 +184,9 @@ void Conv2dNeonK3x3S1(const float* input, // NCHW
float32x4_t output_row0 = vld1q_f32(output_ptr1);
output_row0 = vaddq_f32(output_row0, sum0);
vst1q_f32(output_ptr1, output_row0);
output_ptr1 += REGISTER_SIZE;
output_ptr1 += kRegisterSize;
for(int i = 0; i < 3; ++i) {
row[i] += REGISTER_SIZE;
row[i] += kRegisterSize;
}
}
for (; remain_count > 0; --remain_count) {
......
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include <random>
#include "gtest/gtest.h"
#include "mace/kernels/batch_norm.h"
namespace mace {
TEST(BatchNormNeonTest, Simple) {
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 128;
index_t channels = 3;
index_t height = 10 + rand() % 100;
index_t width = 10 + rand() % 100;
index_t input_size = batch * channels * height * width;
std::vector<float> input(input_size, 0.0);
std::vector<float> scale(channels, 0.0);
std::vector<float> offset(channels, 0.0);
std::vector<float> mean(channels, 0.0);
std::vector<float> var(channels, 0.0);
for (int i = 0; i < input_size; ++i) {
input[i] = nd(gen);
}
for (int i = 0; i < channels; ++i) {
scale[i] = nd(gen);
offset[i] = nd(gen);
mean[i] = nd(gen);
var[i] = std::abs(nd(gen));
}
// declare output
std::unique_ptr<float[]> output(new float[input_size]);
std::unique_ptr<float[]> output_neon(new float[input_size]);
kernels::BatchNormFunctor<DeviceType::CPU, float>(1e-5)(
input.data(),
scale.data(),
offset.data(),
mean.data(),
var.data(),
batch,
channels,
height * width,
output.get()
);
kernels::BatchNormFunctor<DeviceType::NEON, float>(1e-5)(
input.data(),
scale.data(),
offset.data(),
mean.data(),
var.data(),
batch,
channels,
height * width,
output_neon.get()
);
for (index_t i = 0; i < input_size; ++i) {
EXPECT_FLOAT_EQ(output[i], output_neon[i]);
}
}
} // namespace mace
\ No newline at end of file
......@@ -2,75 +2,67 @@
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/core/operator.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/kernels/batch_norm.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
template <DeviceType D, typename T>
static void BatchNorm(int iters, int batch, int channels, int height, int width) {
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<float> nd(0, 1);
mace::testing::StopTiming();
index_t input_size = batch * channels * height * width;
std::vector<T> input(input_size, 0.0);
std::vector<T> scale(channels, 0.0);
std::vector<T> offset(channels, 0.0);
std::vector<T> mean(channels, 0.0);
std::vector<T> var(channels, 0.0);
OpsTestNet net;
OpDefBuilder("BatchNorm", "BatchNormBM")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.Output("Output")
.Finalize(net.operator_def());
for (int i = 0; i < input_size; ++i) {
input[i] = nd(gen);
}
for (int i = 0; i < channels; ++i) {
scale[i] = nd(gen);
offset[i] = nd(gen);
mean[i] = nd(gen);
var[i] = std::abs(nd(gen));
}
// Add input data
net.AddRandomInput<T>("Input", {batch, channels, height, width});
net.AddRandomInput<T>("Scale", {channels});
net.AddRandomInput<T>("Offset", {channels});
net.AddRandomInput<T>("Mean", {channels});
net.AddRandomInput<T>("Var", {channels}, true);
// declare output
std::unique_ptr<T[]> output(new T[input_size]);
auto functor = kernels::BatchNormFunctor<D, T>(1e-5);
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
mace::testing::StartTiming();
while(iters--) {
functor(input.data(),
scale.data(),
offset.data(),
mean.data(),
var.data(),
batch,
channels,
height * width,
output.get());
net.RunOp(D);
}
}
#define BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, DEVICE) \
#define BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, DEVICE) \
static void BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::ItemsProcessed(tot); \
mace::testing::BytesProcessed(tot * (sizeof(TYPE)));\
mace::testing::BytesProcessed(tot * (sizeof(TYPE))); \
BatchNorm<DEVICE, TYPE>(iters, N, C, H, W); \
} \
BENCHMARK(BM_BATCH_NORM_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define BM_BATCH_NORM(N, C, H, W, TYPE) \
#define BM_BATCH_NORM(N, C, H, W, TYPE) \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, CPU); \
BM_BATCH_NORM_MACRO(N, C, H, W, TYPE, NEON);
BM_BATCH_NORM(1, 1, 128, 128, float);
BM_BATCH_NORM(1, 1, 512, 512, float);
BM_BATCH_NORM(1, 1, 1024, 1024, float);
BM_BATCH_NORM(16, 1, 256, 256, float);
BM_BATCH_NORM(32, 1, 256, 256, float);
BM_BATCH_NORM(64, 1, 256, 256, float);
BM_BATCH_NORM(1, 3, 128, 128, float);
BM_BATCH_NORM(1, 3, 512, 512, float);
BM_BATCH_NORM(1, 3, 1024, 1024, float);
BM_BATCH_NORM(16, 3, 256, 256, float);
BM_BATCH_NORM(1, 64, 256, 256, float);
BM_BATCH_NORM(1, 64, 512, 512, float);
BM_BATCH_NORM(1, 128, 256, 256, float);
BM_BATCH_NORM(1, 128, 512, 512, float);
BM_BATCH_NORM(32, 1, 256, 256, float);
BM_BATCH_NORM(32, 3, 256, 256, float);
BM_BATCH_NORM(64, 3, 256, 256, float);
} // namespace mace
\ No newline at end of file
......@@ -9,7 +9,7 @@ namespace mace {
class BatchNormOpTest : public OpsTestBase {};
TEST_F(BatchNormOpTest, Simple) {
TEST_F(BatchNormOpTest, SimpleCPU) {
// Construct graph
auto& net = test_net();
OpDefBuilder("BatchNorm", "BatchNormTest")
......@@ -40,4 +40,42 @@ TEST_F(BatchNormOpTest, Simple) {
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 0.01);
}
TEST_F(BatchNormOpTest, SimpleNeon) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 10;
index_t channels = 3 + rand() % 50;
index_t height = 10 + rand() % 50;
index_t width = 10 + rand() % 50;
// Construct graph
auto& net = test_net();
OpDefBuilder("BatchNorm", "BatchNormTest")
.Input("Input")
.Input("Scale")
.Input("Offset")
.Input("Mean")
.Input("Var")
.Output("Output")
.Finalize(net.operator_def());
// Add input data
net.AddRandomInput<float>("Input", {batch, channels, height, width});
net.AddRandomInput<float>("Scale", {channels});
net.AddRandomInput<float>("Offset", {channels});
net.AddRandomInput<float>("Mean", {channels});
net.AddRandomInput<float>("Var", {channels}, true);
// run cpu
net.RunOp();
// Check
Tensor expected = *net.GetOutput("Output");
// Run NEON
net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-5);
}
}
......@@ -62,5 +62,6 @@ static void Conv2d(int iters, int batch, int channels, int height, int width,
BM_CONV_2D(1, 64, 32, 32, 1, 1, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, VALID, 128, float);
BM_CONV_2D(1, 64, 32, 32, 3, 3, 1, SAME, 128, float);
} // namespace mace
......@@ -192,3 +192,47 @@ TEST_F(Conv2dOpTest, Conv1x1) {
}
// TODO we need more tests
TEST_F(Conv2dOpTest, Conv3x3R1) {
auto func = [&](Padding type) {
srand(time(NULL));
// generate random input
index_t batch = 1 + rand() % 5;
index_t input_channels = 3 + rand() % 50;
index_t height = 10 + rand() % 100;
index_t width = 10 + rand() % 100;
index_t output_channels = 3 + rand() % 50;
// Construct graph
auto& net = test_net();
OpDefBuilder("Conv2d", "Conv2dTest")
.Input("Input")
.Input("Filter")
.Input("Bias")
.Output("Output")
.Finalize(net.operator_def());
// Add args
net.AddIntsArg("strides", {1, 1});
net.AddIntArg("padding", type);
net.AddIntsArg("dilations", {1, 1});
// Add input data
net.AddRandomInput<float>("Input", {batch, input_channels, height, width});
net.AddRandomInput<float>("Filter", {output_channels, input_channels, 3, 3});
net.AddRandomInput<float>("Bias", {output_channels});
// run cpu
net.RunOp();
// Check
Tensor expected = *net.GetOutput("Output");
// Run NEON
net.RunOp(DeviceType::NEON);
ExpectTensorNear<float>(expected, *net.GetOutput("Output"), 1e-5);
};
func(VALID);
func(SAME);
}
......@@ -15,152 +15,167 @@
namespace mace {
class OpDefBuilder {
public:
OpDefBuilder(const char* type, const char* name) {
op_def_.set_type(type);
op_def_.set_name(name);
}
OpDefBuilder& Input(const char* input_name) {
op_def_.add_input(input_name);
return *this;
}
OpDefBuilder& Output(const char* output_name) {
op_def_.add_output(output_name);
return *this;
}
void Finalize(OperatorDef* op_def) const {
MACE_CHECK(op_def != nullptr, "input should not be null.");
*op_def = op_def_;
}
OperatorDef op_def_;
public:
OpDefBuilder(const char *type, const char *name) {
op_def_.set_type(type);
op_def_.set_name(name);
}
OpDefBuilder &Input(const char *input_name) {
op_def_.add_input(input_name);
return *this;
}
OpDefBuilder &Output(const char *output_name) {
op_def_.add_output(output_name);
return *this;
}
void Finalize(OperatorDef *op_def) const {
MACE_CHECK(op_def != nullptr, "input should not be null.");
*op_def = op_def_;
}
OperatorDef op_def_;
};
class OpsTestNet {
public:
OpsTestNet() {}
template <typename T>
void AddInputFromArray(const char* name,
const std::vector<index_t>& shape,
const std::vector<T>& data) {
Tensor* input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape);
T* input_data = input->mutable_data<T>();
MACE_CHECK(input->size() == data.size());
memcpy(input_data, data.data(), data.size() * sizeof(T));
}
public:
OpsTestNet() {}
template<typename T>
void AddInputFromArray(const char *name,
const std::vector<index_t> &shape,
const std::vector<T> &data) {
Tensor *input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape);
T *input_data = input->mutable_data<T>();
MACE_CHECK(input->size() == data.size());
memcpy(input_data, data.data(), data.size() * sizeof(T));
}
template <typename T>
void AddRandomInput(const char* name, const std::vector<index_t>& shape) {
Tensor* input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape);
float* input_data = input->mutable_data<T>();
template<typename T>
void AddRepeatedInput(const char *name,
const std::vector<index_t> &shape,
const T data) {
Tensor *input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape);
T *input_data = input->mutable_data<T>();
MACE_CHECK(input->size() == data.size());
std::fill(input_data, input_data + input->size(), data);
}
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<T> nd(0, 1);
template<typename T>
void AddRandomInput(const char *name, const std::vector<index_t> &shape, bool positive = false) {
Tensor *input = ws_.CreateTensor(name, cpu_allocator(), DataTypeToEnum<T>::v());
input->Resize(shape);
float *input_data = input->mutable_data<T>();
std::generate(input_data, input_data + input->size(),
[&gen, &nd]{ return nd(gen); });
}
std::random_device rd;
std::mt19937 gen(rd());
std::normal_distribution<T> nd(0, 1);
void AddIntArg(const char* name, const int value) {
auto arg = op_def_.add_arg();
arg->set_name(name);
arg->set_i(value);
}
std::generate(input_data, input_data + input->size(),
[&gen, &nd, positive] { return positive ? std::abs(nd(gen)) : nd(gen); });
}
void AddFloatArg(const char* name, const float value) {
auto arg = op_def_.add_arg();
arg->set_name(name);
arg->set_f(value);
}
void AddIntArg(const char *name, const int value) {
auto arg = op_def_.add_arg();
arg->set_name(name);
arg->set_i(value);
}
void AddStringArg(const char* name, const char* value) {
auto arg = op_def_.add_arg();
arg->set_name(name);
arg->set_s(value);
}
void AddFloatArg(const char *name, const float value) {
auto arg = op_def_.add_arg();
arg->set_name(name);
arg->set_f(value);
}
void AddIntsArg(const char* name, const std::vector<int>& values) {
auto arg = op_def_.add_arg();
arg->set_name(name);
for (auto value : values) {
arg->add_ints(value);
}
void AddStringArg(const char *name, const char *value) {
auto arg = op_def_.add_arg();
arg->set_name(name);
arg->set_s(value);
}
void AddIntsArg(const char *name, const std::vector<int> &values) {
auto arg = op_def_.add_arg();
arg->set_name(name);
for (auto value : values) {
arg->add_ints(value);
}
}
void AddFloatsArg(const char* name, const std::vector<float>& values) {
auto arg = op_def_.add_arg();
arg->set_name(name);
for (auto value : values) {
arg->add_floats(value);
}
void AddFloatsArg(const char *name, const std::vector<float> &values) {
auto arg = op_def_.add_arg();
arg->set_name(name);
for (auto value : values) {
arg->add_floats(value);
}
}
void AddStringsArg(const char* name, const std::vector<const char*>& values) {
auto arg = op_def_.add_arg();
arg->set_name(name);
for (auto value : values) {
arg->add_strings(value);
}
void AddStringsArg(const char *name, const std::vector<const char *> &values) {
auto arg = op_def_.add_arg();
arg->set_name(name);
for (auto value : values) {
arg->add_strings(value);
}
}
OperatorDef* operator_def() { return &op_def_; }
OperatorDef *operator_def() { return &op_def_; }
Workspace* ws() { return &ws_; }
Workspace *ws() { return &ws_; }
bool RunOp(DeviceType device) {
if (!net_) {
NetDef net_def;
net_def.add_op()->CopyFrom(op_def_);
VLOG(3) << net_def.DebugString();
net_ = CreateNet(net_def, &ws_, device);
}
return net_->Run();
bool RunOp(DeviceType device) {
if (!net_) {
NetDef net_def;
net_def.add_op()->CopyFrom(op_def_);
VLOG(3) << net_def.DebugString();
net_ = CreateNet(net_def, &ws_, device);
}
return net_->Run();
}
bool RunOp() {
return RunOp(DeviceType::CPU);
}
bool RunOp() {
return RunOp(DeviceType::CPU);
}
Tensor* GetOutput(const char* output_name) {
return ws_.GetTensor(output_name);
}
Tensor *GetOutput(const char *output_name) {
return ws_.GetTensor(output_name);
}
public:
Workspace ws_;
OperatorDef op_def_;
std::unique_ptr<NetBase> net_;
public:
Workspace ws_;
OperatorDef op_def_;
std::unique_ptr<NetBase> net_;
};
class OpsTestBase : public ::testing::Test {
public:
OpsTestNet& test_net() { return test_net_; };
protected:
virtual void TearDown() {
auto ws = test_net_.ws();
auto tensor_names = ws->Tensors();
for (auto& name : tensor_names) {
ws->RemoveTensor(name);
}
public:
OpsTestNet &test_net() { return test_net_; };
protected:
virtual void TearDown() {
auto ws = test_net_.ws();
auto tensor_names = ws->Tensors();
for (auto &name : tensor_names) {
ws->RemoveTensor(name);
}
}
private:
OpsTestNet test_net_;
private:
OpsTestNet test_net_;
};
template <typename T>
Tensor CreateTensor(const std::vector<index_t>& shape, const std::vector<T>& data) {
template<typename T>
Tensor CreateTensor(const std::vector<index_t> &shape, const std::vector<T> &data) {
Tensor res(cpu_allocator(), DataTypeToEnum<T>::v());
res.Resize(shape);
float* input_data = res.mutable_data<float>();
float *input_data = res.mutable_data<float>();
memcpy(input_data, data.data(), data.size() * sizeof(T));
return res;
}
inline bool IsSameSize(const Tensor& x, const Tensor& y) {
inline bool IsSameSize(const Tensor &x, const Tensor &y) {
if (x.dim_size() != y.dim_size()) return false;
for (int d = 0; d < x.dim_size(); ++d) {
if (x.dim(d) != y.dim(d)) return false;
......@@ -168,58 +183,59 @@ inline bool IsSameSize(const Tensor& x, const Tensor& y) {
return true;
}
inline std::string ShapeToString(const Tensor& x) {
inline std::string ShapeToString(const Tensor &x) {
std::stringstream stream;
for (int i = 0; i < x.dim_size(); i++) {
if (i > 0) stream<<",";
if (i > 0) stream << ",";
int64_t dim = x.dim(i);
if (dim < 0) {
stream<<"?";
stream << "?";
} else {
stream<<dim;
stream << dim;
}
}
stream<<"]";
stream << "]";
return std::string(stream.str());
}
template <typename T>
template<typename T>
struct is_floating_point_type {
static const bool value = std::is_same<T, float>::value ||
std::is_same<T, double>::value;
};
template <typename T>
inline void ExpectEqual(const T& a, const T& b) {
template<typename T>
inline void ExpectEqual(const T &a, const T &b) {
EXPECT_EQ(a, b);
}
template <>
inline void ExpectEqual<float>(const float& a, const float& b) {
template<>
inline void ExpectEqual<float>(const float &a, const float &b) {
EXPECT_FLOAT_EQ(a, b);
}
template <>
inline void ExpectEqual<double>(const double& a, const double& b) {
template<>
inline void ExpectEqual<double>(const double &a, const double &b) {
EXPECT_DOUBLE_EQ(a, b);
}
inline void AssertSameTypeDims(const Tensor& x, const Tensor& y) {
inline void AssertSameTypeDims(const Tensor &x, const Tensor &y) {
ASSERT_EQ(x.dtype(), y.dtype());
ASSERT_TRUE(IsSameSize(x, y))
<< "x.shape [" << ShapeToString(x) << "] vs "
<< "y.shape [ " << ShapeToString(y) << "]";
<< "x.shape [" << ShapeToString(x) << "] vs "
<< "y.shape [ " << ShapeToString(y) << "]";
}
template <typename T, bool is_fp = is_floating_point_type<T>::value>
template<typename T, bool is_fp = is_floating_point_type<T>::value>
struct Expector;
// Partial specialization for float and double.
template <typename T>
template<typename T>
struct Expector<T, true> {
static void Equal(const T& a, const T& b) { ExpectEqual(a, b); }
static void Equal(const T &a, const T &b) { ExpectEqual(a, b); }
static void Equal(const Tensor& x, const Tensor& y) {
static void Equal(const Tensor &x, const Tensor &y) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
AssertSameTypeDims(x, y);
auto a = x.data<T>();
......@@ -229,22 +245,22 @@ struct Expector<T, true> {
}
}
static void Near(const Tensor& x, const Tensor& y, const double abs_err) {
static void Near(const Tensor &x, const Tensor &y, const double abs_err) {
ASSERT_EQ(x.dtype(), DataTypeToEnum<T>::v());
AssertSameTypeDims(x, y);
auto a = x.data<T>();
auto b = y.data<T>();
for (int i = 0; i < x.size(); ++i) {
EXPECT_NEAR(a[i], b[i], abs_err)
<< "a = " << a << " b = " << b << " index = " << i;
<< "a = " << a << " b = " << b << " index = " << i;
}
}
};
template <typename T>
void ExpectTensorNear(const Tensor& x, const Tensor& y, const double abs_err) {
template<typename T>
void ExpectTensorNear(const Tensor &x, const Tensor &y, const double abs_err) {
static_assert(is_floating_point_type<T>::value, "T is not a floating point type");
Expector<T>::Near(x, y ,abs_err);
Expector<T>::Near(x, y, abs_err);
}
} // namespace mace
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册