提交 6a154ca8 编写于 作者: 吴承辉

Merge branch 'gemm' into 'master'

Optimize gemm v7

See merge request !577
此差异已折叠。
......@@ -38,7 +38,9 @@ void TestShapeOp(const std::vector<index_t> &input_shape) {
std::vector<int32_t> expected_input_shape(input_shape.begin(),
input_shape.end());
if (!expected_input_shape.empty()) {
net.AddInputFromArray<CPU, int32_t>("ExpectedOutput", {input_shape.size()},
net.AddInputFromArray<CPU, int32_t>("ExpectedOutput",
{static_cast<int32_t>(
input_shape.size())},
expected_input_shape);
} else {
net.AddInputFromArray<CPU, int32_t>("ExpectedOutput", {}, {0});
......
......@@ -37,11 +37,18 @@ void TestSlice(const std::vector<index_t> &input_shape,
const std::vector<float> &output) {
OpsTestNet net;
net.AddInputFromArray<CPU, float>("Input", input_shape, input);
net.AddInputFromArray<CPU, int32_t>("BeginIndices", {input_shape.size()},
net.AddInputFromArray<CPU, int32_t>("BeginIndices",
{static_cast<int32_t>(
input_shape.size())},
begin_indices);
net.AddInputFromArray<CPU, int32_t>("EndIndices", {input_shape.size()},
net.AddInputFromArray<CPU, int32_t>("EndIndices",
{static_cast<int32_t>(
input_shape.size())},
end_indices);
net.AddInputFromArray<CPU, int32_t>("Strides", {input_shape.size()}, strides);
net.AddInputFromArray<CPU, int32_t>("Strides",
{static_cast<int32_t>(
input_shape.size())},
strides);
OpDefBuilder("StridedSlice", "StridedSliceOpTest")
.Input("Input")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册