未验证 提交 856cb9c5 编写于 作者: J jakpiase 提交者: GitHub

Added matmul_v2+transpose+reshape fuse pass (#36481)

* added base changes for matmul_v2+trans+resh fuse pass

* added full matmul_v2+transpose+reshape pass

* removed a file added by mistake

* added reviewers suggestions

* Changed ops type in checking capatibility version

* Deteled one statement
上级 0ca2807c
......@@ -123,6 +123,7 @@ if(WITH_MKLDNN)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
pass_library(reshape_transpose_matmul_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(matmul_v2_transpose_reshape_fuse_pass inference DIR mkldnn)
pass_library(batch_norm_act_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_fuse_pass inference DIR mkldnn)
pass_library(multi_gru_seq_fuse_pass inference DIR mkldnn)
......@@ -192,7 +193,7 @@ endif()
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
cc_test(test_reshape_transpose_matmul_mkldnn_fuse_pass SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass)
cc_test(test_matmul_transpose_reshape_fuse_pass SRCS mkldnn/matmul_transpose_reshape_fuse_pass_tester.cc DEPS matmul_transpose_reshape_fuse_pass matmul_v2_transpose_reshape_fuse_pass)
cc_test(test_cpu_bfloat16_placement_pass SRCS mkldnn/cpu_bfloat16_placement_pass_tester.cc DEPS cpu_bfloat16_placement_pass)
cc_test(test_cpu_bfloat16_pass SRCS mkldnn/cpu_bfloat16_pass_tester.cc DEPS cpu_bfloat16_pass)
cc_test(test_multi_gru_fuse_pass SRCS mkldnn/multi_gru_fuse_pass_tester.cc DEPS multi_gru_fuse_pass)
......
......@@ -2697,16 +2697,18 @@ PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
return matmul_out;
}
PDNode *patterns::MatmulTransposeReshapePattern::operator()() {
// shared function for matmul and matmul_v2
PDNode *patterns::MatmulTransposeReshapePattern::operator()(
const std::string &op_name) {
auto reshape_op =
pattern->NewNode(reshape_op_repr())->assert_is_op("reshape2");
auto transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op("transpose2");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op(op_name);
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsInput()
->assert_is_op_output("matmul", "Out")
->assert_is_op_output(op_name, "Out")
->assert_is_op_input("transpose2", "X");
auto transpose_out = pattern->NewNode(transpose_out_repr())
......
......@@ -1546,7 +1546,7 @@ struct MatmulTransposeReshapePattern : public PatternBase {
const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_transpose_reshape") {}
PDNode* operator()();
PDNode* operator()(const std::string& op_name);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
......
......@@ -23,7 +23,9 @@ namespace framework {
namespace ir {
MatmulTransposeReshapeMKLDNNPass::MatmulTransposeReshapeMKLDNNPass() {
AddOpCompat(OpCompat("matmul"))
op_name_ = "matmul";
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
......@@ -89,7 +91,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
patterns::MatmulTransposeReshapePattern mtrp(gpd.mutable_pattern(),
name_scope_);
mtrp();
mtrp(op_name_);
int found_matmul_transpose_reshape_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
......@@ -98,7 +100,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
LOG(WARNING) << "Pass in op compat failed.";
return;
}
VLOG(4) << "handle matmul_transpose_reshape fuse";
VLOG(4) << "handle " + op_name_ + "_transpose_reshape fuse";
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, mtrp);
GET_IR_NODE_FROM_SUBGRAPH(transpose_op, transpose_op, mtrp);
......@@ -118,17 +120,17 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
const bool supported_transpose_axis = std::equal(
transpose_axis.begin(), transpose_axis.end(), supported_axis.begin());
if (transpose_out_size != 4) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
<< "supported rank is 4, received " << transpose_out_size;
return;
}
if (!supported_transpose_axis) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
<< "supported transpose axis for the fuse are {0, 2, 1, 3}";
return;
}
if (reshape_out_size != 3) {
VLOG(3) << "do not perform matmul_transpose_reshape fuse: "
VLOG(3) << "do not perform " + op_name_ + "_transpose_reshape fuse: "
<< "reshape_out supported rank is 3, received "
<< reshape_out_size;
return;
......@@ -152,7 +154,7 @@ void MatmulTransposeReshapeMKLDNNPass::ApplyImpl(ir::Graph *graph) const {
if (!Has("disable_logs") || !Get<bool>("disable_logs")) {
std::stringstream msg_ss;
msg_ss << "--- Fused " << found_matmul_transpose_reshape_count
<< " MatmulTransposeReshape patterns";
<< " MatmulTransposeReshape patterns for " + op_name_ + " Op";
paddle::string::PrettyLogDetail(msg_ss.str().c_str());
}
}
......
......@@ -31,6 +31,7 @@ class MatmulTransposeReshapeMKLDNNPass : public FusePassBase {
protected:
void ApplyImpl(Graph* graph) const override;
const std::string name_scope_{"matmul_transpose_reshape_fuse"};
std::string op_name_;
};
} // namespace ir
} // namespace framework
......
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h"
namespace paddle {
namespace framework {
......@@ -42,9 +42,15 @@ void SetOp(ProgramDesc *prog, const std::string &type,
op->SetAttr("transpose_X", true);
op->SetAttr("transpose_Y", true);
}
if (type == "matmul_v2") {
op->SetInput("Y", {inputs[1]});
op->SetAttr("use_mkldnn", true);
op->SetAttr("trans_x", true);
op->SetAttr("trans_y", true);
}
}
ProgramDesc BuildProgramDesc() {
ProgramDesc BuildProgramDesc(const std::string &op_name) {
ProgramDesc prog;
for (auto &v : std::initializer_list<std::string>(
{"a1", "a2", "b", "c", "cx", "d", "dx", "e"})) {
......@@ -52,7 +58,7 @@ ProgramDesc BuildProgramDesc() {
var->SetType(proto::VarType::SELECTED_ROWS);
}
SetOp(&prog, "matmul", {"a1", "a2"}, {"b"});
SetOp(&prog, op_name, {"a1", "a2"}, {"b"});
SetOp(&prog, "transpose2", {"b"}, {"c", "cx"});
SetOp(&prog, "reshape2", {"c"}, {"d", "dx"});
SetOp(&prog, "fc", {"d"}, {"e"});
......@@ -60,13 +66,13 @@ ProgramDesc BuildProgramDesc() {
return prog;
}
void MainTest(const ProgramDesc &prog) {
void MainTest(const ProgramDesc &prog, const std::string &op_name) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num = graph->Nodes().size();
auto pass =
PassRegistry::Instance().Get("matmul_transpose_reshape_fuse_pass");
PassRegistry::Instance().Get(op_name + "_transpose_reshape_fuse_pass");
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
......@@ -75,7 +81,7 @@ void MainTest(const ProgramDesc &prog) {
for (auto *node : graph->Nodes()) {
if (node->IsOp()) {
auto *op = node->Op();
if (op->Type() == "matmul") {
if (op->Type() == op_name) {
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_reshape_Out"),
std::vector<int>({4, 5, 6}));
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_transpose_Out"),
......@@ -85,12 +91,18 @@ void MainTest(const ProgramDesc &prog) {
}
}
TEST(MatmulTransposeReshapeFusePass, matmul_inputs) {
auto prog = BuildProgramDesc();
MainTest(prog);
TEST(MatmulTransposeReshapeFusePass, matmul_fuse_pass) {
auto prog = BuildProgramDesc("matmul");
MainTest(prog, "matmul");
}
TEST(MatmulTransposeReshapeFusePass, matmul_v2_fuse_pass) {
auto prog = BuildProgramDesc("matmul_v2");
MainTest(prog, "matmul_v2");
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(matmul_transpose_reshape_fuse_pass);
USE_PASS(matmul_v2_transpose_reshape_fuse_pass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/matmul_v2_transpose_reshape_fuse_pass.h"
#include <vector>
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
MatmulV2TransposeReshapeMKLDNNPass::MatmulV2TransposeReshapeMKLDNNPass() {
op_name_ = "matmul_v2";
AddOpCompat(OpCompat(op_name_))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsTensor()
.IsOptional()
.End()
.AddInput("ShapeTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(matmul_v2_transpose_reshape_fuse_pass,
paddle::framework::ir::MatmulV2TransposeReshapeMKLDNNPass);
REGISTER_PASS_CAPABILITY(matmul_v2_transpose_reshape_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("transpose2", 0)
.EQ("reshape2", 0));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
class MatmulV2TransposeReshapeMKLDNNPass
: public MatmulTransposeReshapeMKLDNNPass {
public:
MatmulV2TransposeReshapeMKLDNNPass();
virtual ~MatmulV2TransposeReshapeMKLDNNPass() {}
protected:
const std::string name_scope_{"matmul_v2_transpose_reshape_fuse"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -249,6 +249,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"scale_matmul_fuse_pass", //
"reshape_transpose_matmul_mkldnn_fuse_pass", //
"matmul_transpose_reshape_fuse_pass", //
"matmul_v2_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass",
......
......@@ -39,4 +39,12 @@ extra {
name: "op_device"
type: STRING
}
attrs {
name: "fused_reshape_Out"
type: INTS
}
attrs {
name: "fused_transpose_Out"
type: INTS
}
}
......@@ -90,8 +90,62 @@ class MatMulV2Op : public framework::OperatorWithKernel {
new_dims.push_back(1);
}
auto out_dims = framework::make_ddim(new_dims);
ctx->SetOutputDim("Out", out_dims);
auto ddim_out = framework::make_ddim(new_dims);
#ifdef PADDLE_WITH_MKLDNN
// if mkldnn matmul_v2+transpose+reshape fuse activated
auto reshape_out = ctx->Attrs().Get<std::vector<int>>("fused_reshape_Out");
auto transpose_out =
ctx->Attrs().Get<std::vector<int>>("fused_transpose_Out");
if (!reshape_out.empty() && !transpose_out.empty()) {
auto reshape_out_size = reshape_out.size();
auto transpose_out_size = transpose_out.size();
PADDLE_ENFORCE_EQ(transpose_out_size, 4,
platform::errors::InvalidArgument(
"transpose_out supported rank is 4, "
"received %d",
transpose_out_size));
const std::vector<int> supported_axis{0, 2, 1, 3};
const bool supported_transpose_axis = std::equal(
transpose_out.begin(), transpose_out.end(), supported_axis.begin());
PADDLE_ENFORCE_EQ(
supported_transpose_axis, true,
platform::errors::InvalidArgument(
"supported transpose axis for the fuse are {0, 2, 1, 3}"));
PADDLE_ENFORCE_EQ(
reshape_out_size, 3,
platform::errors::InvalidArgument("reshape_out supported rank is 3, "
"received %d",
reshape_out_size));
auto it = std::find(reshape_out.begin(), reshape_out.end(), -1);
// if "-1" is present then one of reshape dims must be infered
if (it != reshape_out.end()) {
int index = std::distance(reshape_out.begin(), it);
auto ddim_out_vec = framework::vectorize(ddim_out);
int ddim_out_product =
std::accumulate(ddim_out_vec.begin(), ddim_out_vec.end(), 1,
std::multiplies<int>());
int reshape_out_product = std::accumulate(
reshape_out.begin(), reshape_out.end(), -1, std::multiplies<int>());
reshape_out[index] = ddim_out_product / reshape_out_product;
}
framework::DDim shape_out =
ddim_out.transpose(transpose_out).reshape(reshape_out);
ctx->SetOutputDim("Out", shape_out);
} else {
ctx->SetOutputDim("Out", ddim_out);
}
#else
ctx->SetOutputDim("Out", ddim_out);
#endif
ctx->ShareLoD("X", /* --> */ "Out");
}
......@@ -139,6 +193,18 @@ class MatMulV2OpMaker : public framework::OpProtoAndCheckerMaker {
"Set true to transpose the last two dimensions of Y before "
"doing multiplication")
.SetDefault(false);
AddAttr<std::vector<int>>(
"fused_reshape_Out",
R"DOC(When MKLDNN matmul_v2_transpose_reshape fuse activated, "
"it's a shape atribute of fused reshape for `Out` output.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<std::vector<int>>(
"fused_transpose_Out",
R"DOC(When MKLDNN matmul_v2_transpose_reshape fuse activated, "
"it's a axis atribute of fused transpose for `Out` output.)DOC")
.SetDefault({})
.AsExtra();
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
......
......@@ -36,7 +36,8 @@ class MatMulV2MKLDNNHandler
MatMulV2MKLDNNHandler(const mkldnn::engine engine,
paddle::platform::Place cpu_place,
const std::vector<int64_t>& x_org_dims, bool trans_x,
const std::vector<int64_t>& y_org_dims, bool trans_y)
const std::vector<int64_t>& y_org_dims, bool trans_y,
bool is_output_fused)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::matmul>(engine,
cpu_place) {
// M X K * K X N
......@@ -86,6 +87,10 @@ class MatMulV2MKLDNNHandler
out_strides[i] = out_ddims[i + 1] * out_strides[i + 1];
}
if (is_output_fused) {
out_strides = FakeTransposeStrides(out_ddims);
}
auto x_md = memory::desc(x_dims, MKLDNNGetDataType<T>(), x_strides);
auto y_md = memory::desc(y_dims, MKLDNNGetDataType<T>(), y_strides);
auto out_md = memory::desc(out_ddims, MKLDNNGetDataType<T>(), out_strides);
......@@ -93,6 +98,24 @@ class MatMulV2MKLDNNHandler
this->AcquireForwardPrimitiveDescriptor(x_md, y_md, out_md);
}
std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t>& matmul_out_dims) const {
// fuse matmul_v2 + transpose + reshape guarantees that output is 4D and
// transpose axis are: {0, 2, 1, 3}
std::vector<int64_t> transpose_axis = {0, 2, 1, 3};
std::vector<int64_t> fake_strides(transpose_axis.size());
int ndims = static_cast<int>(transpose_axis.size());
int total_stride = 1;
for (int i = ndims - 1; i >= 0; --i) {
fake_strides[transpose_axis[i]] = total_stride;
total_stride *= matmul_out_dims[transpose_axis[i]];
}
return fake_strides;
}
std::shared_ptr<memory> AcquireWeightsMemory(const Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->weights_desc(),
......@@ -116,7 +139,8 @@ class MatMulV2MKLDNNKernel
bool trans_y, Tensor* out, std::vector<int64_t>& out_dims,
int execution_number = 0) const {
MatMulV2MKLDNNHandler<T> handler(onednn_engine, ctx.GetPlace(), x_dims,
trans_x, y_dims, trans_y);
trans_x, y_dims, trans_y,
IsOutputFused(ctx));
const auto src_memory_p = handler.AcquireSrcMemory(x);
const auto weights_memory_p = handler.AcquireWeightsMemory(y);
......@@ -133,9 +157,10 @@ class MatMulV2MKLDNNKernel
matmul_p->execute(astream, matmul_args);
astream.wait();
auto format = paddle::platform::MKLDNNFormatForSize(
out->dims().size(), dnnl::memory::format_tag::nchw);
out->set_layout(paddle::framework::DataLayout::kMKLDNN);
out->set_format(
GetMKLDNNFormat(dst_memory_p->get_desc().reshape(out_dims)));
out->set_format(format);
}
private:
......@@ -166,7 +191,8 @@ class MatMulV2MKLDNNKernel
}
}
if ((y_dims.size() == x_dims.size()) && y_dims.size() > 2) {
if ((y_dims.size() == x_dims.size()) && y_dims.size() > 2 &&
!IsOutputFused(ctx)) {
for (size_t i = 0; i < x_dims.size() - 2; ++i) {
PADDLE_ENFORCE_EQ(
x_dims[i] == y_dims[i] || x_dims[i] == 1 || y_dims[i] == 1, true,
......@@ -181,6 +207,13 @@ class MatMulV2MKLDNNKernel
}
}
bool IsOutputFused(const ExecutionContext& ctx) const {
auto& fused_reshape_Out = ctx.Attr<std::vector<int>>("fused_reshape_Out");
auto& fused_transpose_Out =
ctx.Attr<std::vector<int>>("fused_transpose_Out");
return !fused_reshape_Out.empty() && !fused_transpose_Out.empty();
}
void RunKernel(const ExecutionContext& ctx) const {
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import PassVersionChecker
class TestMatmulV2OneDNNTransposeReshapeFusePass(InferencePassTest):
def setUp(self):
self.set_params()
self.tranpose_perm = [0, 2, 1, 3]
self.pass_name = 'matmul_v2_transpose_reshape_fuse_pass'
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=self.data_shape, dtype="float32")
weight = fluid.layers.create_parameter(
shape=self.weight_shape, dtype="float32")
matmul = paddle.matmul(
data,
weight,
transpose_x=self.transpose_x,
transpose_y=self.transpose_y)
transpose = fluid.layers.transpose(matmul, self.tranpose_perm)
reshape = fluid.layers.reshape(transpose, shape=self.reshape_shape)
self.fetch_list = [reshape]
self.enable_mkldnn = True
def set_params(self):
self.data_shape = [-1, 3, 100, 110]
self.weight_shape = [1, 3, 110, 100]
self.feeds = {
"data": np.random.random((1, 3, 100, 110)).astype("float32")
}
self.transpose_x = False
self.transpose_y = False
self.reshape_shape = [3, 100, 100]
def test_check_output(self):
use_gpu = False
self.check_output_with_option(use_gpu)
def test_pass_compatible(self):
self.assertTrue(PassVersionChecker.IsCompatible(self.pass_name))
class TestMatmulV2OneDNNTransposeReshapeFusePassDifferentDims(
TestMatmulV2OneDNNTransposeReshapeFusePass):
def set_params(self):
self.data_shape = [-1, 4, 100, 80]
self.weight_shape = [1, 4, 80, 100]
self.feeds = {
"data": np.random.random((1, 4, 100, 80)).astype("float32")
}
self.transpose_x = True
self.transpose_y = True
self.reshape_shape = [8, 40, 80]
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -440,9 +440,11 @@ class TestMatMulOpTransposeReshapeEmptyFloat(OpTest):
self.reshape_out = []
self.out = np.matmul(self.x, self.y)
def setUp(self):
os.environ["DNNL_MAX_CPU_ISA"] = "AVX"
def set_op_type(self):
self.op_type = "matmul"
def setUp(self):
self.set_op_type()
self._cpu_only = True
self.use_mkldnn = True
self.init_data_type()
......
......@@ -23,6 +23,13 @@ import paddle.fluid.core as core
import paddle
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from paddle.fluid.tests.unittests.mkldnn.test_matmul_mkldnn_op import (
TestMatMulOpTransposeReshapeEmptyFloat,
TestMatMulOpTransposeReshapeBasicFloat,
TestMatMulOpTransposeReshapeOtherDimFloat,
TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException,
TestMatMulOpTransposeReshapeTransposeRankNotSupportedException,
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException)
def reference_matmul(X, Y, transpose_x=False, transpose_y=False):
......@@ -390,6 +397,43 @@ create_bf16_test_class(TestMatMulV2MatrixXMatrix5DTranposeYOneDNNOp)
create_bf16_test_class(TestMatMulV2MatrixXMatrix6Dx2DOneDNNOp)
create_bf16_test_class(TestMatMulV2MatrixXMatrix2Dx5DOneDNNOp)
class TestMatMulV2OpTransposeReshapeEmptyFloat(
TestMatMulOpTransposeReshapeEmptyFloat):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeBasicFloat(
TestMatMulOpTransposeReshapeBasicFloat):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeOtherDimFloat(
TestMatMulOpTransposeReshapeOtherDimFloat):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeTransposeAxisNotSupportedException(
TestMatMulOpTransposeReshapeTransposeAxisNotSupportedException):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeRankOfReshapeNotSupportedException(
TestMatMulOpTransposeReshapeRankOfReshapeNotSupportedException):
def set_op_type(self):
self.op_type = "matmul_v2"
class TestMatMulV2OpTransposeReshapeTransposeRankNotSupportedException(
TestMatMulOpTransposeReshapeTransposeRankNotSupportedException):
def set_op_type(self):
self.op_type = "matmul_v2"
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册