未验证 提交 02296977 编写于 作者: S Sławomir Siwek 提交者: GitHub

Extract fused_transpose op dedicated for oneDNN fuse passes (#50021)

* extract common methods to reuse

* add header for transpose ops

* fused_transpose

* Split big function

* transpose2 tests

* fused_transpose

* Apply extra attributes

* add pbtxt file

* update pbtxt

* Merge develop

* add more strict op compats

* code  style

* remove mkldnn_data_type

* unify SetOutMemDescWithReshape2FuseSupport

* adjust quantize-dequantize for transpose

* remove appendact

* transpose2 quantization

* fix int8 tests

* adjust transpose_op to current develop

* delete fusion code from transpose_kernel

* add fused transpose to NHWC unittest

* change order
上级 06cb6553
...@@ -979,7 +979,8 @@ PDNode *patterns::OperatorActivation::operator()( ...@@ -979,7 +979,8 @@ PDNode *patterns::OperatorActivation::operator()(
return activation_out; return activation_out;
} }
PDNode *patterns::QuantTranspose2::operator()() { PDNode *patterns::QuantTranspose::operator()(
const std::string &transpose_type) {
auto *quant_in = pattern->NewNode(quant_in_repr()) auto *quant_in = pattern->NewNode(quant_in_repr())
->AsInput() ->AsInput()
->assert_is_op_input("quantize", "Input"); ->assert_is_op_input("quantize", "Input");
...@@ -989,19 +990,20 @@ PDNode *patterns::QuantTranspose2::operator()() { ...@@ -989,19 +990,20 @@ PDNode *patterns::QuantTranspose2::operator()() {
->AsIntermediate() ->AsIntermediate()
->assert_has_n_outputs(1) ->assert_has_n_outputs(1)
->assert_is_op_output("quantize") ->assert_is_op_output("quantize")
->assert_is_op_input("transpose2", "X"); ->assert_is_op_input(transpose_type, "X");
auto *transpose2_op = auto *transpose_op =
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2"); pattern->NewNode(transpose_op_repr())->assert_is_op(transpose_type);
quant_op->LinksFrom({quant_in}).LinksTo({quant_out}); quant_op->LinksFrom({quant_in}).LinksTo({quant_out});
transpose2_op->LinksFrom({quant_out}); transpose_op->LinksFrom({quant_out});
return transpose2_op; return transpose_op;
} }
PDNode *patterns::Transpose2Dequant::operator()() { PDNode *patterns::TransposeDequant::operator()(
auto *transpose2_op = const std::string &transpose_type) {
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2"); auto *transpose_op =
pattern->NewNode(transpose_op_repr())->assert_is_op(transpose_type);
auto dequant_in = pattern->NewNode(dequant_in_repr()) auto dequant_in = pattern->NewNode(dequant_in_repr())
->AsIntermediate() ->AsIntermediate()
->assert_has_n_inputs(1) ->assert_has_n_inputs(1)
...@@ -1012,7 +1014,7 @@ PDNode *patterns::Transpose2Dequant::operator()() { ...@@ -1012,7 +1014,7 @@ PDNode *patterns::Transpose2Dequant::operator()() {
->AsOutput() ->AsOutput()
->assert_is_op_output("dequantize", "Output"); ->assert_is_op_output("dequantize", "Output");
transpose2_op->LinksTo({dequant_in}); transpose_op->LinksTo({dequant_in});
dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out}); dequant_op->LinksFrom({dequant_in}).LinksTo({dequant_out});
return dequant_out; return dequant_out;
} }
......
...@@ -552,24 +552,24 @@ struct OperatorActivation : public PatternBase { ...@@ -552,24 +552,24 @@ struct OperatorActivation : public PatternBase {
PATTERN_DECL_NODE(activation_out); PATTERN_DECL_NODE(activation_out);
}; };
struct QuantTranspose2 : public PatternBase { struct QuantTranspose : public PatternBase {
QuantTranspose2(PDPattern* pattern, const std::string& name_scope) QuantTranspose(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "quant_transpose2") {} : PatternBase(pattern, name_scope, "quant_transpose") {}
PDNode* operator()(); PDNode* operator()(const std::string& transpose_type);
PATTERN_DECL_NODE(quant_in); PATTERN_DECL_NODE(quant_in);
PATTERN_DECL_NODE(quant_op); PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(quant_out); PATTERN_DECL_NODE(quant_out);
PATTERN_DECL_NODE(transpose2_op); PATTERN_DECL_NODE(transpose_op);
}; };
struct Transpose2Dequant : public PatternBase { struct TransposeDequant : public PatternBase {
Transpose2Dequant(PDPattern* pattern, const std::string& name_scope) TransposeDequant(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "transpose2_dequant") {} : PatternBase(pattern, name_scope, "transpose_dequant") {}
PDNode* operator()(); PDNode* operator()(const std::string& transpose_type);
PATTERN_DECL_NODE(transpose2_op); PATTERN_DECL_NODE(transpose_op);
PATTERN_DECL_NODE(dequant_in); PATTERN_DECL_NODE(dequant_in);
PATTERN_DECL_NODE(dequant_op); PATTERN_DECL_NODE(dequant_op);
PATTERN_DECL_NODE(dequant_out); PATTERN_DECL_NODE(dequant_out);
......
...@@ -492,6 +492,7 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const { ...@@ -492,6 +492,7 @@ void ComputePropagateScalesMkldnnPass::ApplyImpl(ir::Graph* graph) const {
FusePassBase::Init(pattern_name, graph); FusePassBase::Init(pattern_name, graph);
const std::unordered_set<std::string> scale_immutable_ops = { const std::unordered_set<std::string> scale_immutable_ops = {
"fused_transpose"
"transpose2", "transpose2",
"reshape2", "reshape2",
"pool2d", "pool2d",
......
...@@ -1305,7 +1305,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const { ...@@ -1305,7 +1305,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeMatmul(graph, false /* with_residual_data */); QuantizeMatmul(graph, false /* with_residual_data */);
QuantizeMatmul(graph, true /* with_residual_data */); QuantizeMatmul(graph, true /* with_residual_data */);
QuantizeImmutable(graph, "reshape2", "X"); QuantizeImmutable(graph, "reshape2", "X");
QuantizeImmutable(graph, "transpose2", "X"); QuantizeImmutable(graph, "fused_transpose", "X");
QuantizeImmutable(graph, "slice", "Input"); QuantizeImmutable(graph, "slice", "Input");
QuantizeImmutable(graph, "nearest_interp", "X"); QuantizeImmutable(graph, "nearest_interp", "X");
QuantizeImmutable(graph, "nearest_interp_v2", "X"); QuantizeImmutable(graph, "nearest_interp_v2", "X");
......
...@@ -59,8 +59,9 @@ void SetOp(ProgramDesc* prog, ...@@ -59,8 +59,9 @@ void SetOp(ProgramDesc* prog,
op->SetAttr("fuse_residual_connection", false); op->SetAttr("fuse_residual_connection", false);
} }
op->SetOutput("Output", {outputs[0]}); op->SetOutput("Output", {outputs[0]});
} else if (type == "pool2d" || type == "transpose2" || type == "reshape2" || } else if (type == "pool2d" || type == "fused_transpose" ||
type == "nearest_interp" || type == "nearest_interp_v2") { type == "reshape2" || type == "nearest_interp" ||
type == "nearest_interp_v2") {
op->SetInput("X", {inputs[0]}); op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]}); op->SetOutput("Out", {outputs[0]});
} else if (type == "slice") { } else if (type == "slice") {
...@@ -558,7 +559,7 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) { ...@@ -558,7 +559,7 @@ void TestImmutableOpWithManyOutputs(const std::string tested_op) {
} }
const std::vector<std::string> immutables = {"reshape2", const std::vector<std::string> immutables = {"reshape2",
"transpose2", "fused_transpose",
"slice", "slice",
"nearest_interp", "nearest_interp",
"nearest_interp_v2", "nearest_interp_v2",
......
...@@ -43,6 +43,7 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const { ...@@ -43,6 +43,7 @@ void CPUQuantizePlacementPass::ApplyImpl(ir::Graph* graph) const {
"pool2d", "pool2d",
"prior_box", "prior_box",
"reshape2", "reshape2",
"fused_transpose",
"transpose2", "transpose2",
"fusion_gru", "fusion_gru",
"fusion_lstm", "fusion_lstm",
......
...@@ -161,7 +161,8 @@ inline void ConvertToFusedOp(OpDesc* op) { ...@@ -161,7 +161,8 @@ inline void ConvertToFusedOp(OpDesc* op) {
{"depthwise_conv2d", "fused_conv2d"}, {"depthwise_conv2d", "fused_conv2d"},
{"matmul", "fused_matmul"}, {"matmul", "fused_matmul"},
{"matmul_v2", "fused_matmul"}, {"matmul_v2", "fused_matmul"},
{"softplus", "fused_softplus"}}; {"softplus", "fused_softplus"},
{"transpose2", "fused_transpose"}};
if (op->Type() == "matmul") { if (op->Type() == "matmul") {
op->SetAttr("trans_x", op->GetAttr("transpose_X")); op->SetAttr("trans_x", op->GetAttr("transpose_X"));
...@@ -173,6 +174,8 @@ inline void ConvertToFusedOp(OpDesc* op) { ...@@ -173,6 +174,8 @@ inline void ConvertToFusedOp(OpDesc* op) {
if (it != fused_ops.end()) { if (it != fused_ops.end()) {
op->SetType(it->second); op->SetType(it->second);
VLOG(3) << "Converted " << it->first << " to " << it->second; VLOG(3) << "Converted " << it->first << " to " << it->second;
} else {
VLOG(3) << "Fused op for " << op->Type() << " is not implemented yet.";
} }
} }
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -28,7 +29,7 @@ void FuseOperatorReshape2OneDNNPass::ApplyImpl(Graph *graph) const { ...@@ -28,7 +29,7 @@ void FuseOperatorReshape2OneDNNPass::ApplyImpl(Graph *graph) const {
// THIS FUSE WILL WORK ONLY WITH OPERATORS THAT OUTPUTS PLAIN MEMORY, F.E. // THIS FUSE WILL WORK ONLY WITH OPERATORS THAT OUTPUTS PLAIN MEMORY, F.E.
// ABCD FOR 4D! BE AWARE OF THAT! // ABCD FOR 4D! BE AWARE OF THAT!
std::vector<std::pair<std::string, int>> ops_and_outputs = { std::vector<std::pair<std::string, int>> ops_and_outputs = {
{"fc", 1}, {"transpose2", 2}}; {"fc", 1}, {"fused_transpose", 2}, {"transpose2", 2}};
for (const auto &op_and_outputs : ops_and_outputs) for (const auto &op_and_outputs : ops_and_outputs)
FuseReshape2(graph, op_and_outputs.first, op_and_outputs.second); FuseReshape2(graph, op_and_outputs.first, op_and_outputs.second);
...@@ -114,6 +115,7 @@ void FuseOperatorReshape2OneDNNPass::FuseReshape2(Graph *graph, ...@@ -114,6 +115,7 @@ void FuseOperatorReshape2OneDNNPass::FuseReshape2(Graph *graph,
return; return;
} }
ConvertToFusedOp(operator_op->Op());
operator_op->Op()->SetAttr("fused_reshape2_shape", reshape2_shape); operator_op->Op()->SetAttr("fused_reshape2_shape", reshape2_shape);
operator_op->Op()->SetOutput("Out", {reshape2_out->Name()}); operator_op->Op()->SetOutput("Out", {reshape2_out->Name()});
...@@ -140,5 +142,7 @@ REGISTER_PASS(operator_reshape2_onednn_fuse_pass, ...@@ -140,5 +142,7 @@ REGISTER_PASS(operator_reshape2_onednn_fuse_pass,
REGISTER_PASS_CAPABILITY(operator_reshape2_onednn_fuse_pass) REGISTER_PASS_CAPABILITY(operator_reshape2_onednn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.GE("reshape2", 0) .EQ("fused_transpose", 0)
.GE("fc", 0)); .EQ("transpose2", 0)
.EQ("reshape2", 0)
.EQ("fc", 0));
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -26,7 +27,7 @@ using string::PrettyLogDetail; ...@@ -26,7 +27,7 @@ using string::PrettyLogDetail;
void FuseOperatorUnsqueeze2OneDNNPass::ApplyImpl(Graph *graph) const { void FuseOperatorUnsqueeze2OneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::pair<std::string, int>> ops_and_outputs = { std::vector<std::pair<std::string, int>> ops_and_outputs = {
{"transpose2", 2}, {"elementwise_mul", 1}}; {"fused_transpose", 2}, {"transpose2", 2}, {"elementwise_mul", 1}};
for (const auto &op_and_outputs : ops_and_outputs) for (const auto &op_and_outputs : ops_and_outputs)
FuseUnsqueeze2(graph, op_and_outputs.first, op_and_outputs.second); FuseUnsqueeze2(graph, op_and_outputs.first, op_and_outputs.second);
...@@ -89,6 +90,7 @@ void FuseOperatorUnsqueeze2OneDNNPass::FuseUnsqueeze2( ...@@ -89,6 +90,7 @@ void FuseOperatorUnsqueeze2OneDNNPass::FuseUnsqueeze2(
return; return;
} }
ConvertToFusedOp(operator_op->Op());
operator_op->Op()->SetAttr("fused_unsqueeze2_axes", unsqueeze2_axes); operator_op->Op()->SetAttr("fused_unsqueeze2_axes", unsqueeze2_axes);
operator_op->Op()->SetOutput("Out", {unsqueeze2_out->Name()}); operator_op->Op()->SetOutput("Out", {unsqueeze2_out->Name()});
...@@ -115,5 +117,6 @@ REGISTER_PASS(operator_unsqueeze2_onednn_fuse_pass, ...@@ -115,5 +117,6 @@ REGISTER_PASS(operator_unsqueeze2_onednn_fuse_pass,
REGISTER_PASS_CAPABILITY(operator_unsqueeze2_onednn_fuse_pass) REGISTER_PASS_CAPABILITY(operator_unsqueeze2_onednn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.GE("unsqueeze2", 0) .EQ("unsqueeze2", 0)
.GE("transpose2", 0)); .EQ("fused_transpose", 0)
.EQ("transpose2", 0));
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/quant_transpose2_dequant_onednn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
...@@ -21,15 +22,15 @@ namespace framework { ...@@ -21,15 +22,15 @@ namespace framework {
namespace ir { namespace ir {
void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2( void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2(
Graph *graph) const { Graph *graph, const std::string &transpose_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope, graph); FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::QuantTranspose2 quant_transpose2_pattern(gpd.mutable_pattern(), patterns::QuantTranspose quant_transpose2_pattern(gpd.mutable_pattern(),
name_scope); name_scope);
quant_transpose2_pattern(); quant_transpose2_pattern(transpose_type);
int found_patterns_count = 0; int found_patterns_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
...@@ -42,10 +43,10 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2( ...@@ -42,10 +43,10 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2(
GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, quant_transpose2_pattern); GET_IR_NODE_FROM_SUBGRAPH(quant_op, quant_op, quant_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, quant_transpose2_pattern); GET_IR_NODE_FROM_SUBGRAPH(quant_out, quant_out, quant_transpose2_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
transpose2_op, transpose2_op, quant_transpose2_pattern); transpose_op, transpose_op, quant_transpose2_pattern);
if (!transpose2_op->Op()->HasAttr("use_mkldnn") || if (!transpose_op->Op()->HasAttr("use_mkldnn") ||
!(PADDLE_GET_CONST(bool, transpose2_op->Op()->GetAttr("use_mkldnn")))) { !(PADDLE_GET_CONST(bool, transpose_op->Op()->GetAttr("use_mkldnn")))) {
VLOG(4) VLOG(4)
<< "Only oneDNN version of transpose2 can be fused with quantize."; << "Only oneDNN version of transpose2 can be fused with quantize.";
return; return;
...@@ -60,8 +61,9 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2( ...@@ -60,8 +61,9 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2(
? PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Shift")) ? PADDLE_GET_CONST(float, quant_op->Op()->GetAttr("Shift"))
: 0; : 0;
transpose2_op->Op()->SetAttr("scale", scale); ConvertToFusedOp(transpose_op->Op());
transpose2_op->Op()->SetAttr("shift", shift); transpose_op->Op()->SetAttr("scale", scale);
transpose_op->Op()->SetAttr("shift", shift);
bool is_negative_output = bool is_negative_output =
quant_op->Op()->HasAttr("is_negative_input") quant_op->Op()->HasAttr("is_negative_input")
...@@ -81,11 +83,11 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2( ...@@ -81,11 +83,11 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2(
} else { } else {
output_dtype = "uint8"; output_dtype = "uint8";
} }
transpose2_op->Op()->SetAttr("output_data_type", output_dtype); transpose_op->Op()->SetAttr("output_data_type", output_dtype);
transpose2_op->Op()->SetInput("X", transpose_op->Op()->SetInput("X",
std::vector<std::string>({quant_in->Name()})); std::vector<std::string>({quant_in->Name()}));
IR_NODE_LINK_TO(quant_in, transpose2_op); IR_NODE_LINK_TO(quant_in, transpose_op);
GraphSafeRemoveNodes(graph, {quant_op, quant_out}); GraphSafeRemoveNodes(graph, {quant_op, quant_out});
found_patterns_count++; found_patterns_count++;
}; };
...@@ -98,15 +100,15 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2( ...@@ -98,15 +100,15 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseQuantizeTranspose2(
} }
void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize( void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize(
Graph *graph) const { Graph *graph, const std::string &transpose_type) const {
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr.")); graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(name_scope, graph); FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::Transpose2Dequant transpose2_dequant_pattern(gpd.mutable_pattern(), patterns::TransposeDequant transpose2_dequant_pattern(gpd.mutable_pattern(),
name_scope); name_scope);
transpose2_dequant_pattern(); transpose2_dequant_pattern(transpose_type);
int found_patterns_count = 0; int found_patterns_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
...@@ -116,7 +118,7 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize( ...@@ -116,7 +118,7 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize(
return; return;
} }
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
transpose2_op, transpose2_op, transpose2_dequant_pattern); transpose_op, transpose_op, transpose2_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
dequant_in, dequant_in, transpose2_dequant_pattern); dequant_in, dequant_in, transpose2_dequant_pattern);
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
...@@ -124,8 +126,8 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize( ...@@ -124,8 +126,8 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize(
GET_IR_NODE_FROM_SUBGRAPH( GET_IR_NODE_FROM_SUBGRAPH(
dequant_out, dequant_out, transpose2_dequant_pattern); dequant_out, dequant_out, transpose2_dequant_pattern);
if (!transpose2_op->Op()->HasAttr("use_mkldnn") || if (!transpose_op->Op()->HasAttr("use_mkldnn") ||
!(PADDLE_GET_CONST(bool, transpose2_op->Op()->GetAttr("use_mkldnn")))) { !(PADDLE_GET_CONST(bool, transpose_op->Op()->GetAttr("use_mkldnn")))) {
VLOG(4) VLOG(4)
<< "Only oneDNN version of transpose2 can be fused with dequantize."; << "Only oneDNN version of transpose2 can be fused with dequantize.";
return; return;
...@@ -141,13 +143,14 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize( ...@@ -141,13 +143,14 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize(
? PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("Shift")) ? PADDLE_GET_CONST(float, dequant_op->Op()->GetAttr("Shift"))
: 0; : 0;
transpose2_op->Op()->SetAttr("scale", reorder_scale); ConvertToFusedOp(transpose_op->Op());
transpose2_op->Op()->SetAttr("shift", shift); transpose_op->Op()->SetAttr("scale", reorder_scale);
transpose2_op->Op()->SetAttr("output_data_type", std::string("fp32")); transpose_op->Op()->SetAttr("shift", shift);
transpose2_op->Op()->SetOutput( transpose_op->Op()->SetAttr("output_data_type", std::string("fp32"));
transpose_op->Op()->SetOutput(
"Out", std::vector<std::string>({dequant_out->Name()})); "Out", std::vector<std::string>({dequant_out->Name()}));
IR_NODE_LINK_TO(transpose2_op, dequant_out); IR_NODE_LINK_TO(transpose_op, dequant_out);
GraphSafeRemoveNodes(graph, {dequant_in, dequant_op}); GraphSafeRemoveNodes(graph, {dequant_in, dequant_op});
found_patterns_count++; found_patterns_count++;
}; };
...@@ -161,8 +164,10 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize( ...@@ -161,8 +164,10 @@ void FuseQuantTranspose2DequantOneDNNPass::FuseTranspose2Dequantize(
} }
void FuseQuantTranspose2DequantOneDNNPass::ApplyImpl(Graph *graph) const { void FuseQuantTranspose2DequantOneDNNPass::ApplyImpl(Graph *graph) const {
FuseQuantizeTranspose2(graph); FuseQuantizeTranspose2(graph, "fused_transpose");
FuseTranspose2Dequantize(graph); FuseTranspose2Dequantize(graph, "fused_transpose");
FuseQuantizeTranspose2(graph, "transpose2");
FuseTranspose2Dequantize(graph, "transpose2");
} }
FuseQuantTranspose2DequantOneDNNPass::FuseQuantTranspose2DequantOneDNNPass() { FuseQuantTranspose2DequantOneDNNPass::FuseQuantTranspose2DequantOneDNNPass() {
...@@ -180,6 +185,21 @@ FuseQuantTranspose2DequantOneDNNPass::FuseQuantTranspose2DequantOneDNNPass() { ...@@ -180,6 +185,21 @@ FuseQuantTranspose2DequantOneDNNPass::FuseQuantTranspose2DequantOneDNNPass() {
.AddAttr("axis") .AddAttr("axis")
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End(); .End();
AddOpCompat(OpCompat("fused_transpose"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddOutput("XShape")
.IsOptional()
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
} }
} // namespace ir } // namespace ir
......
...@@ -28,8 +28,10 @@ class FuseQuantTranspose2DequantOneDNNPass : public FusePassBase { ...@@ -28,8 +28,10 @@ class FuseQuantTranspose2DequantOneDNNPass : public FusePassBase {
protected: protected:
void ApplyImpl(Graph *graph) const override; void ApplyImpl(Graph *graph) const override;
void FuseQuantizeTranspose2(Graph *graph) const; void FuseQuantizeTranspose2(Graph *graph,
void FuseTranspose2Dequantize(Graph *graph) const; const std::string &transpose_type) const;
void FuseTranspose2Dequantize(Graph *graph,
const std::string &transpose_type) const;
private: private:
std::string name_scope = "quant_transpose2_dequant_onednn_fuse_pass"; std::string name_scope = "quant_transpose2_dequant_onednn_fuse_pass";
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h" #include "paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h" #include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/utils/string/pretty_log.h" #include "paddle/utils/string/pretty_log.h"
...@@ -59,6 +60,8 @@ void FuseSqueeze2Transpose2OneDNNPass::ApplyImpl(Graph *graph) const { ...@@ -59,6 +60,8 @@ void FuseSqueeze2Transpose2OneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<int> squeeze2_axes = std::vector<int> squeeze2_axes =
PADDLE_GET_CONST(std::vector<int>, squeeze2_op->Op()->GetAttr("axes")); PADDLE_GET_CONST(std::vector<int>, squeeze2_op->Op()->GetAttr("axes"));
ConvertToFusedOp(transpose2_op->Op());
transpose2_op->Op()->SetAttr("fused_squeeze2_axes", squeeze2_axes); transpose2_op->Op()->SetAttr("fused_squeeze2_axes", squeeze2_axes);
transpose2_op->Op()->SetInput("X", {squeeze2_op_in->Name()}); transpose2_op->Op()->SetInput("X", {squeeze2_op_in->Name()});
...@@ -83,5 +86,5 @@ REGISTER_PASS(squeeze2_transpose2_onednn_fuse_pass, ...@@ -83,5 +86,5 @@ REGISTER_PASS(squeeze2_transpose2_onednn_fuse_pass,
REGISTER_PASS_CAPABILITY(squeeze2_transpose2_onednn_fuse_pass) REGISTER_PASS_CAPABILITY(squeeze2_transpose2_onednn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.GE("squeeze2", 0) .EQ("squeeze2", 0)
.GE("transpose2", 0)); .EQ("transpose2", 0));
...@@ -129,7 +129,8 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs( ...@@ -129,7 +129,8 @@ void AnalysisPredictor::MkldnnQuantizer::CalculateScalesForOpOutputs(
is_unsigned = (fuse_activation == "relu" || fuse_activation == "relu6"); is_unsigned = (fuse_activation == "relu" || fuse_activation == "relu6");
} else if (op->Type() == "relu") { } else if (op->Type() == "relu") {
is_unsigned = true; is_unsigned = true;
} else if (op->Type() == "transpose2" || op->Type() == "reshape2" || } else if (op->Type() == "transpose2" ||
op->Type() == "fused_transpose" || op->Type() == "reshape2" ||
op->Type() == "pool2d" || op->Type() == "nearest_interp" || op->Type() == "pool2d" || op->Type() == "nearest_interp" ||
op->Type() == "nearest_interp_v2" || op->Type() == "split") { op->Type() == "nearest_interp_v2" || op->Type() == "split") {
auto input_var_name = op->Input("X")[0]; auto input_var_name = op->Input("X")[0];
......
...@@ -47,6 +47,8 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() { ...@@ -47,6 +47,8 @@ MkldnnQuantizerConfig::MkldnnQuantizerConfig() {
// input data and assign to Quantize and Dequantize scale. // input data and assign to Quantize and Dequantize scale.
rules_["transpose2"]["X"] = ScaleAlgo::KL; rules_["transpose2"]["X"] = ScaleAlgo::KL;
rules_["transpose2"]["Out"] = ScaleAlgo::NONE; rules_["transpose2"]["Out"] = ScaleAlgo::NONE;
rules_["fused_transpose"]["X"] = ScaleAlgo::KL;
rules_["fused_transpose"]["Out"] = ScaleAlgo::NONE;
rules_["slice"]["Input"] = ScaleAlgo::KL; rules_["slice"]["Input"] = ScaleAlgo::KL;
rules_["slice"]["Out"] = ScaleAlgo::NONE; rules_["slice"]["Out"] = ScaleAlgo::NONE;
......
...@@ -280,8 +280,12 @@ TEST(Analyzer_int8_mobilenet_ssd, quantization) { ...@@ -280,8 +280,12 @@ TEST(Analyzer_int8_mobilenet_ssd, quantization) {
if (FLAGS_enable_mkldnn) { if (FLAGS_enable_mkldnn) {
q_cfg.EnableMkldnnQuantizer(); q_cfg.EnableMkldnnQuantizer();
q_cfg.mkldnn_quantizer_config(); q_cfg.mkldnn_quantizer_config();
std::unordered_set<std::string> quantize_operators( std::unordered_set<std::string> quantize_operators({"conv2d",
{"conv2d", "depthwise_conv2d", "prior_box", "transpose2", "reshape2"}); "depthwise_conv2d",
"prior_box",
"fused_transpose",
"transpose2",
"reshape2"});
q_cfg.mkldnn_quantizer_config()->SetEnabledOpTypes(quantize_operators); q_cfg.mkldnn_quantizer_config()->SetEnabledOpTypes(quantize_operators);
q_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data); q_cfg.mkldnn_quantizer_config()->SetWarmupData(warmup_data);
q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize( q_cfg.mkldnn_quantizer_config()->SetWarmupBatchSize(
......
type: "fused_transpose"
def {
inputs {
name: "X"
}
outputs {
name: "Out"
}
outputs {
name: "XShape"
}
attrs {
name: "axis"
type: INTS
}
}
extra {
attrs{
name: "fused_squeeze2_axes"
type: INTS
}
attrs{
name: "fused_unsqueeze2_axes"
type: INTS
}
attrs{
name: "fused_reshape2_shape"
type: INTS
}
attrs{
name: "scale"
type: FLOAT
}
attrs{
name: "shift"
type: FLOAT
}
attrs{
name: "output_data_type"
type: STRING
}
}
...@@ -20,10 +20,6 @@ extra { ...@@ -20,10 +20,6 @@ extra {
name: "use_mkldnn" name: "use_mkldnn"
type: BOOLEAN type: BOOLEAN
} }
attrs {
name: "use_quantizer"
type: BOOLEAN
}
attrs { attrs {
name: "mkldnn_data_type" name: "mkldnn_data_type"
type: STRING type: STRING
...@@ -49,4 +45,3 @@ extra { ...@@ -49,4 +45,3 @@ extra {
type: STRING type: STRING
} }
} }
...@@ -23,10 +23,6 @@ extra { ...@@ -23,10 +23,6 @@ extra {
name: "use_mkldnn" name: "use_mkldnn"
type: BOOLEAN type: BOOLEAN
} }
attrs {
name: "use_quantizer"
type: BOOLEAN
}
attrs { attrs {
name: "mkldnn_data_type" name: "mkldnn_data_type"
type: STRING type: STRING
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/operators/transpose_op.h"
namespace paddle {
namespace operators {
class FusedTransposeOpMaker : public Transpose2OpMaker {
protected:
void Apply() override {
AddAttr<std::vector<int>>("fused_squeeze2_axes",
"Axes from squeeze2 operator obtained from "
"squeeze2_transpose2_onednn_fuse_pass")
.SetDefault({});
AddAttr<std::vector<int>>("fused_unsqueeze2_axes",
"Axes from unsqueeze2 operator obtained from "
"operator_unsqueeze2_onednn_fuse_pass")
.SetDefault({});
AddAttr<std::vector<int>>("fused_reshape2_shape",
"Shape from reshape2 operator obtained from "
"operator_reshape2_onednn_fuse_pass")
.SetDefault({});
AddAttr<float>("scale",
"Obtained from quant_transpose2_dequant_onednn_fuse_pass")
.SetDefault(1.0f);
AddAttr<float>("shift",
"Obtained from quant_transpose2_dequant_onednn_fuse_pass")
.SetDefault(0.0f);
AddAttr<std::string>(
"output_data_type",
"Obtained from quant_transpose2_dequant_onednn_fuse_pass")
.SetDefault("");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
fused_transpose,
ops::Transpose2Op,
ops::FusedTransposeOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -11,6 +11,7 @@ cc_test_old( ...@@ -11,6 +11,7 @@ cc_test_old(
generated_op generated_op
pooling pooling
transpose_op transpose_op
fused_transpose_op
scope scope
device_context device_context
enforce enforce
......
...@@ -32,6 +32,8 @@ USE_OP_ITSELF(relu); ...@@ -32,6 +32,8 @@ USE_OP_ITSELF(relu);
PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN); PD_DECLARE_KERNEL(relu, OneDNN, ONEDNN);
USE_OP_ITSELF(transpose); USE_OP_ITSELF(transpose);
PD_DECLARE_KERNEL(transpose, OneDNN, ONEDNN); PD_DECLARE_KERNEL(transpose, OneDNN, ONEDNN);
USE_OP_ITSELF(fused_transpose);
PD_DECLARE_KERNEL(fused_transpose, OneDNN, ONEDNN);
USE_OP_ITSELF(shape); USE_OP_ITSELF(shape);
PD_DECLARE_KERNEL(shape, OneDNN, ONEDNN); PD_DECLARE_KERNEL(shape, OneDNN, ONEDNN);
USE_OP_ITSELF(crop); USE_OP_ITSELF(crop);
...@@ -49,7 +51,7 @@ struct InputVars { ...@@ -49,7 +51,7 @@ struct InputVars {
phi::DenseTensor *tensor; phi::DenseTensor *tensor;
}; };
TEST(test_pool2d_transpose_nhwc, cpu_place) { void Test_Pool2d_Transpose_NHWC(const std::string &transpose_type) {
framework::DDim dims({1, 4, 8, 512}); // NHWC shape framework::DDim dims({1, 4, 8, 512}); // NHWC shape
framework::DDim expected_dims({1, 7, 512, 3}); // NHWC expected shape framework::DDim expected_dims({1, 7, 512, 3}); // NHWC expected shape
phi::CPUPlace p; phi::CPUPlace p;
...@@ -89,7 +91,7 @@ TEST(test_pool2d_transpose_nhwc, cpu_place) { ...@@ -89,7 +91,7 @@ TEST(test_pool2d_transpose_nhwc, cpu_place) {
axis[2] = 3; axis[2] = 3;
axis[3] = 1; axis[3] = 1;
auto op_transpose = framework::OpRegistry::CreateOp( auto op_transpose = framework::OpRegistry::CreateOp(
"transpose", transpose_type,
{{"X", {"y"}}}, {{"X", {"y"}}},
{{"Out", {"z"}}}, {{"Out", {"z"}}},
{{"axis", {axis}}, {"use_mkldnn", {true}}}); {{"axis", {axis}}, {"use_mkldnn", {true}}});
...@@ -105,6 +107,11 @@ TEST(test_pool2d_transpose_nhwc, cpu_place) { ...@@ -105,6 +107,11 @@ TEST(test_pool2d_transpose_nhwc, cpu_place) {
"Computed shape does not match expected shape")); "Computed shape does not match expected shape"));
} }
TEST(test_pool2d_transpose_nhwc, cpu_place) {
Test_Pool2d_Transpose_NHWC({"transpose"});
Test_Pool2d_Transpose_NHWC({"fused_transpose"});
}
TEST(test_pool2d_relu_relu_nhwc, cpu_place) { TEST(test_pool2d_relu_relu_nhwc, cpu_place) {
framework::DDim dims({1, 4, 8, 512}); // NHWC shape framework::DDim dims({1, 4, 8, 512}); // NHWC shape
framework::DDim expected_dims({1, 512, 3, 7}); // NCHW expected shape framework::DDim expected_dims({1, 512, 3, 7}); // NCHW expected shape
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using phi::DataLayout;
using phi::OneDNNContext; using phi::OneDNNContext;
template <typename T> template <typename T>
...@@ -28,7 +27,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -28,7 +27,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
true, true,
paddle::platform::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"Operator DNNL Transpose must use CPUPlace")); "Operator DNNL Transpose must use CPUPlace"));
auto& dev_ctx = ctx.template device_context<OneDNNContext>(); auto& dev_ctx = ctx.template device_context<OneDNNContext>();
const auto& dnnl_engine = dev_ctx.GetEngine(); const auto& dnnl_engine = dev_ctx.GetEngine();
...@@ -58,8 +57,6 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -58,8 +57,6 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dnnl::memory::desc(x_vec_dims, dnnl::memory::desc(x_vec_dims,
x->mem_desc().data_type(), x->mem_desc().data_type(),
phi::funcs::GetPlainOneDNNFormat(x_vec_dims.size())); phi::funcs::GetPlainOneDNNFormat(x_vec_dims.size()));
// a trick is used here to fake transpose of out_md, so later it will be
// "untransposed", leaving output data in plain format tag
auto dst_strides = auto dst_strides =
phi::funcs::FakeTransposeStrides(dst_md.dims(), transpose_axis); phi::funcs::FakeTransposeStrides(dst_md.dims(), transpose_axis);
...@@ -88,7 +85,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -88,7 +85,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()),
true, true,
paddle::platform::errors::PreconditionNotMet( phi::errors::PreconditionNotMet(
"Operator DNNL TransposeGrad must use CPUPlace")); "Operator DNNL TransposeGrad must use CPUPlace"));
const auto* dout = const auto* dout =
......
...@@ -109,7 +109,6 @@ const std::unordered_map<std::string, ExtraAttrPropertySet> ...@@ -109,7 +109,6 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
{"Scale_weights", ExtraAttrProperty::ONEDNN}, {"Scale_weights", ExtraAttrProperty::ONEDNN},
{"x_data_format", ExtraAttrProperty::ONEDNN}, {"x_data_format", ExtraAttrProperty::ONEDNN},
{"y_data_format", ExtraAttrProperty::ONEDNN}, {"y_data_format", ExtraAttrProperty::ONEDNN},
{"fused_squeeze2_axes", ExtraAttrProperty::ONEDNN},
{"fused_unsqueeze2_axes", ExtraAttrProperty::ONEDNN}, {"fused_unsqueeze2_axes", ExtraAttrProperty::ONEDNN},
{"fused_reshape2_shape", ExtraAttrProperty::ONEDNN}, {"fused_reshape2_shape", ExtraAttrProperty::ONEDNN},
// ONEDNN pass dedicated attributes // ONEDNN pass dedicated attributes
...@@ -117,9 +116,6 @@ const std::unordered_map<std::string, ExtraAttrPropertySet> ...@@ -117,9 +116,6 @@ const std::unordered_map<std::string, ExtraAttrPropertySet>
{"Bias_scales", ExtraAttrProperty::ONEDNN}, {"Bias_scales", ExtraAttrProperty::ONEDNN},
{"Output_shift_scale", ExtraAttrProperty::ONEDNN}, {"Output_shift_scale", ExtraAttrProperty::ONEDNN},
{"Sum_scale", ExtraAttrProperty::ONEDNN}, {"Sum_scale", ExtraAttrProperty::ONEDNN},
{"scale", ExtraAttrProperty::ONEDNN},
{"shift", ExtraAttrProperty::ONEDNN},
{"output_data_type", ExtraAttrProperty::ONEDNN},
// GPUDNN dedicated attributes // GPUDNN dedicated attributes
{"exhaustive_search", ExtraAttrProperty::GPUDNN}, {"exhaustive_search", ExtraAttrProperty::GPUDNN},
{"fuse_relu_before_depthwise_conv", ExtraAttrProperty::GPUDNN}, {"fuse_relu_before_depthwise_conv", ExtraAttrProperty::GPUDNN},
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -16,34 +16,19 @@ limitations under the License. */ ...@@ -16,34 +16,19 @@ limitations under the License. */
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/operators/transpose_op.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class TransposeOp : public framework::OperatorWithKernel { phi::KernelKey TransposeOp::GetExpectedKernelType(
public: const framework::ExecutionContext &ctx) const {
using framework::OperatorWithKernel::OperatorWithKernel; auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
auto &data_format = ctx.Attr<std::string>("data_format");
protected: phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
phi::KernelKey GetExpectedKernelType( return phi::KernelKey(
const framework::ExecutionContext &ctx) const override { ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); }
auto &data_format = ctx.Attr<std::string>("data_format");
phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
return phi::KernelKey(
ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
}
};
class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
...@@ -69,19 +54,12 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -69,19 +54,12 @@ class TransposeOpMaker : public framework::OpProtoAndCheckerMaker {
"the input will be transformed automatically. ") "the input will be transformed automatically. ")
.SetDefault("AnyLayout") .SetDefault("AnyLayout")
.AsExtra(); .AsExtra();
AddAttr<bool>(
"use_quantizer",
"(bool, default false) "
"This parameter is no longer used. Use 'mkldnn_data_type' instead.")
.SetDefault(false)
.AsExtra();
AddAttr<std::string>( AddAttr<std::string>(
"mkldnn_data_type", "mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel") "(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32") .SetDefault("float32")
.InEnum({"float32", "int8", "bfloat16"}) .InEnum({"float32", "int8", "bfloat16"})
.AsExtra(); .AsExtra();
/* int8 parameters */
AddComment(R"DOC( AddComment(R"DOC(
Transpose Operator. Transpose Operator.
...@@ -129,65 +107,45 @@ class TransposeOpGrad : public framework::OperatorWithKernel { ...@@ -129,65 +107,45 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
} }
}; };
// FIXME(zcd): transpose2 adds an intermediate output(XShape) based on void Transpose2Op::InferShape(framework::InferShapeContext *ctx) const {
// transpose, the XShape is used to carry the shape and lod of X which using CompatMetaTensor = framework::CompatMetaTensor;
// will be used in transpose_grad, in this way, the framework can reuse CompatMetaTensor x(ctx->GetInputVarPtrs("X")[0], ctx->IsRuntime());
// the memory of X immediately the transpose2_op is finished. CompatMetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime());
// Considering compatibility issues, we could not fix transpose2_op std::vector<int> axis = ctx->Attrs().Get<std::vector<int>>("axis");
class Transpose2Op : public TransposeOp { phi::TransposeInferMeta(x, axis, &out);
public:
Transpose2Op(const std::string &type, if (!ctx->HasOutput("XShape")) return;
const framework::VariableNameMap &inputs, const auto &in_dims = ctx->GetInputDim("X");
const framework::VariableNameMap &outputs, std::vector<int64_t> x_shape_dim(in_dims.size() + 1);
const framework::AttributeMap &attrs) x_shape_dim[0] = 0;
: TransposeOp(type, inputs, outputs, attrs) {} for (int i = 0; i < in_dims.size(); ++i) {
x_shape_dim[i + 1] = in_dims[i];
void InferShape(framework::InferShapeContext *ctx) const override {
using CompatMetaTensor = framework::CompatMetaTensor;
CompatMetaTensor x(ctx->GetInputVarPtrs("X")[0], ctx->IsRuntime());
CompatMetaTensor out(ctx->GetOutputVarPtrs("Out")[0], ctx->IsRuntime());
std::vector<int> axis = ctx->Attrs().Get<std::vector<int>>("axis");
phi::TransposeInferMeta(x, axis, &out);
if (!ctx->HasOutput("XShape")) return;
const auto &in_dims = ctx->GetInputDim("X");
std::vector<int64_t> x_shape_dim(in_dims.size() + 1);
x_shape_dim[0] = 0;
for (int i = 0; i < in_dims.size(); ++i) {
x_shape_dim[i + 1] = in_dims[i];
}
ctx->SetOutputDim("XShape", phi::make_ddim(x_shape_dim));
ctx->ShareLoD("X", /*->*/ "XShape");
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::proto::VarType::Type data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "X");
std::string data_format = ctx.Attr<std::string>("data_format");
phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
return phi::KernelKey(
ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
} }
}; ctx->SetOutputDim("XShape", phi::make_ddim(x_shape_dim));
ctx->ShareLoD("X", /*->*/ "XShape");
class Transpose2OpMaker : public framework::OpProtoAndCheckerMaker { }
public:
void Make() override { phi::KernelKey Transpose2Op::GetExpectedKernelType(
AddInput( const framework::ExecutionContext &ctx) const {
"X", auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
"(Tensor) The input tensor, tensors with rank up to 6 are supported."); auto &data_format = ctx.Attr<std::string>("data_format");
AddOutput("Out", "(Tensor)The output tensor."); phi::DataLayout layout_ = phi::StringToDataLayout(data_format);
AddAttr<std::vector<int>>( return phi::KernelKey(
"axis", ctx.GetPlace(), layout_, phi::TransToPhiDataType(data_type));
"(vector<int>) A list of values, and the size of the list should be " }
"the same with the input tensor rank. This operator permutes the input "
"tensor's axes according to the values given."); void Transpose2OpMaker::Make() {
AddOutput("XShape", "(Tensor)The output tensor.") AddInput(
.AsIntermediate() "X",
.AsExtra(); "(Tensor) The input tensor, tensors with rank up to 6 are supported.");
AddComment(R"DOC( AddOutput("Out", "(Tensor)The output tensor.");
AddAttr<std::vector<int>>(
"axis",
"(vector<int>) A list of values, and the size of the list should be "
"the same with the input tensor rank. This operator permutes the input "
"tensor's axes according to the values given.");
AddOutput("XShape", "(Tensor)The output tensor.").AsIntermediate().AsExtra();
AddComment(R"DOC(
Transpose Operator. Transpose Operator.
The input tensor will be permuted according to the axes given. The input tensor will be permuted according to the axes given.
...@@ -215,8 +173,8 @@ The behavior of this operator is similar to how `numpy.transpose` works. ...@@ -215,8 +173,8 @@ The behavior of this operator is similar to how `numpy.transpose` works.
$[0, 2, 3, 1]$, then shape of the output tensor will be: $(N, H, W, C)$. $[0, 2, 3, 1]$, then shape of the output tensor will be: $(N, H, W, C)$.
)DOC"); )DOC");
} Apply();
}; }
template <typename T> template <typename T>
class Transpose2GradMaker : public framework::SingleGradOpMaker<T> { class Transpose2GradMaker : public framework::SingleGradOpMaker<T> {
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
namespace paddle {
namespace operators {
class TransposeOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
};
// FIXME(zcd): transpose2 adds an intermediate output(XShape) based on
// transpose, the XShape is used to carry the shape and lod of X which
// will be used in transpose_grad, in this way, the framework can reuse
// the memory of X immediately the transpose2_op is finished.
// Considering compatibility issues, we could not fix transpose2_op
class Transpose2Op : public TransposeOp {
public:
using TransposeOp::TransposeOp;
Transpose2Op(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: TransposeOp(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override;
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
};
class Transpose2OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() final;
protected:
virtual void Apply() {}
};
} // namespace operators
} // namespace paddle
...@@ -759,6 +759,10 @@ ...@@ -759,6 +759,10 @@
attrs : [bool use_cudnn = false, float fuse_alpha = 0.0f, float fuse_beta = 0.0f, float Scale_in = 1.0f, attrs : [bool use_cudnn = false, float fuse_alpha = 0.0f, float fuse_beta = 0.0f, float Scale_in = 1.0f,
float Scale_out = 1.0f, float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}'] float Scale_out = 1.0f, float Scale_in_eltwise = 1.0f, 'float[] Scale_weights = {1.0f}']
- op : fused_transpose
extra :
attrs : [str data_format = "AnyLayout"]
- op : gather - op : gather
backward : gather_grad backward : gather_grad
extra : extra :
...@@ -1853,8 +1857,7 @@ ...@@ -1853,8 +1857,7 @@
perm : axis perm : axis
extra : extra :
outputs : [XShape] outputs : [XShape]
attrs : [bool use_mkldnn = false, str data_format = "AnyLayout", bool use_quantizer = false, attrs : [bool use_mkldnn = false, str data_format = "AnyLayout", str mkldnn_data_type = "float32"]
str mkldnn_data_type = "float32"]
- op : trilinear_interp (trilinear_interp_v2) - op : trilinear_interp (trilinear_interp_v2)
backward : trilinear_interp_grad (trilinear_interp_v2_grad) backward : trilinear_interp_grad (trilinear_interp_v2_grad)
......
...@@ -60,6 +60,8 @@ static std::vector<int> TransposeToPermuteAxes(const std::vector<int>& axis) { ...@@ -60,6 +60,8 @@ static std::vector<int> TransposeToPermuteAxes(const std::vector<int>& axis) {
return permute_axis; return permute_axis;
} }
// a trick is used here to fake transpose of out_md, so later it will be
// "untransposed", leaving output data in plain format tag
static std::vector<int64_t> FakeTransposeStrides( static std::vector<int64_t> FakeTransposeStrides(
const std::vector<int64_t>& out_dims, const std::vector<int>& axis) { const std::vector<int64_t>& out_dims, const std::vector<int>& axis) {
std::vector<int64_t> fake_strides(axis.size()); std::vector<int64_t> fake_strides(axis.size());
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/onednn/onednn_reuse.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/transpose_kernel.h"
namespace phi {
void SetInMemDescWithSqueeze2FuseSupport(
const std::vector<int> fused_squeeze2_axes,
DenseTensor* in,
const dnnl::memory::desc& in_md) {
const std::set<int64_t> squeeze2_axes_set(fused_squeeze2_axes.begin(),
fused_squeeze2_axes.end());
const std::vector<int64_t>& x_vec_dims = in_md.dims();
std::vector<int64_t> squeezed_op_tz(
x_vec_dims.size() - fused_squeeze2_axes.size(), 0);
int j = 0;
for (size_t i = 0; i < x_vec_dims.size(); ++i) {
if (squeeze2_axes_set.count(i) ||
squeeze2_axes_set.count(i - x_vec_dims.size())) {
PADDLE_ENFORCE_EQ(
x_vec_dims[i],
1,
errors::InvalidArgument(
"Squeeze2 input dim %d should be equal to one, but get %d.",
i,
x_vec_dims[i]));
continue;
}
squeezed_op_tz[j++] = x_vec_dims[i];
}
in->set_mem_desc(in_md.reshape(squeezed_op_tz));
in->Resize(make_ddim(squeezed_op_tz));
}
template <typename T, typename Context>
void FusedTransposeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
const std::vector<int>& fused_squeeze2_axes,
const std::vector<int>& fused_unsqueeze2_axes,
const std::vector<int>& fused_reshape2_shape,
const float scale,
const float shift,
const std::string& output_data_type,
DenseTensor* out) {
// Here we need to match dims to paddle layout
// as we are producing non-oneDNN result
auto x_dims = x.dims();
if ((x_dims.size() >= 3) &&
(phi::OneDNNContext::tls().get_cur_paddle_data_layout() ==
phi::DataLayout::kNHWC)) {
int axis_size = axis.size();
std::vector<int> formated_axis = axis;
std::vector<int> count(axis_size, 0);
for (int i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
formated_axis[i] = axis[i] + axis_size;
}
}
auto dims = phi::vectorize<int>(x_dims);
std::rotate(dims.begin() + 1, dims.begin() + 2, dims.end());
x_dims = x_dims.reshape(dims);
VLOG(3)
<< "Rotating Shape in Transpose from: kMKLDNN to: kNHWC output_shape";
phi::DDim out_dims(x_dims);
for (size_t i = 0; i < axis.size(); i++) {
out_dims[i] = x_dims[formated_axis[i]];
}
out->Resize(out_dims);
}
PADDLE_ENFORCE_EQ(
dev_ctx.GetPlace().GetType(),
AllocationType::CPU,
errors::PreconditionNotMet("oneDNN Transpose kernel must use CPUPlace"));
if (!(fused_squeeze2_axes.empty())) {
SetInMemDescWithSqueeze2FuseSupport(
fused_squeeze2_axes, const_cast<DenseTensor*>(&x), x.mem_desc());
}
if (axis.size() == 1) {
Copy<Context>(dev_ctx, x, x.place(), false, out);
out->set_mem_desc(x.mem_desc());
return;
}
auto x_vec_dims = vectorize(x.dims());
auto x_type = funcs::ToOneDNNDataType(x.dtype());
dnnl::primitive_attr attrs;
const int32_t mask = 0;
if (scale != 1.0f) {
attrs.set_output_scales(mask, {scale});
}
if (shift != 0.0f) {
auto dst = output_data_type == "fp32" ? DNNL_ARG_SRC : DNNL_ARG_DST;
attrs.set_zero_points(dst, mask, {static_cast<int32_t>(shift)});
}
DataType out_dtype;
if (output_data_type == "bf16") {
out_dtype = DataType::BFLOAT16;
} else if (output_data_type == "int8") {
out_dtype = DataType::INT8;
} else if (output_data_type == "uint8") {
out_dtype = DataType::UINT8;
} else if (output_data_type == "fp32") {
out_dtype = DataType::FLOAT32;
} else {
out_dtype = x.dtype();
}
auto out_type = funcs::ToOneDNNDataType(out_dtype);
funcs::ReorderOneDNNHandler reorder_handler(
x_vec_dims, x.dtype(), x_type, out_dtype, out_type, dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x.mem_desc(), funcs::to_void_cast(x.data<T>()));
auto fake_strides = funcs::FakeTransposeStrides(x_vec_dims, axis);
auto dst_md = dnnl::memory::desc(x_vec_dims, out_type, fake_strides);
auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(out, dst_md, dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(
reorder_dst_memory_p, reorder_src_memory_p, attrs);
auto& astream = OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
auto out_md = reorder_dst_memory_p->get_desc().permute_axes(
funcs::TransposeToPermuteAxes(axis));
if (!fused_unsqueeze2_axes.empty()) {
funcs::SetOutMemDescWithUnsqueeze2FuseSupport(
fused_unsqueeze2_axes, out, out_md);
} else if (!fused_reshape2_shape.empty()) {
funcs::SetOutMemDescWithReshape2FuseSupport(
fused_reshape2_shape, out, out_md);
} else if (!fused_squeeze2_axes.empty()) {
out->set_mem_desc(out_md);
out->Resize(make_ddim(out_md.dims()));
} else {
out->set_mem_desc(out_md);
}
}
} // namespace phi
PD_REGISTER_KERNEL(fused_transpose,
OneDNN,
ONEDNN,
phi::FusedTransposeKernel,
float,
uint8_t,
int8_t,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -18,87 +18,6 @@ ...@@ -18,87 +18,6 @@
namespace phi { namespace phi {
void SetOutMemDescWithLogicalLayoutFusesSupport(
const OneDNNContext& dev_ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
const auto fused_unsqueeze2_axes =
dev_ctx.HasDnnAttr("fused_unsqueeze2_axes")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_unsqueeze2_axes"))
: std::vector<int>();
const auto fused_reshape2_shape =
dev_ctx.HasDnnAttr("fused_reshape2_shape")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_reshape2_shape"))
: std::vector<int>();
const auto fused_squeeze2_axes =
dev_ctx.HasDnnAttr("fused_squeeze2_axes")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_squeeze2_axes"))
: std::vector<int>();
if (!fused_unsqueeze2_axes.empty()) {
funcs::SetOutMemDescWithUnsqueeze2FuseSupport(
fused_unsqueeze2_axes, out, out_md);
} else if (!fused_reshape2_shape.empty()) {
funcs::SetOutMemDescWithReshape2FuseSupport(
fused_reshape2_shape, out, out_md);
} else if (!fused_squeeze2_axes.empty()) {
out->set_mem_desc(out_md);
out->Resize(make_ddim(out_md.dims()));
} else {
out->set_mem_desc(out_md);
}
}
void SetInMemDescWithSqueeze2FuseSupport(
const std::vector<int> fused_squeeze2_axes,
DenseTensor* in,
const dnnl::memory::desc& in_md) {
const std::set<int64_t> squeeze2_axes_set(fused_squeeze2_axes.begin(),
fused_squeeze2_axes.end());
const std::vector<int64_t>& x_vec_dims = in_md.dims();
std::vector<int64_t> squeezed_op_tz(
x_vec_dims.size() - fused_squeeze2_axes.size(), 0);
int j = 0;
for (size_t i = 0; i < x_vec_dims.size(); ++i) {
if (squeeze2_axes_set.count(i) ||
squeeze2_axes_set.count(i - x_vec_dims.size())) {
PADDLE_ENFORCE_EQ(
x_vec_dims[i],
1,
errors::InvalidArgument(
"Squeeze2 input dim %d should be equal to one, but get %d.",
i,
x_vec_dims[i]));
continue;
}
squeezed_op_tz[j++] = x_vec_dims[i];
}
in->set_mem_desc(in_md.reshape(squeezed_op_tz));
in->Resize(make_ddim(squeezed_op_tz));
}
void SetInMemDescWithLogicalLayoutFusesSupport(
const OneDNNContext& dev_ctx,
DenseTensor* in,
const dnnl::memory::desc& in_md) {
const auto fused_squeeze2_axes =
dev_ctx.HasDnnAttr("fused_squeeze2_axes")
? PADDLE_GET_CONST(std::vector<int>,
dev_ctx.GetDnnAttr("fused_squeeze2_axes"))
: std::vector<int>();
if (fused_squeeze2_axes.empty()) {
in->set_mem_desc(in_md);
in->Resize(make_ddim(in_md.dims()));
} else {
SetInMemDescWithSqueeze2FuseSupport(fused_squeeze2_axes, in, in_md);
}
}
template <typename T, typename Context> template <typename T, typename Context>
void TransposeKernel(const Context& dev_ctx, void TransposeKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
...@@ -137,9 +56,6 @@ void TransposeKernel(const Context& dev_ctx, ...@@ -137,9 +56,6 @@ void TransposeKernel(const Context& dev_ctx,
AllocationType::CPU, AllocationType::CPU,
errors::PreconditionNotMet("oneDNN Transpose kernel must use CPUPlace")); errors::PreconditionNotMet("oneDNN Transpose kernel must use CPUPlace"));
SetInMemDescWithLogicalLayoutFusesSupport(
dev_ctx, const_cast<DenseTensor*>(&x), x.mem_desc());
if (axis.size() == 1 || axis.size() == 0) { if (axis.size() == 1 || axis.size() == 0) {
Copy<Context>(dev_ctx, x, x.place(), false, out); Copy<Context>(dev_ctx, x, x.place(), false, out);
out->set_mem_desc(x.mem_desc()); out->set_mem_desc(x.mem_desc());
...@@ -148,79 +64,24 @@ void TransposeKernel(const Context& dev_ctx, ...@@ -148,79 +64,24 @@ void TransposeKernel(const Context& dev_ctx,
auto x_vec_dims = vectorize(x.dims()); auto x_vec_dims = vectorize(x.dims());
auto x_type = funcs::ToOneDNNDataType(x.dtype()); auto x_type = funcs::ToOneDNNDataType(x.dtype());
dnnl::primitive_attr attrs;
const int32_t mask = 0;
const auto quantization_scale =
dev_ctx.HasDnnAttr("scale")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("scale"))
: 1.0f;
const auto quantization_shift =
dev_ctx.HasDnnAttr("shift")
? PADDLE_GET_CONST(float, dev_ctx.GetDnnAttr("shift"))
: 0.0f;
const auto output_data_type =
dev_ctx.HasDnnAttr("output_data_type")
? PADDLE_GET_CONST(std::string,
dev_ctx.GetDnnAttr("output_data_type"))
: "";
const bool with_scale = quantization_scale != 1.0f;
const bool with_shift = quantization_shift != 0.0f;
if (with_scale) {
attrs.set_output_scales(mask, {quantization_scale});
}
if (with_shift) {
auto dst = output_data_type == "fp32" ? DNNL_ARG_SRC : DNNL_ARG_DST;
attrs.set_zero_points(
dst, mask, {static_cast<int32_t>(quantization_shift)});
}
DataType out_dtype;
if (output_data_type == "bf16") {
out_dtype = DataType::BFLOAT16;
} else if (output_data_type == "int8") {
out_dtype = DataType::INT8;
} else if (output_data_type == "uint8") {
out_dtype = DataType::UINT8;
} else if (output_data_type == "fp32") {
out_dtype = DataType::FLOAT32;
} else {
out_dtype = x.dtype();
}
auto out_type = phi::funcs::ToOneDNNDataType(out_dtype);
funcs::ReorderOneDNNHandler reorder_handler( funcs::ReorderOneDNNHandler reorder_handler(
x_vec_dims, x.dtype(), x_type, out_dtype, out_type, dev_ctx.GetEngine()); x_vec_dims, x.dtype(), x_type, dev_ctx.GetEngine());
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x.mem_desc(), funcs::to_void_cast(x.data<T>())); x.mem_desc(), funcs::to_void_cast(x.data<T>()));
auto fake_strides = funcs::FakeTransposeStrides(x_vec_dims, axis); auto fake_strides = funcs::FakeTransposeStrides(x_vec_dims, axis);
auto dst_md = dnnl::memory::desc(x_vec_dims, out_type, fake_strides); auto dst_md =
dnnl::memory::desc(x_vec_dims, x.mem_desc().data_type(), fake_strides);
auto reorder_dst_memory_p = auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(out, dst_md, dev_ctx.GetPlace()); reorder_handler.AcquireDstMemory(out, dst_md, dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_dst_memory_p,
auto reorder_p = reorder_handler.AcquireReorder( reorder_src_memory_p);
reorder_dst_memory_p, reorder_src_memory_p, attrs);
auto& astream = OneDNNContext::tls().get_stream(); auto& astream = OneDNNContext::tls().get_stream();
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
out->set_mem_desc(reorder_dst_memory_p->get_desc().permute_axes(
// it is needed because oneDNN's permute axis understand axes order in funcs::TransposeToPermuteAxes(axis)));
// different way PaddlePaddle's transpose
std::vector<int> permute_axis(axis.size());
for (size_t i = 0; i < axis.size(); ++i) {
permute_axis[axis[i]] = i;
}
SetOutMemDescWithLogicalLayoutFusesSupport(
dev_ctx,
out,
reorder_dst_memory_p->get_desc().permute_axes(
funcs::TransposeToPermuteAxes(axis)));
} }
} // namespace phi } // namespace phi
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature FusedTransposeOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("fused_transpose",
{"X"},
{"axis",
"fused_squeeze2_axes",
"fused_unsqueeze2_axes",
"fused_reshape2_shape",
"scale",
"shift",
"output_data_type"},
{"Out"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(fused_transpose,
phi::FusedTransposeOpArgumentMapping);
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestTranspose2Reshape2OneDNNFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
channel = draw(st.sampled_from([1, 2, 4]))
axis = draw(st.sampled_from([[0, 1, 2, 3], [2, 1, 3, 0], [3, 2, 1, 0]]))
shape = draw(
st.sampled_from(
[[channel, 512, 64], [256, 128, channel], [channel, 1024, 32]]
)
)
transpose2_op = OpConfig(
type="transpose2",
inputs={
"X": ["transpose_x"],
},
outputs={
"Out": ["transpose_out"],
"XShape": ['transpose2_xshape'],
},
attrs={
"axis": axis,
"use_mkldnn": True,
},
)
reshape2_op = OpConfig(
type="reshape2",
inputs={"X": ["transpose_out"]},
outputs={"Out": ["reshape_out"]},
attrs={
"shape": shape,
},
)
model_net = [transpose2_op, reshape2_op]
program_config = ProgramConfig(
ops=model_net,
weights={},
inputs={
"transpose_x": TensorConfig(
data_gen=partial(generate_input, [channel, 16, 64, 32])
)
},
outputs=["reshape_out"],
)
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True,
passes=[
"operator_reshape2_onednn_fuse_pass",
],
)
yield config, ["fused_transpose"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
quant=False,
passes=[
"operator_reshape2_onednn_fuse_pass",
],
)
if __name__ == "__main__":
unittest.main()
...@@ -78,7 +78,7 @@ class TestTranspose2Unsqueeze2OneDNNFusePass(PassAutoScanTest): ...@@ -78,7 +78,7 @@ class TestTranspose2Unsqueeze2OneDNNFusePass(PassAutoScanTest):
"operator_unsqueeze2_onednn_fuse_pass", "operator_unsqueeze2_onednn_fuse_pass",
], ],
) )
yield config, ["transpose2"], (1e-5, 1e-5) yield config, ["fused_transpose"], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
...@@ -109,7 +109,7 @@ class TestQuantTranspose2DequantOneDNNFusePass(PassAutoScanTest): ...@@ -109,7 +109,7 @@ class TestQuantTranspose2DequantOneDNNFusePass(PassAutoScanTest):
use_mkldnn=True, use_mkldnn=True,
passes=['quant_transpose2_dequant_onednn_fuse_pass'], passes=['quant_transpose2_dequant_onednn_fuse_pass'],
) )
yield config, ['transpose2', 'transpose2'], (1e-5, 1e-5) yield config, ['fused_transpose', 'fused_transpose'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
import hypothesis.strategies as st
import numpy as np
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestSqueeze2Transpose2OneDNNFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
channel = draw(st.sampled_from([1, 2, 4, 8, 16]))
transpose_axis = draw(
st.sampled_from(
[[0, 1, 2], [0, 2, 1], [1, 0, 2], [2, 1, 0], [2, 1, 0]]
)
)
squeeze2_op = OpConfig(
type="squeeze2",
inputs={"X": ["squeeze_x"]},
outputs={
"Out": ["squeeze_out"],
"XShape": ["squeeze2_xshape"],
},
attrs={
"axes": [2],
"use_mkldnn": True,
},
)
transpose2_op = OpConfig(
type="transpose2",
inputs={
"X": ["squeeze_out"],
},
outputs={
"Out": ["trans_out"],
"XShape": ['transpose2_xshape'],
},
attrs={
"axis": transpose_axis,
"use_mkldnn": True,
},
)
model_net = [squeeze2_op, transpose2_op]
program_config = ProgramConfig(
ops=model_net,
weights={},
inputs={
"squeeze_x": TensorConfig(
data_gen=partial(generate_input, [channel, 16, 1, 32])
)
},
outputs=["trans_out"],
)
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True,
passes=[
"squeeze2_transpose2_onednn_fuse_pass",
],
)
yield config, ["fused_transpose"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
quant=False,
passes=[
"squeeze2_transpose2_onednn_fuse_pass",
],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册