未验证 提交 e1a7a880 编写于 作者: S Sylwester Fraczek 提交者: GitHub

added reshape transpose matmul fuse pass (#23754)

上级 61d19a8e
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include <set>
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -158,6 +159,11 @@ DDim DDim::transpose(const std::vector<int>& axis) const { ...@@ -158,6 +159,11 @@ DDim DDim::transpose(const std::vector<int>& axis) const {
size_t in_rank = in_dims.size(); size_t in_rank = in_dims.size();
size_t axis_size = axis.size(); size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), axis_size,
platform::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in_rank, axis_size, in_rank, axis_size,
platform::errors::InvalidArgument("The input dimension's size " platform::errors::InvalidArgument("The input dimension's size "
...@@ -166,25 +172,9 @@ DDim DDim::transpose(const std::vector<int>& axis) const { ...@@ -166,25 +172,9 @@ DDim DDim::transpose(const std::vector<int>& axis) const {
"axis's size is %d", "axis's size is %d",
in_rank, axis_size)); in_rank, axis_size));
std::vector<int> count(axis_size, 0); PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size,
for (size_t i = 0; i < axis_size; i++) { platform::errors::InvalidArgument(
PADDLE_ENFORCE_LT(axis[i], static_cast<int>(axis_size), "Axis values must be ranging from 0 to (dims - 1)."));
platform::errors::InvalidArgument(
"ValueError: Each element of axis must appear "
"exactly once in the range from 0 to (dims - 1), "
"where the dims is the axis's size, "
"but received axis[%d] is %d, axis_size is %d",
i, axis[i], axis_size));
PADDLE_ENFORCE_EQ(
++count[axis[i]], 1,
platform::errors::InvalidArgument(
"ValueError: Each element of axis should "
"be a unique value range from 0 to (dims - 1), "
"where the dims is the axis's size, "
"unique value means this axis value can appear only once. "
"But received count[axis[%d]] is %d",
i, count[axis[i]]));
}
DDim out_dims(in_dims); DDim out_dims(in_dims);
for (size_t i = 0; i < axis_size; i++) { for (size_t i = 0; i < axis_size; i++) {
......
...@@ -97,6 +97,7 @@ if(WITH_MKLDNN) ...@@ -97,6 +97,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn) pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn) pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn) pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
endif() endif()
...@@ -145,5 +146,8 @@ if (WITH_MKLDNN) ...@@ -145,5 +146,8 @@ if (WITH_MKLDNN)
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass) cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor) cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor) cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
if(NOT WITH_COVERAGE)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
endif()
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass) cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass)
endif () endif ()
...@@ -33,7 +33,6 @@ namespace paddle { ...@@ -33,7 +33,6 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
using string::PrettyLogEndl;
using string::PrettyLog; using string::PrettyLog;
using string::Style; using string::Style;
...@@ -2148,6 +2147,57 @@ void patterns::DeleteQuantDequantOpPattern::operator()() { ...@@ -2148,6 +2147,57 @@ void patterns::DeleteQuantDequantOpPattern::operator()() {
any_op2->LinksFrom({quant_dequant_out}); any_op2->LinksFrom({quant_dequant_out});
} }
PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
bool with_reshape_xshape, bool with_transpose_xshape) {
auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto reshape_in = pattern->NewNode(reshape_in_repr())
->AsInput()
->assert_is_op_input("reshape2", "X");
auto reshape_out = pattern->NewNode(reshape_out_repr())
->AsIntermediate()
->assert_is_op_input("transpose2", "X")
->assert_is_op_output("reshape2", "Out");
if (!with_reshape_xshape)
reshape_out->assert_is_only_output_of_op("reshape2");
auto reshape_xshape = with_reshape_xshape
? pattern->NewNode(reshape_xshape_repr())
->AsIntermediate()
->assert_is_op_output("reshape2", "XShape")
: nullptr;
auto transpose_out = pattern->NewNode(transpose_out_repr())
->AsIntermediate()
->assert_is_op_input("matmul")
->assert_is_op_output("transpose2", "Out");
if (!with_transpose_xshape)
transpose_out->assert_is_only_output_of_op("transpose2");
auto transpose_xshape =
with_transpose_xshape
? pattern->NewNode(transpose_xshape_repr())
->AsIntermediate()
->assert_is_op_output("transpose2", "XShape")
: nullptr;
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul", "Out");
reshape_op->LinksFrom({reshape_in}).LinksTo({reshape_out});
if (with_reshape_xshape) reshape_op->LinksTo({reshape_xshape});
transpose_op->LinksFrom({reshape_out}).LinksTo({transpose_out});
if (with_transpose_xshape) transpose_op->LinksTo({transpose_xshape});
matmul_op->LinksFrom({transpose_out}).LinksTo({matmul_out});
return matmul_out;
}
PDNode *patterns::MatmulTransposeReshapePattern::operator()() { PDNode *patterns::MatmulTransposeReshapePattern::operator()() {
auto reshape_op = auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2"); pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
......
...@@ -1210,6 +1210,29 @@ struct DeleteQuantDequantOpPattern : public PatternBase { ...@@ -1210,6 +1210,29 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
PATTERN_DECL_NODE(any_op2); PATTERN_DECL_NODE(any_op2);
}; };
// Reshape + Transpose + Matmul
// named nodes:
// reshape_op, reshape_out, reshape_xshape,
// transpose_op, transpose_out, transpose_xshape,
// matmul_op, matmul_out
struct ReshapeTransposeMatmulPattern : public PatternBase {
ReshapeTransposeMatmulPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, "reshape_transpose_matmul") {}
PDNode* operator()(bool with_reshape_xshape, bool with_transpose_xshape);
PATTERN_DECL_NODE(reshape_in);
PATTERN_DECL_NODE(reshape_op);
PATTERN_DECL_NODE(reshape_out);
PATTERN_DECL_NODE(reshape_xshape);
PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(transpose_out);
PATTERN_DECL_NODE(transpose_xshape);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
};
// Matmul + Transpose + Reshape // Matmul + Transpose + Reshape
struct MatmulTransposeReshapePattern : public PatternBase { struct MatmulTransposeReshapePattern : public PatternBase {
MatmulTransposeReshapePattern(PDPattern* pattern, MatmulTransposeReshapePattern(PDPattern* pattern,
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include <string>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
void ReshapeTransposeMatmulMkldnnFusePass::Fuse(
Graph *graph, bool with_reshape_xshape, bool with_transpose_xshape) const {
GraphPatternDetector gpd;
patterns::ReshapeTransposeMatmulPattern rtm_pattern(gpd.mutable_pattern(),
name_scope_);
rtm_pattern(with_reshape_xshape, with_transpose_xshape);
int found_reshape_transpose_matmul_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
VLOG(4) << "handle ReshapeTransposeMatmulMkldnn fuse";
GET_IR_NODE_FROM_SUBGRAPH(reshape_in, reshape_in, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_op, reshape_op, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape_out, reshape_out, rtm_pattern);
ir::Node *reshape_xshape{nullptr};
if (with_reshape_xshape) {
GET_IR_NODE_FROM_SUBGRAPH(reshape_xshape1, reshape_xshape, rtm_pattern);
reshape_xshape = reshape_xshape1;
}
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(transpose_out, transpose_out, rtm_pattern);
ir::Node *transpose_xshape{nullptr};
if (with_transpose_xshape) {
GET_IR_NODE_FROM_SUBGRAPH(transpose_xshape1, transpose_xshape,
rtm_pattern);
transpose_xshape = transpose_xshape1;
}
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, rtm_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, rtm_pattern);
auto reshape_shape =
boost::get<std::vector<int>>(reshape_op->Op()->GetAttr("shape"));
auto transpose_axis =
boost::get<std::vector<int>>(transpose_op->Op()->GetAttr("axis"));
OpDesc *matmul_desc = matmul_op->Op();
std::string input_var_name = transpose_out->Name();
auto UpdateMatmul = [&](std::string matmul_input_name) {
matmul_desc->SetInput(matmul_input_name, {(reshape_in)->Name()});
matmul_desc->SetAttr("fused_reshape_" + matmul_input_name, reshape_shape);
matmul_desc->SetAttr("fused_transpose_" + matmul_input_name,
transpose_axis);
};
if (matmul_desc->Inputs().at("X").at(0) == input_var_name) {
UpdateMatmul("X");
} else if (matmul_desc->Inputs().at("Y").at(0) == input_var_name) {
UpdateMatmul("Y");
} else {
throw platform::errors::InvalidArgument(
"Unexpected input to MatMul encountered.");
}
std::unordered_set<const ir::Node *> nodes_to_remove{
reshape_op, reshape_out, transpose_op, transpose_out};
if (with_reshape_xshape) nodes_to_remove.insert(reshape_xshape);
if (with_transpose_xshape) nodes_to_remove.insert(transpose_xshape);
GraphSafeRemoveNodes(graph, nodes_to_remove);
IR_NODE_LINK_TO(reshape_in, matmul_op);
++found_reshape_transpose_matmul_count;
};
gpd(graph, handler);
AddStatis(found_reshape_transpose_matmul_count);
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_reshape_transpose_matmul_count
<< " ReshapeTransposeMatmulMkldnn patterns";
if (with_reshape_xshape) msg_ss << " with reshape's xshape";
if (with_transpose_xshape) msg_ss << " with transpose's xshape";
string::PrettyLogDetail(msg_ss.str().c_str());
}
void ReshapeTransposeMatmulMkldnnFusePass::ApplyImpl(ir::Graph *graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
FusePassBase::Init(name_scope_, graph);
Fuse(graph, false, false);
Fuse(graph, false, true);
Fuse(graph, true, false);
Fuse(graph, true, true);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(reshape_transpose_matmul_mkldnn_fuse_pass,
paddle::framework::ir::ReshapeTransposeMatmulMkldnnFusePass);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* Fuse Reshape->Transpose->MatMul when MatMul uses mkldnn.
*/
class ReshapeTransposeMatmulMkldnnFusePass : public FusePassBase {
public:
virtual ~ReshapeTransposeMatmulMkldnnFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
const std::string name_scope_{"reshape_transpose_matmul_fuse"};
void Fuse(Graph* graph, bool with_reshape_xshape,
bool with_transpose_xshape) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope, const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<LoDTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(platform::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "w1", {768, 768});
AddVarToScope(param_scope, "bias1", {768});
AddVarToScope(param_scope, "w2", {768, 768});
AddVarToScope(param_scope, "bias2", {768});
return param_scope;
}
void TestMain(bool with_xshapes) {
// inputs operator output
// -----------------------------------------------
// a1,w1,bias1 fc -> b1
// b1 reshape -> c1
// c1 transpose -> d1
// a2,w2,bias2 fc -> b2
// b2 reshape -> c2
// c2 transpose -> d2
// (d1, d2) matmul -> (...)
Layers layers;
auto* a1 = layers.data("a1", {-1, 128, 768});
auto* w1 = layers.data("w1", {768, 768}, true);
auto* bias1 = layers.data("bias1", {768}, true);
auto* b1 = layers.fc(a1, w1, bias1, 2);
b1->SetShape({-1, 128, 768});
auto* c1 = layers.reshape2(b1, {0, 0, 12, 64}, with_xshapes);
c1->SetShape({-1, 128, 12, 64});
auto* d1 = layers.transpose2(c1, {0, 2, 1, 3}, with_xshapes);
d1->SetShape({-1, 12, 128, 64});
auto* a2 = layers.data("a2", {-1, 128, 768});
auto* w2 = layers.data("w2", {768, 768}, true);
auto* bias2 = layers.data("bias2", {768}, true);
auto* b2 = layers.fc(a2, w2, bias2, 2);
b2->SetShape({-1, 128, 768});
auto* c2 = layers.reshape2(b2, {0, 0, 12, 64});
c2->SetShape({-1, 128, 12, 64});
auto* d2 = layers.transpose2(c2, {0, 2, 1, 3});
d2->SetShape({-1, 12, 128, 64});
layers.matmul(d1, d2);
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
int num_reshape_nodes_before = GetNumOpNodes(graph, "reshape2");
int num_transpose_nodes_before = GetNumOpNodes(graph, "transpose2");
int total_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
auto pass =
PassRegistry::Instance().Get("reshape_transpose_matmul_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release()));
int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2");
int num_transpose_nodes_after = GetNumOpNodes(graph, "transpose2");
int total_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
EXPECT_EQ(num_reshape_nodes_before, 2);
EXPECT_EQ(num_reshape_nodes_after, 0);
EXPECT_EQ(num_transpose_nodes_before, 2);
EXPECT_EQ(num_transpose_nodes_after, 0);
int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out
if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape
EXPECT_EQ(total_nodes_before - removed, total_nodes_after);
auto* matmul_op_desc = GetOpNodes(graph, "matmul").at(0)->Op();
auto check = [&matmul_op_desc](std::string a) {
std::string shape_str = "fused_reshape_" + a;
EXPECT_THAT(matmul_op_desc->GetAttrIfExists<std::vector<int>>(shape_str),
testing::ElementsAre(0, 0, 12, 64));
std::string axis_str = "fused_transpose_" + a;
EXPECT_THAT(matmul_op_desc->GetAttrIfExists<std::vector<int>>(axis_str),
testing::ElementsAre(0, 2, 1, 3));
};
check("X");
check("Y");
}
TEST(ReshapeTransposeMatmulMkldnnFusePass,
both_matmul_inputs_reshape_transpose) {
TestMain(false);
}
TEST(ReshapeTransposeMatmulMkldnnFusePass,
both_matmul_inputs_reshape_transpose_one_with_xshapes) {
TestMain(true);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass);
...@@ -258,23 +258,33 @@ struct Layers { ...@@ -258,23 +258,33 @@ struct Layers {
return out; return out;
} }
VarDesc* transpose2(VarDesc* x, std::vector<int> axis) { VarDesc* transpose2(VarDesc* x, std::vector<int> axis,
bool with_xshape = false) {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp(); OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("transpose2"); op->SetType("transpose2");
op->SetInput("X", {x->Name()}); op->SetInput("X", {x->Name()});
op->SetAttr("axis", axis); op->SetAttr("axis", axis);
op->SetOutput("Out", {out->Name()}); op->SetOutput("Out", {out->Name()});
if (with_xshape) {
VarDesc* xshape = lod_tensor(unique_name());
op->SetOutput("XShape", {xshape->Name()});
}
return out; return out;
} }
VarDesc* reshape2(VarDesc* x, std::vector<int> shape) { VarDesc* reshape2(VarDesc* x, std::vector<int> shape,
bool with_xshape = false) {
VarDesc* out = lod_tensor(unique_name()); VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp(); OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("reshape2"); op->SetType("reshape2");
op->SetInput("X", {x->Name()}); op->SetInput("X", {x->Name()});
op->SetAttr("shape", shape); op->SetAttr("shape", shape);
op->SetOutput("Out", {out->Name()}); op->SetOutput("Out", {out->Name()});
if (with_xshape) {
VarDesc* xshape = lod_tensor(unique_name());
op->SetOutput("XShape", {xshape->Name()});
}
return out; return out;
} }
...@@ -579,6 +589,17 @@ static std::string DebugString(const std::unique_ptr<Graph>& graph) { ...@@ -579,6 +589,17 @@ static std::string DebugString(const std::unique_ptr<Graph>& graph) {
return DebugString(graph.get()); return DebugString(graph.get());
} }
static std::vector<ir::Node*> GetOpNodes(const std::unique_ptr<Graph>& graph,
std::string op_type) {
std::vector<ir::Node*> rc;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op() && node->Op()->Type() == op_type) {
rc.push_back(node);
}
}
return rc;
}
static int GetNumOpNodes(const std::unique_ptr<Graph>& graph, static int GetNumOpNodes(const std::unique_ptr<Graph>& graph,
std::string op_type) { std::string op_type) {
int num_nodes = 0; int num_nodes = 0;
......
...@@ -191,12 +191,13 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -191,12 +191,13 @@ void CpuPassStrategy::EnableMKLDNN() {
"conv3d_bias_mkldnn_fuse_pass", // "conv3d_bias_mkldnn_fuse_pass", //
"conv_elementwise_add_mkldnn_fuse_pass", "conv_elementwise_add_mkldnn_fuse_pass",
"conv_concat_relu_mkldnn_fuse_pass", "conv_concat_relu_mkldnn_fuse_pass",
"conv_relu_mkldnn_fuse_pass", // "conv_relu_mkldnn_fuse_pass", //
"conv_leaky_relu_mkldnn_fuse_pass", // "conv_leaky_relu_mkldnn_fuse_pass", //
"conv_relu6_mkldnn_fuse_pass", // "conv_relu6_mkldnn_fuse_pass", //
"conv_swish_mkldnn_fuse_pass", // "conv_swish_mkldnn_fuse_pass", //
"scale_matmul_fuse_pass", // "scale_matmul_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", // "reshape_transpose_matmul_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up // Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass", // "fc_mkldnn_pass",
"mkldnn_inplace_pass", // This pass should be activated after "mkldnn_inplace_pass", // This pass should be activated after
......
...@@ -318,6 +318,36 @@ class MatMulGradKernel : public framework::OpKernel<T> { ...@@ -318,6 +318,36 @@ class MatMulGradKernel : public framework::OpKernel<T> {
} }
}; };
framework::DDim GetDimForInput(const framework::InferShapeContext &ctx,
std::string input_name) {
auto shape = ctx.Attrs().Get<std::vector<int>>("fused_reshape_" + input_name);
auto axis =
ctx.Attrs().Get<std::vector<int>>("fused_transpose_" + input_name);
auto dim = ctx.GetInputDim(input_name);
if (!shape.empty() && !axis.empty()) {
PADDLE_ENFORCE_GE(
shape.size(), 2,
platform::errors::InvalidArgument(
"shape_%s attribute of MatMulOp was implemented for 2, 3 "
"or 4 dimensions.",
input_name));
PADDLE_ENFORCE_LE(
shape.size(), 4,
platform::errors::InvalidArgument(
"shape_%s attribute of MatMulOp was implemented for 2, 3 "
"or 4 dimensions.",
input_name));
PADDLE_ENFORCE_EQ(
shape.size(), axis.size(),
platform::errors::InvalidArgument(
"Ranks of shape_%s and axis_%s attributes of MatMulOp "
"must be equal.",
input_name, input_name));
dim = dim.reshape(shape).transpose(axis);
}
return dim;
}
class MatMulOp : public framework::OperatorWithKernel { class MatMulOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -328,9 +358,8 @@ class MatMulOp : public framework::OperatorWithKernel { ...@@ -328,9 +358,8 @@ class MatMulOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul"); OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", "matmul");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "matmul"); OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "matmul");
auto dim_x = context->GetInputDim("X"); auto dim_x = GetDimForInput(*context, "X");
auto dim_y = context->GetInputDim("Y"); auto dim_y = GetDimForInput(*context, "Y");
auto mat_dim_x = auto mat_dim_x =
math::CreateMatrixDescriptor(RowMatrixFromVector(dim_x), 0, math::CreateMatrixDescriptor(RowMatrixFromVector(dim_x), 0,
context->Attrs().Get<bool>("transpose_X")); context->Attrs().Get<bool>("transpose_X"));
...@@ -484,6 +513,18 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -484,6 +513,18 @@ class MatMulOpMaker : public framework::OpProtoAndCheckerMaker {
"use_mkldnn", "use_mkldnn",
"(bool, default false) Indicates if MKL-DNN kernel will be used") "(bool, default false) Indicates if MKL-DNN kernel will be used")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("fused_reshape_X",
R"DOC(Shape of fused reshape of `X` input.)DOC")
.SetDefault({});
AddAttr<std::vector<int>>("fused_reshape_Y",
R"DOC(Shape of fused reshape of `Y` input.)DOC")
.SetDefault({});
AddAttr<std::vector<int>>("fused_transpose_X",
R"DOC(Axis of fused transpose of `X` input.)DOC")
.SetDefault({});
AddAttr<std::vector<int>>("fused_transpose_Y",
R"DOC(Axis of fused transpose of `Y` input.)DOC")
.SetDefault({});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"fused_reshape_Out", "fused_reshape_Out",
R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, " R"DOC(When MKLDNN MatMul_transpose_reshape fuse activated, "
......
...@@ -23,12 +23,12 @@ namespace operators { ...@@ -23,12 +23,12 @@ namespace operators {
using dnnl::memory; using dnnl::memory;
using dnnl::primitive; using dnnl::primitive;
using platform::to_void_cast;
using framework::DataLayout; using framework::DataLayout;
using framework::ExecutionContext;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
using platform::MKLDNNGetDataType;
using platform::MKLDNNDeviceContext; using platform::MKLDNNDeviceContext;
using framework::ExecutionContext; using platform::MKLDNNGetDataType;
using platform::to_void_cast;
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
...@@ -86,6 +86,74 @@ class MatMulFactory { ...@@ -86,6 +86,74 @@ class MatMulFactory {
return dnnl::memory(md, engine_, to_void_cast(data)); return dnnl::memory(md, engine_, to_void_cast(data));
} }
std::vector<int64_t> Transpose(const std::vector<int64_t>& x,
const std::vector<int>& axis) {
size_t in_rank = x.size();
size_t axis_size = axis.size();
auto axis_set = std::set<int>(axis.begin(), axis.end());
PADDLE_ENFORCE_EQ(axis_set.size(), axis_size,
platform::errors::InvalidArgument(
"In an axis array, elements must be unique."));
PADDLE_ENFORCE_EQ(
in_rank, axis_size,
platform::errors::InvalidArgument("The input dimension's size "
"should be equal to the axis's size. "
"But received dimension is %d, "
"axis's size is %d",
in_rank, axis_size));
PADDLE_ENFORCE_LT(*std::max_element(axis.begin(), axis.end()), axis_size,
platform::errors::InvalidArgument(
"Axis values must be ranging from 0 to (dims - 1)."));
std::vector<int64_t> new_x(x.size());
for (size_t i = 0; i < x.size(); i++) {
new_x[i] = x[axis[i]];
}
return new_x;
}
std::pair<math::MatDescriptor, memory::dims> GetInputDimsAndStrides(
const ExecutionContext& ctx, std::string input_name) {
auto shape = ctx.Attr<std::vector<int>>("fused_reshape_" + input_name);
auto axis = ctx.Attr<std::vector<int>>("fused_transpose_" + input_name);
auto input_dims = ctx.Input<Tensor>(input_name)->dims();
auto new_dims = input_dims;
if (!shape.empty() && !axis.empty()) {
new_dims = input_dims.reshape(shape).transpose(axis);
}
auto& MatrixDimsFromVector = input_name == "X" ? RowMatrixDimsFromVector
: ColumnMatrixDimsFromVector;
math::MatDescriptor mat_dim =
math::CreateMatrixDescriptor(MatrixDimsFromVector(new_dims), 0,
ctx.Attr<bool>("transpose_" + input_name));
memory::dims strides;
if (!shape.empty()) {
auto shape2 = input_dims.reshape(shape);
strides.push_back(1);
for (auto i = shape2.size() - 1; i > 0; --i) {
strides.insert(strides.begin(), strides.front() * shape2[i]);
}
strides = Transpose(strides, axis);
if (shape.size() == 4)
strides.erase(strides.begin());
else if (shape.size() == 2)
strides.insert(strides.begin(), shape[0] * shape[1]);
mat_dim.stride_ = strides[0];
if (mat_dim.trans_) std::swap(*strides.rbegin(), *(++strides.rbegin()));
}
return std::make_pair(mat_dim, strides);
}
bool IsInputFused(const ExecutionContext& ctx) const {
return !(ctx.Attr<std::vector<int>>("fused_reshape_X").empty() &&
ctx.Attr<std::vector<int>>("fused_reshape_Y").empty());
}
bool IsOutputFused(const ExecutionContext& ctx) const { bool IsOutputFused(const ExecutionContext& ctx) const {
auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out"); auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto& fused_transpose_Out = auto& fused_transpose_Out =
...@@ -100,12 +168,12 @@ class MatMulFactory { ...@@ -100,12 +168,12 @@ class MatMulFactory {
} }
MatMulDims GetMatmulDims(const ExecutionContext& ctx) { MatMulDims GetMatmulDims(const ExecutionContext& ctx) {
auto mat_dim_x = math::CreateMatrixDescriptor( math::MatDescriptor mat_dim_x;
RowMatrixDimsFromVector(ctx.Input<Tensor>("X")->dims()), 0, memory::dims strides_x;
ctx.Attr<bool>("transpose_X")); std::tie(mat_dim_x, strides_x) = GetInputDimsAndStrides(ctx, "X");
auto mat_dim_y = math::CreateMatrixDescriptor( math::MatDescriptor mat_dim_y;
ColumnMatrixDimsFromVector(ctx.Input<Tensor>("Y")->dims()), 0, memory::dims strides_y;
ctx.Attr<bool>("transpose_Y")); std::tie(mat_dim_y, strides_y) = GetInputDimsAndStrides(ctx, "Y");
const auto x_bs = mat_dim_x.batch_size_; const auto x_bs = mat_dim_x.batch_size_;
const auto y_bs = mat_dim_y.batch_size_; const auto y_bs = mat_dim_y.batch_size_;
...@@ -122,26 +190,27 @@ class MatMulFactory { ...@@ -122,26 +190,27 @@ class MatMulFactory {
batch_size_ = 1; batch_size_ = 1;
auto b = BS; auto b = BS;
if (BS > 1 && IsOutputFused(ctx)) { if (BS > 1 && (IsOutputFused(ctx) || IsInputFused(ctx))) {
batch_size_ = ctx.Input<Tensor>("X")->dims()[0]; auto& x_dims = ctx.Input<Tensor>("X")->dims();
auto& y_dims = ctx.Input<Tensor>("Y")->dims();
batch_size_ = x_bs > y_bs ? x_dims[0] : y_dims[0];
b = BS / batch_size_; b = BS / batch_size_;
} }
memory::dims x_dims = {b, M, K}; memory::dims x_dims = {b, M, K};
memory::dims y_dims = {b, K, N}; memory::dims y_dims = {b, K, N};
memory::dims out_dims = {b, M, N}; memory::dims out_dims = {b, M, N};
size_t x_size = b * M * K * sizeof(XT); x_offset_ = b * M * K * sizeof(XT);
size_t y_size = b * K * N * sizeof(YT); y_offset_ = b * K * N * sizeof(YT);
size_t out_size = b * M * N * sizeof(OT); out_offset_ = b * M * N * sizeof(OT);
offsets_ = {x_size, y_size, out_size};
// Translate transA and transB // Translate transA and transB
memory::dims strides_x = !ctx.Attr<bool>("transpose_X") if (strides_x.empty())
? memory::dims{M * K, K, 1} strides_x = !ctx.Attr<bool>("transpose_X") ? memory::dims{M * K, K, 1}
: memory::dims{M * K, 1, M}; : memory::dims{M * K, 1, M};
memory::dims strides_y = !ctx.Attr<bool>("transpose_Y") if (strides_y.empty())
? memory::dims{N * K, N, 1} strides_y = !ctx.Attr<bool>("transpose_Y") ? memory::dims{N * K, N, 1}
: memory::dims{N * K, 1, K}; : memory::dims{N * K, 1, K};
memory::dims out_strides = memory::dims{M * N, N, 1}; memory::dims out_strides = memory::dims{M * N, N, 1};
CorrectStridesWhenFloatOutputFused(ctx, N, b, &out_strides); CorrectStridesWhenFloatOutputFused(ctx, N, b, &out_strides);
...@@ -187,12 +256,10 @@ class MatMulFactory { ...@@ -187,12 +256,10 @@ class MatMulFactory {
void Execute() { void Execute() {
dnnl::stream stream(engine_); dnnl::stream stream(engine_);
auto offsets = offsets_;
unsigned bs = batch_size_;
void* x_ptr = x_mem_.get_data_handle(); void* x_ptr = x_mem_.get_data_handle();
void* y_ptr = y_mem_.get_data_handle(); void* y_ptr = y_mem_.get_data_handle();
void* out_ptr = out_mem_.get_data_handle(); void* out_ptr = out_mem_.get_data_handle();
for (unsigned i = 0; i < bs; i++) { for (uint16_t i = 0; i < batch_size_; i++) {
x_mem_.set_data_handle(x_ptr); x_mem_.set_data_handle(x_ptr);
y_mem_.set_data_handle(y_ptr); y_mem_.set_data_handle(y_ptr);
out_mem_.set_data_handle(out_ptr); out_mem_.set_data_handle(out_ptr);
...@@ -201,9 +268,9 @@ class MatMulFactory { ...@@ -201,9 +268,9 @@ class MatMulFactory {
{MKLDNN_ARG_WEIGHTS, y_mem_}, {MKLDNN_ARG_WEIGHTS, y_mem_},
{MKLDNN_ARG_DST, out_mem_}, {MKLDNN_ARG_DST, out_mem_},
}); });
x_ptr = static_cast<char*>(x_ptr) + offsets.x_offset; x_ptr = static_cast<char*>(x_ptr) + x_offset_;
y_ptr = static_cast<char*>(y_ptr) + offsets.y_offset; y_ptr = static_cast<char*>(y_ptr) + y_offset_;
out_ptr = static_cast<char*>(out_ptr) + offsets.out_offset; out_ptr = static_cast<char*>(out_ptr) + out_offset_;
} }
stream.wait(); stream.wait();
} }
...@@ -243,21 +310,21 @@ class MatMulFactory { ...@@ -243,21 +310,21 @@ class MatMulFactory {
dnnl::memory y_mem_; dnnl::memory y_mem_;
dnnl::memory out_mem_; dnnl::memory out_mem_;
dnnl::matmul matmul_prim_; dnnl::matmul matmul_prim_;
memory_offsets offsets_; uint32_t x_offset_;
unsigned batch_size_; uint32_t y_offset_;
uint32_t out_offset_;
uint16_t batch_size_;
bool initialized_ = false; bool initialized_ = false;
}; };
template <typename XT, typename YT, typename OT> template <typename XT, typename YT, typename OT>
static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory( static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
const ExecutionContext& ctx) { const ExecutionContext& ctx) {
const auto x_dims = framework::vectorize<int>(ctx.Input<Tensor>("X")->dims());
const auto y_dims = framework::vectorize<int>(ctx.Input<Tensor>("Y")->dims());
const auto& out_name = ctx.OutputName("Out"); const auto& out_name = ctx.OutputName("Out");
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const std::string key = const std::string key =
platform::CreateKey(platform::ThreadIDasStr(), x_dims, y_dims, out_name); platform::CreateKey(platform::ThreadIDasStr(), out_name);
auto factory = auto factory =
std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key)); std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key));
......
...@@ -408,7 +408,10 @@ framework::DataLayout get_cur_paddle_data_layout(void) { ...@@ -408,7 +408,10 @@ framework::DataLayout get_cur_paddle_data_layout(void) {
return cur_paddle_data_layout; return cur_paddle_data_layout;
} }
void MKLDNNDeviceContext::ResetBlobMap() const { p_blobmap_->clear(); } void MKLDNNDeviceContext::ResetBlobMap() const {
VLOG(3) << "Clearing DNNL cache.";
p_blobmap_->clear();
}
size_t MKLDNNDeviceContext::GetShapeBlobSize() const { size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
std::lock_guard<std::mutex> lock(*p_mutex_); std::lock_guard<std::mutex> lock(*p_mutex_);
......
...@@ -500,6 +500,8 @@ class Qat2Int8MkldnnPass(object): ...@@ -500,6 +500,8 @@ class Qat2Int8MkldnnPass(object):
graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()), graph.draw('.', 'qat_int8_{}'.format(ir_pass.type()),
graph.all_op_nodes()) graph.all_op_nodes())
graph = self._apply_pass(graph, 'scale_matmul_fuse_pass') graph = self._apply_pass(graph, 'scale_matmul_fuse_pass')
graph = self._apply_pass(graph,
'reshape_transpose_matmul_mkldnn_fuse_pass')
graph = self._apply_pass( graph = self._apply_pass(
graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'], graph, 'cpu_quantize_pass', ['quant_var_scales', 'data_layout'],
[self._var_quant_scales, self._get_data_layout()]) [self._var_quant_scales, self._get_data_layout()])
......
...@@ -161,6 +161,180 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp): ...@@ -161,6 +161,180 @@ class TestDnnlMatMulOpInt8ForceFP32BasicScales(TestDnnlMatMulOp):
self.attrs = {'force_fp32_output': True} self.attrs = {'force_fp32_output': True}
@skip_check_grad_ci(reason="DNNL's MatMul doesn't implement grad kernel.")
class TestMatMulOpReshapeTranspose(OpTest):
def init_data_type(self):
self.data_type_ = 'float32'
def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32").reshape(
[2, 128, 12, 64]).transpose([0, 2, 1, 3])
self.y = np.random.random([2, 128, 768]).astype("float32").reshape(
[2, 128, 12, 64]).transpose([0, 2, 1, 3])
self.out = np.matmul(self.x, self.y.transpose([0, 1, 3, 2]))
self.fused_reshape_X = []
self.fused_transpose_X = []
self.fused_reshape_Y = []
self.fused_transpose_Y = []
def setUp(self):
# Set max isa, otherwise fails on SKX and earlier
os.environ["DNNL_MAX_CPU_ISA"] = "AVX"
self.op_type = "matmul"
self._cpu_only = True
self.use_mkldnn = True
self.transpose_y = True
self.init_data_type()
self.generate_data()
self.inputs = {'X': self.x, 'Y': self.y}
self.attrs = {
'use_mkldnn': self.use_mkldnn,
'transpose_Y': self.transpose_y
}
if len(self.fused_transpose_X) > 0:
self.attrs['fused_transpose_X'] = self.fused_transpose_X
if len(self.fused_transpose_Y) > 0:
self.attrs['fused_transpose_Y'] = self.fused_transpose_Y
if len(self.fused_reshape_X) > 0:
self.attrs['fused_reshape_X'] = self.fused_reshape_X
if len(self.fused_reshape_Y) > 0:
self.attrs['fused_reshape_Y'] = self.fused_reshape_Y
self.outputs = {'Out': self.out}
def test_check_output(self):
self.check_output()
class TestMatMulOpReshapeTranspose4DXFloat(TestMatMulOpReshapeTranspose):
def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32")
self.y = np.random.random([2, 128, 768]).astype("float32").reshape(
[2, 128, 12, 64]).transpose([0, 2, 1, 3])
self.fused_transpose_X = [0, 2, 1, 3]
self.fused_reshape_X = [0, 0, 12, 64]
self.fused_transpose_Y = []
self.fused_reshape_Y = []
self.out = np.matmul(
self.x.reshape([2, 128, 12, 64]).transpose([0, 2, 1, 3]),
self.y.transpose([0, 1, 3, 2]))
class TestMatMulOpReshapeTranspose4DXInt8(TestMatMulOpReshapeTranspose4DXFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose4DYFloat(TestMatMulOpReshapeTranspose):
def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32").reshape(
[2, 128, 12, 64]).transpose([0, 2, 1, 3])
self.y = np.random.random([2, 128, 768]).astype("float32")
self.fused_transpose_X = []
self.fused_reshape_X = []
self.fused_transpose_Y = [0, 2, 1, 3]
self.fused_reshape_Y = [0, 0, 12, 64]
self.out = np.matmul(
self.x, self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1]))
class TestMatMulOpReshapeTranspose4DYInt8(TestMatMulOpReshapeTranspose4DYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose4DXYFloat(TestMatMulOpReshapeTranspose):
def generate_data(self):
self.x = np.random.random([2, 128, 768]).astype("float32")
self.y = np.random.random([2, 128, 768]).astype("float32")
self.fused_transpose_X = [0, 2, 1, 3]
self.fused_reshape_X = [0, 0, 12, 64]
self.fused_transpose_Y = [0, 2, 1, 3]
self.fused_reshape_Y = [0, 0, 12, 64]
self.out = np.matmul(
self.x.reshape([2, 128, 12, 64]).transpose([0, 2, 1, 3]),
self.y.reshape([2, 128, 12, 64]).transpose([0, 2, 3, 1]))
class TestMatMulOpReshapeTranspose4DXYInt8(
TestMatMulOpReshapeTranspose4DXYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose2DXFloat(TestMatMulOpReshapeTranspose):
def generate_data(self):
self.x = np.random.random([2, 5, 10]).astype("float32")
self.y = np.random.random([2, 5, 10]).astype("float32").reshape(
[10, 10]).transpose([1, 0])
self.fused_transpose_X = [1, 0]
self.fused_reshape_X = [10, 10]
self.fused_transpose_Y = []
self.fused_reshape_Y = []
self.out = np.matmul(
self.x.reshape([10, 10]).transpose([1, 0]),
self.y.transpose([1, 0]))
class TestMatMulOpReshapeTranspose2DXInt8(TestMatMulOpReshapeTranspose2DXFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose2DYFloat(TestMatMulOpReshapeTranspose):
def generate_data(self):
self.x = np.random.random([2, 5, 10]).astype("float32").reshape(
[10, 10]).transpose([1, 0])
self.y = np.random.random([2, 5, 10]).astype("float32")
self.fused_transpose_X = []
self.fused_reshape_X = []
self.fused_transpose_Y = [1, 0]
self.fused_reshape_Y = [10, 10]
self.out = np.matmul(self.x, self.y.reshape([10, 10]))
class TestMatMulOpReshapeTranspose2DYInt8(TestMatMulOpReshapeTranspose2DYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose3DXFloat(TestMatMulOpReshapeTranspose):
def generate_data(self):
self.x = np.random.random([2, 2, 5, 5]).astype("float32")
self.y = np.random.random([2, 2, 5, 5]).astype("float32").reshape(
[2, 10, 5]).transpose([0, 2, 1])
self.fused_transpose_X = [0, 2, 1]
self.fused_reshape_X = [2, 10, 5]
self.fused_transpose_Y = []
self.fused_reshape_Y = []
self.out = np.matmul(
self.x.reshape([2, 10, 5]).transpose(0, 2, 1),
self.y.transpose(0, 2, 1))
class TestMatMulOpReshapeTranspose3DXInt8(TestMatMulOpReshapeTranspose3DXFloat):
def init_data_type(self):
self.data_type_ = 'int8'
class TestMatMulOpReshapeTranspose3DYFloat(TestMatMulOpReshapeTranspose):
def generate_data(self):
self.x = np.random.random([2, 2, 5, 5]).astype(self.data_type_).reshape(
[2, 10, 5]).transpose([0, 2, 1])
self.y = np.random.random([2, 2, 5, 5]).astype(self.data_type_)
self.fused_transpose_X = []
self.fused_reshape_X = []
self.fused_transpose_Y = [0, 2, 1]
self.fused_reshape_Y = [2, 10, 5]
self.out = np.matmul(self.x, self.y.reshape([2, 10, 5]))
class TestMatMulOpReshapeTranspose3DYInt8(TestMatMulOpReshapeTranspose3DYFloat):
def init_data_type(self):
self.data_type_ = 'int8'
@skip_check_grad_ci(reason="Tests inference only optimization.") @skip_check_grad_ci(reason="Tests inference only optimization.")
class TestMatMulOpTransposeReshapeEmptyFloat(OpTest): class TestMatMulOpTransposeReshapeEmptyFloat(OpTest):
def init_data_type(self): def init_data_type(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册