提交 12c2e126 编写于 作者: W Wiktor Adamski

Refactored tests.

上级 6333a263
......@@ -27,7 +27,7 @@ class PadTest : public OpsTestBase {};
namespace {
template <DeviceType D>
void Simple() {
void SimpleConstant() {
// Construct graph
OpsTestNet net;
......@@ -72,11 +72,99 @@ void Simple() {
});
ExpectTensorNear<float>(*expected, *output, 1e-5);
}
template <DeviceType D, typename T>
void Result(const std::vector<index_t> &input_shape,
const std::vector<float> &input_data,
const std::vector<index_t> &expected_shape,
const std::vector<float> &expected_data,
const std::vector<int> &paddings,
const PadType pad_type) {
// Construct graph
OpsTestNet net;
std::string input("Input");
std::string t_input(input);
std::string output("Output");
std::string t_output(output);
// Add input data
net.AddInputFromArray<D, float>(input, input_shape, input_data);
if (D == DeviceType::CPU) {
t_input = "TInput";
t_output = "TOutput";
net.TransformDataFormat<DeviceType::CPU, T>(input, NHWC, t_input, NCHW);
}
OpDefBuilder("Pad", "PadTest")
.Input(t_input)
.Output(t_output)
.AddIntsArg("paddings", paddings)
.AddIntArg("pad_type", static_cast<int>(pad_type))
.AddIntArg("data_format", DataFormat::NHWC)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, T>(t_output, NCHW, output, NHWC);
}
auto actual = net.GetTensor(output.c_str());
auto expected = net.CreateTensor<float>(expected_shape, expected_data);
ExpectTensorNear<float>(*expected, *actual, 1e-5);
}
void SimpleMirror(const std::vector<float> &expected_data,
const PadType pad_type) {
std::vector<index_t> input_shape{1, 3, 4, 1};
int size = std::accumulate(input_shape.begin(), input_shape.end(),
1, std::multiplies<index_t>());
std::vector<float> input_data;
std::vector<index_t> expected_shape{1, 6, 7, 1};
const std::vector<int> paddings{0, 0, 1, 2, 3, 0, 0, 0};
input_data.reserve(size);
for (int i = 1; i <= size; i++) {
input_data.push_back(i);
}
Result<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_data, paddings, pad_type);
Result<DeviceType::GPU, float>(input_shape, input_data, expected_shape,
expected_data, paddings, pad_type);
Result<DeviceType::GPU, half>(input_shape, input_data, expected_shape,
expected_data, paddings, pad_type);
}
} // namespace
TEST_F(PadTest, SimpleCPU) { Simple<DeviceType::CPU>(); }
TEST_F(PadTest, SimpleConstantCPU) { SimpleConstant<DeviceType::CPU>(); }
TEST_F(PadTest, SimpleConstantGPU) { SimpleConstant<DeviceType::GPU>(); }
TEST_F(PadTest, SimpleGPU) { Simple<DeviceType::GPU>(); }
TEST_F(PadTest, SimpleReflect) {
SimpleMirror({
8, 7, 6, 5, 6, 7, 8,
4, 3, 2, 1, 2, 3, 4,
8, 7, 6, 5, 6, 7, 8,
12, 11, 10, 9, 10, 11, 12,
8, 7, 6, 5, 6, 7, 8,
4, 3, 2, 1, 2, 3, 4,
}, PadType::REFLECT);
}
TEST_F(PadTest, SimpleSymmetric) {
SimpleMirror({
3, 2, 1, 1, 2, 3, 4,
3, 2, 1, 1, 2, 3, 4,
7, 6, 5, 5, 6, 7, 8,
11, 10, 9, 9, 10, 11, 12,
11, 10, 9, 9, 10, 11, 12,
7, 6, 5, 5, 6, 7, 8,
}, PadType::SYMMETRIC);
}
TEST_F(PadTest, ComplexCPU) {
// Construct graph
......@@ -178,52 +266,6 @@ TEST_F(PadTest, ComplexHalf) {
}
}
namespace {
template <DeviceType D, typename T>
void Result(const std::vector<index_t> &input_shape,
const std::vector<float> &input_data,
const std::vector<index_t> &expected_shape,
const std::vector<float> &expected_data,
const std::vector<int> &paddings,
const PadType pad_type) {
// Construct graph
OpsTestNet net;
std::string input("Input");
std::string t_input(input);
std::string output("Output");
std::string t_output(output);
// Add input data
net.AddInputFromArray<D, float>(input, input_shape, input_data);
if (D == DeviceType::CPU) {
t_input = "TInput";
t_output = "TOutput";
net.TransformDataFormat<DeviceType::CPU, T>(input, NHWC, t_input, NCHW);
}
OpDefBuilder("Pad", "PadTest")
.Input(t_input)
.Output(t_output)
.AddIntsArg("paddings", paddings)
.AddIntArg("pad_type", static_cast<int>(pad_type))
.AddIntArg("data_format", DataFormat::NHWC)
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
if (D == DeviceType::CPU) {
net.TransformDataFormat<DeviceType::CPU, T>(t_output, NCHW, output, NHWC);
}
auto actual = net.GetTensor(output.c_str());
auto expected = net.CreateTensor<float>(expected_shape, expected_data);
ExpectTensorNear<float>(*expected, *actual, 1e-5);
}
} // namespace
TEST_F(PadTest, ReflectCPU) {
std::vector<index_t> input_shape{2, 2, 2, 2};
int size = std::accumulate(input_shape.begin(), input_shape.end(),
......@@ -426,50 +468,6 @@ TEST_F(PadTest, SymmetricCPU) {
expected_data, paddings, PadType::SYMMETRIC);
}
TEST_F(PadTest, Result) {
std::vector<index_t> input_shape{1, 3, 4, 1};
int size = std::accumulate(input_shape.begin(), input_shape.end(),
1, std::multiplies<index_t>());
std::vector<float> input_data;
std::vector<index_t> expected_shape{1, 6, 7, 1};
std::vector<float> expected_reflect{
8, 7, 6, 5, 6, 7, 8,
4, 3, 2, 1, 2, 3, 4,
8, 7, 6, 5, 6, 7, 8,
12, 11, 10, 9, 10, 11, 12,
8, 7, 6, 5, 6, 7, 8,
4, 3, 2, 1, 2, 3, 4,
};
std::vector<float> expected_symmetric{
3, 2, 1, 1, 2, 3, 4,
3, 2, 1, 1, 2, 3, 4,
7, 6, 5, 5, 6, 7, 8,
11, 10, 9, 9, 10, 11, 12,
11, 10, 9, 9, 10, 11, 12,
7, 6, 5, 5, 6, 7, 8,
};
const std::vector<int> paddings{0, 0, 1, 2, 3, 0, 0, 0};
input_data.reserve(size);
for (int i = 1; i <= size; i++) {
input_data.push_back(i);
}
Result<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_reflect, paddings, PadType::REFLECT);
Result<DeviceType::GPU, float>(input_shape, input_data, expected_shape,
expected_reflect, paddings, PadType::REFLECT);
Result<DeviceType::GPU, half>(input_shape, input_data, expected_shape,
expected_reflect, paddings, PadType::REFLECT);
Result<DeviceType::CPU, float>(input_shape, input_data, expected_shape,
expected_symmetric, paddings, PadType::SYMMETRIC);
Result<DeviceType::GPU, float>(input_shape, input_data, expected_shape,
expected_symmetric, paddings, PadType::SYMMETRIC);
Result<DeviceType::GPU, half>(input_shape, input_data, expected_shape,
expected_symmetric, paddings, PadType::SYMMETRIC);
}
} // namespace test
} // namespace ops
} // namespace mace
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册