提交 60acbe05 编写于 作者: alinag's avatar alinag

add reduce sum test of 3 dimention case

test=develop
上级 62257e33
......@@ -52,12 +52,12 @@ inline int ConvOutputSize(int input_size,
return output_size;
}
void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize) {
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
......
......@@ -196,7 +196,7 @@ class ReduceSumComputeTester : public arena::TestCase {
std::vector<int> dim_{0};
bool keep_dim_ = false;
bool reduce_all_ = false;
DDim x_dims_{{3, 2, 3, 4}};
DDim x_dims_;
public:
ReduceSumComputeTester(const Place& place,
......@@ -255,39 +255,76 @@ class ReduceSumComputeTester : public arena::TestCase {
out->Resize(DDim(out_dims));
auto* out_data = out->mutable_data<float>();
int in_n = x_dims_[0];
int in_c = x_dims_[1];
int in_h = x_dims_[2];
int in_w = x_dims_[3];
if (x_dims_.size() == 4) {
int in_n = x_dims_[0];
int in_c = x_dims_[1];
int in_h = x_dims_[2];
int in_w = x_dims_[3];
if (reduce_all_) {
reduce_sum_all(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_.size() == 1) {
switch (dim_[0]) {
case 0:
reduce_sum_n(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 1:
reduce_sum_c(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 2:
reduce_sum_h(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 3:
reduce_sum_w(x_data, out_data, in_n, in_c, in_h, in_w);
break;
default:
LOG(FATAL) << "error!!!";
if (reduce_all_) {
reduce_sum_all(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_.size() == 1) {
switch (dim_[0]) {
case 0:
reduce_sum_n(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 1:
reduce_sum_c(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 2:
reduce_sum_h(x_data, out_data, in_n, in_c, in_h, in_w);
break;
case 3:
reduce_sum_w(x_data, out_data, in_n, in_c, in_h, in_w);
break;
default:
LOG(FATAL) << "error!!!";
}
} else if (dim_.size() == 2) {
if (dim_[0] == 0 && dim_[1] == 1) {
reduce_sum_nc(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 1 && dim_[1] == 2) {
reduce_sum_ch(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 2 && dim_[1] == 3) {
reduce_sum_hw(x_data, out_data, in_n, in_c, in_h, in_w);
} else {
LOG(FATAL) << "invalid dims_!!";
}
}
} else if (dim_.size() == 2) {
if (dim_[0] == 0 && dim_[1] == 1) {
reduce_sum_nc(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 1 && dim_[1] == 2) {
reduce_sum_ch(x_data, out_data, in_n, in_c, in_h, in_w);
} else if (dim_[0] == 2 && dim_[1] == 3) {
reduce_sum_hw(x_data, out_data, in_n, in_c, in_h, in_w);
} else {
LOG(FATAL) << "invalid dims_!!";
} else {
if (dim_.size() == 1 && !reduce_all_) {
switch (dim_[0]) {
case 0:
for (int i = 0; i < x_dims_[1]; i++) {
for (int j = 0; j < x_dims_[2]; j++) {
out_data[i * x_dims_[2] + j] =
i * x_dims_[0] * x_dims_[2] + j * x_dims_[0] +
x_dims_[0] * (x_dims_[0] - 1) * x_dims_[1] * x_dims_[2] / 2;
}
}
break;
case 1:
for (int i = 0; i < x_dims_[0]; i++) {
for (int j = 0; j < x_dims_[2]; j++) {
out_data[i * x_dims_[2] + j] =
i * x_dims_[1] * x_dims_[1] * x_dims_[2] + j * x_dims_[1] +
x_dims_[1] * (x_dims_[1] - 1) * x_dims_[2] / 2;
}
}
break;
case 2:
for (int i = 0; i < x_dims_[0]; i++) {
for (int j = 0; j < x_dims_[1]; j++) {
out_data[i * x_dims_[1] + j] =
i * x_dims_[1] * x_dims_[2] * x_dims_[2] +
j * x_dims_[2] * x_dims_[2] +
x_dims_[2] * (x_dims_[2] - 1) / 2;
}
}
break;
default:
LOG(FATAL) << "error!!!";
}
}
}
}
......@@ -333,6 +370,21 @@ void test_reduce_sum(Place place) {
}
}
}
std::vector<std::vector<int>> reduce_dimm{{0}, {1}, {2}};
for (auto dim : reduce_dimm) {
for (auto c : {1, 3}) {
for (auto h : {1, 3}) {
for (auto w : {1, 4}) {
auto x_dims = DDim(std::vector<int64_t>({c, h, w}));
std::unique_ptr<arena::TestCase> tester(new ReduceSumComputeTester(
place, "def", dim, false, false, x_dims));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
}
}
TEST(ReduceSum, precision) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册