未验证 提交 9e006987 编写于 作者: J jakpiase 提交者: GitHub

Optimized oneDNN FC and added operator+unsqueeze2 and operator+reshape2 oneDNN fuse passes (#47391)

* tmp save

* minor chnage

* CI fix

* added FC optimizations

* latest update

* CI fix

* fixed bug with fusing fc
上级 6916215e
......@@ -177,6 +177,8 @@ if(WITH_MKLDNN)
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn)
pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
pass_library(cpu_quantize_pass inference DIR mkldnn)
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
......
......@@ -958,6 +958,44 @@ PDNode *patterns::OperatorActivation::operator()(
return activation_out;
}
PDNode *patterns::OperatorUnsqueeze2::operator()(
const std::string &operator_type, const int num_of_operator_outs) {
auto *preceding_op = pattern->NewNode(preceding_op_repr())
->assert_is_op(operator_type)
->assert_has_n_outputs(num_of_operator_outs);
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
->AsIntermediate()
->assert_is_op_output(operator_type, "Out")
->assert_is_op_input("unsqueeze2");
auto *unsqueeze2_op =
pattern->NewNode(unsqueeze2_op_repr())->assert_is_op("unsqueeze2");
auto *unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr())
->AsOutput()
->assert_is_op_output("unsqueeze2");
preceding_op->LinksTo({preceding_op_out});
unsqueeze2_op->LinksFrom({preceding_op_out}).LinksTo({unsqueeze2_out});
return unsqueeze2_out;
}
PDNode *patterns::OperatorReshape2::operator()(const std::string &operator_type,
const int num_of_operator_outs) {
auto *preceding_op = pattern->NewNode(preceding_op_repr())
->assert_is_op(operator_type)
->assert_has_n_outputs(num_of_operator_outs);
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
->AsIntermediate()
->assert_is_op_output(operator_type, "Out")
->assert_is_op_input("reshape2");
auto *reshape2_op =
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
auto *reshape2_out = pattern->NewNode(reshape2_out_repr())
->AsOutput()
->assert_is_op_output("reshape2");
preceding_op->LinksTo({preceding_op_out});
reshape2_op->LinksFrom({preceding_op_out}).LinksTo({reshape2_out});
return reshape2_out;
}
PDNode *patterns::SeqConvEltAddRelu::operator()(
paddle::framework::ir::PDNode *seqconv_input) {
// Create Operators
......
......@@ -539,6 +539,32 @@ struct OperatorActivation : public PatternBase {
PATTERN_DECL_NODE(activation_out);
};
struct OperatorUnsqueeze2 : public PatternBase {
OperatorUnsqueeze2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_unsqueeze2") {}
PDNode* operator()(const std::string& operator_type,
const int num_of_outputs);
PATTERN_DECL_NODE(preceding_op);
PATTERN_DECL_NODE(preceding_op_out);
PATTERN_DECL_NODE(unsqueeze2_op);
PATTERN_DECL_NODE(unsqueeze2_out);
};
struct OperatorReshape2 : public PatternBase {
OperatorReshape2(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "operator_reshape2") {}
PDNode* operator()(const std::string& operator_type,
const int num_of_outputs);
PATTERN_DECL_NODE(preceding_op);
PATTERN_DECL_NODE(preceding_op_out);
PATTERN_DECL_NODE(reshape2_op);
PATTERN_DECL_NODE(reshape2_out);
};
// SEQCONV with Elementwise_Add ReLU
// op: seqconv + elementwise_add + relu
// named nodes:
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/operator_reshape2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseOperatorReshape2OneDNNPass::ApplyImpl(Graph *graph) const {
// THIS FUSE WILL WORK ONLY WITH OPERATORS THAT OUTPUTS PLAIN MEMORY, F.E.
// ABCD FOR 4D! BE AWARE OF THAT!
std::vector<std::pair<std::string, int>> ops_and_outputs = {
{"fc", 1}, {"transpose2", 2}};
for (const auto &op_and_outputs : ops_and_outputs)
FuseReshape2(graph, op_and_outputs.first, op_and_outputs.second);
}
void FuseOperatorReshape2OneDNNPass::FuseReshape2(Graph *graph,
const std::string &op_type,
int num_of_outputs) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(op_type + "_reshape2_onednn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::OperatorReshape2 op_reshape2_pattern(
gpd.mutable_pattern(), op_type + "_reshape2_onednn_fuse_pass");
op_reshape2_pattern(op_type, num_of_outputs);
int found_operator_reshape2_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_reshape2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
operator_out, preceding_op_out, op_reshape2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_op, reshape2_op, op_reshape2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(reshape2_out, reshape2_out, op_reshape2_pattern);
if (!operator_op->Op()->HasAttr("use_mkldnn") ||
(operator_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) {
VLOG(4) << "Only oneDNN version of " << op_type
<< "can be fused with reshape2.";
return;
}
if (operator_op->Op()->HasAttr("fused_unsqueeze2_axes")) {
VLOG(4) << "Cannot do " << op_type << " + reshape2 fuse, because "
<< op_type << " is already fused with unsqueeze2!";
return;
}
std::vector<int> reshape2_shape =
PADDLE_GET_CONST(std::vector<int>, reshape2_op->Op()->GetAttr("shape"));
int num_of_minus_ones = 0;
for (size_t i = 0; i < reshape2_shape.size(); ++i) {
if (reshape2_shape[i] == 0) {
VLOG(4) << "OneDNN op+reshape2 fuse pass does not support zero dims, "
"skipping";
return;
} else if (reshape2_shape[i] == -1) {
++num_of_minus_ones;
}
}
if (num_of_minus_ones > 1) {
VLOG(4) << "Number of -1 values inside of reshape2 shouldn't be greater "
"than one in op+reshape2 oneDNN fuse pass, skipping";
return;
}
auto const &names = reshape2_op->Op()->InputNames();
bool has_shape_tensor =
std::find(names.begin(), names.end(), "ShapeTensor") != names.end();
bool has_shape_tensor_list =
std::find(names.begin(), names.end(), "ShapeTensorList") != names.end();
if (has_shape_tensor &&
reshape2_op->Op()->Input("ShapeTensor").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and reshape2 because reshape2 dims are specified by "
"ShapeTensor!";
return;
}
if (has_shape_tensor_list &&
reshape2_op->Op()->Input("ShapeTensorList").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and reshape2 because reshape2 dims are specified by "
"ShapeTensorList!";
return;
}
operator_op->Op()->SetAttr("fused_reshape2_shape", reshape2_shape);
operator_op->Op()->SetOutput("Out", {reshape2_out->Name()});
IR_OP_VAR_LINK(operator_op, reshape2_out);
GraphSafeRemoveNodes(g, {reshape2_op, operator_out});
found_operator_reshape2_count++;
};
gpd(graph, handler);
AddStatis(found_operator_reshape2_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_operator_reshape2_count > 0)
PrettyLogDetail("--- fused %d %s with reshape2",
found_operator_reshape2_count,
op_type);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(operator_reshape2_onednn_fuse_pass,
paddle::framework::ir::FuseOperatorReshape2OneDNNPass);
REGISTER_PASS_CAPABILITY(operator_reshape2_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("reshape2", 0)
.GE("fc", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class FuseOperatorReshape2OneDNNPass : public FusePassBase {
public:
virtual ~FuseOperatorReshape2OneDNNPass() {}
protected:
void ApplyImpl(Graph *graph) const override;
void FuseReshape2(Graph *graph,
const std::string &op_type,
int num_of_outputs) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/operator_unsqueeze2_onednn_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
namespace framework {
namespace ir {
using string::PrettyLogDetail;
void FuseOperatorUnsqueeze2OneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::pair<std::string, int>> ops_and_outputs = {
{"transpose2", 2}, {"elementwise_mul", 1}};
for (const auto &op_and_outputs : ops_and_outputs)
FuseUnsqueeze2(graph, op_and_outputs.first, op_and_outputs.second);
}
void FuseOperatorUnsqueeze2OneDNNPass::FuseUnsqueeze2(
Graph *graph, const std::string &op_type, int num_of_outputs) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init(op_type + "_unsqueeze2_onednn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::OperatorUnsqueeze2 op_unsqueeze2_pattern(
gpd.mutable_pattern(), op_type + "_unsqueeze2_onednn_fuse_pass");
op_unsqueeze2_pattern(op_type, num_of_outputs);
int found_operator_unsqueeze2_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
Graph *g) {
GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_unsqueeze2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
operator_out, preceding_op_out, op_unsqueeze2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
unsqueeze2_op, unsqueeze2_op, op_unsqueeze2_pattern);
GET_IR_NODE_FROM_SUBGRAPH(
unsqueeze2_out, unsqueeze2_out, op_unsqueeze2_pattern);
if (!operator_op->Op()->HasAttr("use_mkldnn") ||
(operator_op->Op()->HasAttr("use_mkldnn") &&
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn"))))) {
VLOG(4) << "Only oneDNN version of " << op_type
<< "can be fused with unsqueeze2.";
return;
}
std::vector<int> unsqueeze2_axes = PADDLE_GET_CONST(
std::vector<int>, unsqueeze2_op->Op()->GetAttr("axes"));
auto const &names = unsqueeze2_op->Op()->InputNames();
bool has_axes_tensor =
std::find(names.begin(), names.end(), "AxesTensor") != names.end();
bool has_axes_tensor_list =
std::find(names.begin(), names.end(), "AxesTensorList") != names.end();
if (has_axes_tensor &&
unsqueeze2_op->Op()->Input("AxesTensor").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and unsqueeze2 because unsqueeze2 dims are specified by "
"AxesTensor!";
return;
}
if (has_axes_tensor_list &&
unsqueeze2_op->Op()->Input("AxesTensorList").size() > 0) {
VLOG(4) << "Cannot fuse " << op_type
<< " and unsqueeze2 because unsqueeze2 dims are specified by "
"AxesTensorList!";
return;
}
operator_op->Op()->SetAttr("fused_unsqueeze2_axes", unsqueeze2_axes);
operator_op->Op()->SetOutput("Out", {unsqueeze2_out->Name()});
IR_OP_VAR_LINK(operator_op, unsqueeze2_out);
GraphSafeRemoveNodes(g, {unsqueeze2_op, operator_out});
found_operator_unsqueeze2_count++;
};
gpd(graph, handler);
AddStatis(found_operator_unsqueeze2_count);
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_operator_unsqueeze2_count > 0)
PrettyLogDetail("--- fused %d %s with unsqueeze2",
found_operator_unsqueeze2_count,
op_type);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(operator_unsqueeze2_onednn_fuse_pass,
paddle::framework::ir::FuseOperatorUnsqueeze2OneDNNPass);
REGISTER_PASS_CAPABILITY(operator_unsqueeze2_onednn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("unsqueeze2", 0)
.GE("transpose2", 0));
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
namespace paddle {
namespace framework {
namespace ir {
class FuseOperatorUnsqueeze2OneDNNPass : public FusePassBase {
public:
virtual ~FuseOperatorUnsqueeze2OneDNNPass() {}
protected:
void ApplyImpl(Graph *graph) const override;
void FuseUnsqueeze2(Graph *graph,
const std::string &op_type,
int num_of_outputs) const;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -334,6 +334,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"shuffle_channel_mkldnn_detect_pass", //
"elt_act_mkldnn_fuse_pass", //
"operator_scale_onednn_fuse_pass", //
"operator_unsqueeze2_onednn_fuse_pass", //
"operator_reshape2_onednn_fuse_pass", //
// TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710
// "mkldnn_inplace_pass", // This pass should be activated after
......@@ -428,6 +430,8 @@ void CpuPassStrategy::EnableMkldnnInt8() {
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("operator_scale_onednn_fuse_pass");
passes_.push_back("operator_unsqueeze2_onednn_fuse_pass");
passes_.push_back("operator_reshape2_onednn_fuse_pass");
passes_.push_back("cpu_quantize_placement_pass");
passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass");
......
......@@ -196,12 +196,14 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
astream.wait();
if (handler.use_broadcasting_hack == false) {
z->set_mem_desc(dst_memory->get_desc());
platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx, z, dst_memory->get_desc());
} else {
auto dims = dst_memory->get_desc().dims();
dims.insert(dims.begin(), x->dims()[0]);
dims[1] /= dims[0];
z->set_mem_desc(dst_memory->get_desc().reshape(dims));
platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx, z, dst_memory->get_desc().reshape(dims));
}
}
};
......
......@@ -39,6 +39,13 @@ constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
struct InnerProductCache {
dnnl::inner_product_forward inner_product_p;
dnnl::memory src_mem;
dnnl::memory weights_mem;
dnnl::memory bias_mem;
dnnl::memory dst_mem;
};
template <typename T_in, typename T_w, typename T_out>
class FCMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T_in,
......@@ -357,6 +364,23 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
}));
}
void PrepareSrcMem(const std::shared_ptr<inner_product_forward>& fc_p,
const std::shared_ptr<dnnl::memory>& src_mem,
const LoDTensor* x,
const dnnl::engine& engine) const {
auto x_md = x->mem_desc().reshape(src_mem->get_desc().dims());
if (x_md != src_mem->get_desc()) {
dnnl::memory x_mem(x_md, engine, to_void_cast<T_in>(x->data<T_in>()));
auto reorder_p = dnnl::reorder(x_mem, *src_mem);
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p.execute(astream, x_mem, *src_mem);
astream.wait();
} else {
src_mem->set_data_handle(to_void_cast<T_in>(x->data<T_in>()));
}
}
template <typename T_out, typename T_w>
void RunKernel(const framework::ExecutionContext& ctx) const {
const auto& dev_ctx =
......@@ -368,29 +392,80 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
const auto* bias = ctx.Input<phi::DenseTensor>("Bias");
auto out = ctx.Output<LoDTensor>("Out");
auto in_col_dims = ctx.Attr<int>("in_num_col_dims");
const float scale_in = ctx.Attr<float>("Scale_in");
const auto& scale_weights = ctx.Attr<std::vector<float>>("Scale_weights");
std::shared_ptr<dnnl::inner_product_forward> fc_p;
std::shared_ptr<dnnl::memory> src_memory_p;
std::shared_ptr<dnnl::memory> weights_memory_p;
std::shared_ptr<dnnl::memory> bias_memory_p;
std::shared_ptr<dnnl::memory> dst_memory_p;
std::string cache_key;
cache_key.reserve(64);
cache_key = platform::ExtendKeyWithThreadInfoIfNeeded(
dev_ctx,
platform::CreateKey(dev_ctx,
ctx.InputName("Input"),
ctx.InputName("W"),
phi::vectorize(x->dims())));
auto inner_product_cache =
std::static_pointer_cast<InnerProductCache>(dev_ctx.GetBlob(cache_key));
RecomputeOutputDims(ctx, x, weights, out);
FCMKLDNNHandler<T_in, T_w, T_out> handler(ctx,
dev_ctx,
x,
weights,
bias,
out,
in_col_dims,
mkldnn_engine,
ctx.GetPlace());
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(x);
auto weights_memory_p =
handler.AcquireWeightsMemoryWithReorder(weights, scale_weights);
auto dst_memory_p = handler.AcquireCustomDstMemory(ctx, out);
auto fc_p = handler.AcquireForwardPrimitive();
if (inner_product_cache) {
fc_p = std::make_shared<dnnl::inner_product_forward>(
inner_product_cache->inner_product_p);
src_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->src_mem);
PrepareSrcMem(fc_p, src_memory_p, x, mkldnn_engine);
weights_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->weights_mem);
dst_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->dst_mem);
if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) {
auto* residual_param = ctx.Output<phi::DenseTensor>("ResidualData");
out->ShareDataWith(*residual_param);
}
auto out_ptr = out->mutable_data<T_out>(
ctx.GetPlace(), dst_memory_p->get_desc().get_size());
dst_memory_p->set_data_handle(out_ptr);
if (bias) {
bias_memory_p =
std::make_shared<dnnl::memory>(inner_product_cache->bias_mem);
}
} else {
auto in_col_dims = ctx.Attr<int>("in_num_col_dims");
FCMKLDNNHandler<T_in, T_w, T_out> handler(ctx,
dev_ctx,
x,
weights,
bias,
out,
in_col_dims,
mkldnn_engine,
ctx.GetPlace());
src_memory_p = handler.AcquireSrcMemoryWithReorder(x);
weights_memory_p =
handler.AcquireWeightsMemoryWithReorder(weights, scale_weights);
dst_memory_p = handler.AcquireCustomDstMemory(ctx, out);
if (bias) {
bias_memory_p =
handler.AcquireBiasMemoryWithReorder(bias, scale_in, scale_weights);
}
fc_p = handler.AcquireForwardPrimitive();
}
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
std::unordered_map<int, dnnl::memory> fc_args = {
......@@ -399,15 +474,27 @@ class FCMKLDNNKernel : public framework::OpKernel<T_in> {
{DNNL_ARG_DST, *dst_memory_p}};
if (bias) {
auto bias_memory_p =
handler.AcquireBiasMemoryWithReorder(bias, scale_in, scale_weights);
fc_args.insert({DNNL_ARG_BIAS, *bias_memory_p});
}
fc_p->execute(astream, fc_args);
astream.wait();
out->set_mem_desc(
if (!inner_product_cache) {
auto ip_cache = std::make_shared<InnerProductCache>();
ip_cache->inner_product_p = *fc_p;
ip_cache->src_mem = *src_memory_p;
ip_cache->weights_mem = *weights_memory_p;
ip_cache->dst_mem = *dst_memory_p;
if (bias) {
ip_cache->bias_mem = *bias_memory_p;
}
dev_ctx.SetBlob(cache_key, ip_cache);
}
platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx,
out,
dst_memory_p->get_desc().reshape(phi::vectorize(out->dims())));
}
......
......@@ -81,8 +81,11 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
out->set_mem_desc(reorder_dst_memory_p->get_desc().permute_axes(
TransposeToPermuteAxis(transpose_axis)));
platform::SetOutMemDescWithLogicalLayoutFusesSupport(
ctx,
out,
reorder_dst_memory_p->get_desc().permute_axes(
TransposeToPermuteAxis(transpose_axis)));
}
private:
......
......@@ -109,6 +109,69 @@ static void AppendActivation(const framework::ExecutionContext& ctx,
}
}
static void SetOutMemDescWithUnsqueeze2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
const std::vector<int>& fused_unsqueeze2_axes =
ctx.Attr<std::vector<int>>("fused_unsqueeze2_axes");
const std::vector<int64_t>& op_tz = out_md.dims();
std::vector<int64_t> unsqueezed_op_tz(
op_tz.size() + fused_unsqueeze2_axes.size(), 0);
for (const auto& axis : fused_unsqueeze2_axes) {
int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis;
unsqueezed_op_tz[positive_axis] = 1;
}
int j = 0;
for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) {
if (unsqueezed_op_tz[i] == 0) {
unsqueezed_op_tz[i] = op_tz[j++];
}
}
out->set_mem_desc(out_md.reshape(unsqueezed_op_tz));
out->Resize(phi::make_ddim(unsqueezed_op_tz));
}
static void SetOutMemDescWithReshape2FuseSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
std::vector<int64_t> fused_reshape2_shape(
ctx.Attr<std::vector<int>>("fused_reshape2_shape").begin(),
ctx.Attr<std::vector<int>>("fused_reshape2_shape").end());
const int out_shape_numel = out->numel();
const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(),
fused_reshape2_shape.end(),
1,
std::multiplies<int64_t>());
for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) {
if (fused_reshape2_shape[i] == -1) {
fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel;
break;
}
}
out->set_mem_desc(out_md.reshape(fused_reshape2_shape));
out->Resize(phi::make_ddim(fused_reshape2_shape));
}
static void SetOutMemDescWithLogicalLayoutFusesSupport(
const framework::ExecutionContext& ctx,
phi::DenseTensor* out,
const dnnl::memory::desc& out_md) {
if (ctx.HasAttr("fused_unsqueeze2_axes")) {
SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md);
} else if (ctx.HasAttr("fused_reshape2_shape")) {
SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md);
} else {
out->set_mem_desc(out_md);
}
}
template <typename T>
constexpr bool IsInt8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册