未验证 提交 c7623d72 编写于 作者: J jakpiase 提交者: GitHub

Added shuffle_channel BF16/FP32 FWD oneDNN kernel (#39756)

* added shuffle_channel bf16/fp32 fwd kernel

* added missing files

* CI fix

* changed from pten to phi

* tmp save

* added reviewers suggestions

* fix for test
上级 97dec7ca
...@@ -128,6 +128,7 @@ if(WITH_MKLDNN) ...@@ -128,6 +128,7 @@ if(WITH_MKLDNN)
pass_library(fc_mkldnn_pass inference DIR mkldnn) pass_library(fc_mkldnn_pass inference DIR mkldnn)
pass_library(interpolate_mkldnn_pass inference DIR mkldnn) pass_library(interpolate_mkldnn_pass inference DIR mkldnn)
pass_library(softplus_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(softplus_activation_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(shuffle_channel_mkldnn_detect_pass inference DIR mkldnn)
pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(fc_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn) pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(cpu_quantize_placement_pass base DIR mkldnn) pass_library(cpu_quantize_placement_pass base DIR mkldnn)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/mkldnn/shuffle_channel_mkldnn_detect_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(reshape1_op); \
GET_IR_NODE(reshape1_out); \
GET_IR_NODE(transpose_op); \
GET_IR_NODE(transpose_out); \
GET_IR_NODE(reshape2_op); \
GET_IR_NODE(reshape2_out);
ShuffleChannelMKLDNNDetectPass::ShuffleChannelMKLDNNDetectPass() {
AddOpCompat(OpCompat("reshape2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Shape")
.IsOptional()
.IsTensor()
.End()
.AddInput("ShapeTensor")
.IsOptional()
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("shape")
.IsType<std::vector<int>>()
.End();
AddOpCompat(OpCompat("transpose2"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("XShape")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("axis")
.IsType<std::vector<int>>()
.End();
}
void ShuffleChannelMKLDNNDetectPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "shufflechannel_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* x = gpd.mutable_pattern()
->NewNode("x")
->assert_is_op_input("reshape2", "X")
->AsInput();
patterns::ShuffleChannelPattern pattern(gpd.mutable_pattern(), pattern_name);
pattern(x);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "The Pass in op compat failed.";
return;
}
PADDLE_ENFORCE_GT(
subgraph.count(x), 0,
platform::errors::NotFound("Detector did not find input X."));
auto* input_node = subgraph.at(x);
auto reshape1_desc = reshape1_op->Op();
auto reshape2_desc = reshape2_op->Op();
auto trans_desc = transpose_op->Op();
std::string input_name = input_node->Name();
std::string output_name = reshape2_out->Name();
auto reshape1_shape =
BOOST_GET_CONST(std::vector<int>, reshape1_desc->GetAttr("shape"));
auto reshape2_shape =
BOOST_GET_CONST(std::vector<int>, reshape2_desc->GetAttr("shape"));
auto trans_axis =
BOOST_GET_CONST(std::vector<int>, trans_desc->GetAttr("axis"));
auto* block1 = reshape1_desc->Block();
auto* block2 = reshape2_desc->Block();
if (block1 && block2) {
auto x_var_name = reshape1_desc->Input("X")[0];
auto* x_var_desc = block1->FindVar(x_var_name);
auto x_shape1 = x_var_desc->GetShape();
x_var_name = reshape2_desc->Input("X")[0];
x_var_desc = block2->FindVar(x_var_name);
auto x_shape2 = x_var_desc->GetShape();
// now shuffle_channel is 4D(NCHW) only.
if (x_shape1.size() != 4 || reshape1_shape.size() != 5 ||
reshape2_shape.size() != 4 || trans_axis.size() != 5) {
return;
}
// process 0 and -1 in reshape.
constexpr int64_t copy_dim_val = 0;
for (size_t i = 0; i < reshape1_shape.size(); i++) {
if (reshape1_shape[i] == copy_dim_val) {
reshape1_shape[i] = x_shape1[i];
}
}
for (size_t i = 0; i < reshape2_shape.size(); i++) {
if (reshape2_shape[i] == copy_dim_val) {
reshape2_shape[i] = x_shape2[i];
}
}
constexpr int64_t unk_dim_idx = -1;
bool all_positive = std::all_of(x_shape1.cbegin(), x_shape1.cend(),
[](int64_t i) { return i > 0; });
for (size_t i = 0; i < reshape1_shape.size(); ++i) {
// if -1 is not in batch dim, try to calculate number
if ((reshape1_shape[i] == unk_dim_idx) && (i != 0)) {
// there is no sufficient info
if (!all_positive) return;
reshape1_shape[i] =
std::accumulate(x_shape1.begin(), x_shape1.end(),
static_cast<int64_t>(1),
std::multiplies<int64_t>()) /
std::accumulate(reshape1_shape.begin(), reshape1_shape.end(),
static_cast<int64_t>(-1),
std::multiplies<int64_t>());
break;
}
}
all_positive = std::all_of(x_shape2.cbegin(), x_shape2.cend(),
[](int64_t i) { return i > 0; });
for (size_t i = 0; i < reshape2_shape.size(); ++i) {
// if -1 is not in batch dim, try to calculate number
if ((reshape2_shape[i] == unk_dim_idx) && (i != 0)) {
// there is no sufficient info
if (!all_positive) return;
reshape2_shape[i] =
std::accumulate(x_shape2.begin(), x_shape2.end(),
static_cast<int64_t>(1),
std::multiplies<int64_t>()) /
std::accumulate(reshape2_shape.begin(), reshape2_shape.end(),
static_cast<int64_t>(-1),
std::multiplies<int64_t>());
break;
}
}
// shuffle_channel dosen't change shape
if ((reshape2_shape[0] != -1) && (x_shape1[0] != reshape2_shape[0])) {
return;
}
for (size_t i = 1; i < x_shape1.size(); i++) {
if (x_shape1[i] != reshape2_shape[i]) {
return;
}
}
if ((reshape2_shape[3] != reshape1_shape[4]) ||
(reshape2_shape[2] != reshape1_shape[3])) {
return;
}
} else {
return; // conservative judgement
}
int i_c = reshape1_shape[2];
int o_c = reshape2_shape[1];
int group = o_c / i_c;
// should split on channel dim
if (reshape2_shape[1] != reshape1_shape[2] * reshape1_shape[1]) return;
// trans on channel dim
if (trans_axis[0] != 0 || trans_axis[3] != 3 || trans_axis[4] != 4) return;
if (group != 1 && i_c != 1) {
if (trans_axis[1] != 2 && trans_axis[2] != 1) {
return;
}
}
framework::OpDesc new_op_desc;
new_op_desc.SetType("shuffle_channel");
new_op_desc.SetInput("X", {input_name});
new_op_desc.SetOutput("Out", {output_name});
new_op_desc.SetAttr("group", group);
new_op_desc.SetAttr("use_mkldnn", true);
new_op_desc.Flush();
// Create a new node for the fused op.
auto* new_op = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(input_node, new_op);
IR_NODE_LINK_TO(new_op, reshape2_out);
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, {reshape1_op, reshape1_out, transpose_op,
transpose_out, reshape2_op});
LOG_FIRST_N(WARNING, 1)
<< "There is fluid.layers.shuffle_channel API already, maybe you can "
"use it instead of (reshape + transpose + reshape)";
};
gpd(graph, handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(shuffle_channel_mkldnn_detect_pass,
paddle::framework::ir::ShuffleChannelMKLDNNDetectPass);
REGISTER_PASS_CAPABILITY(shuffle_channel_mkldnn_detect_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("reshape2", 0)
.EQ("transpose2", 0));
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class ShuffleChannelMKLDNNDetectPass : public FusePassBase {
public:
ShuffleChannelMKLDNNDetectPass();
virtual ~ShuffleChannelMKLDNNDetectPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -298,6 +298,7 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -298,6 +298,7 @@ void CpuPassStrategy::EnableMKLDNN() {
// "fc_act_mkldnn_fuse_pass", // "fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass", // "batch_norm_act_fuse_pass", //
"softplus_activation_mkldnn_fuse_pass", // "softplus_activation_mkldnn_fuse_pass", //
"shuffle_channel_mkldnn_detect_pass", //
"elt_act_mkldnn_fuse_pass", // "elt_act_mkldnn_fuse_pass", //
// TODO(intel): Please fix the bug on windows. // TODO(intel): Please fix the bug on windows.
// https://github.com/PaddlePaddle/Paddle/issues/29710 // https://github.com/PaddlePaddle/Paddle/issues/29710
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using framework::Tensor;
using platform::MKLDNNGetDataType;
template <typename T>
class ShuffleChannelMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::shuffle_forward> {
public:
ShuffleChannelMKLDNNHandler(const Tensor* x, const int group,
const dnnl::engine engine,
platform::Place cpu_place)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::shuffle_forward>(engine,
cpu_place) {
static constexpr int channel_axis = 1;
const auto md = dnnl::memory::desc(phi::vectorize(x->dims()),
MKLDNNGetDataType<T>(), x->format());
this->AcquireForwardPrimitiveDescriptor(dnnl::prop_kind::forward_training,
md, channel_axis, group);
}
};
template <typename T>
class ShuffleChannelMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
// oneDNN handles group using C/g instead of g
const int group = x->dims()[1] / ctx.Attr<int>("group");
ShuffleChannelMKLDNNHandler<T> handler(x, group, mkldnn_engine,
ctx.GetPlace());
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out);
auto shuffle_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
shuffle_p->execute(astream, {{DNNL_ARG_SRC, *src_memory_p},
{DNNL_ARG_DST, *dst_memory_p}});
astream.wait();
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(x->format());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(shuffle_channel, MKLDNN, paddle::platform::CPUPlace,
ops::ShuffleChannelMKLDNNKernel<float>,
ops::ShuffleChannelMKLDNNKernel<paddle::platform::bfloat16>);
...@@ -35,9 +35,17 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { ...@@ -35,9 +35,17 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "X"), framework::OperatorWithKernel::IndicateVarDataType(ctx, "X");
ctx.device_context());
#ifdef PADDLE_WITH_MKLDNN
if (this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
...@@ -56,6 +64,10 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -56,6 +64,10 @@ class ShuffleChannelOpMaker : public framework::OpProtoAndCheckerMaker {
PADDLE_ENFORCE_GE(group, 1, platform::errors::InvalidArgument( PADDLE_ENFORCE_GE(group, 1, platform::errors::InvalidArgument(
"group should be larger than 0.")); "group should be larger than 0."));
}); });
AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel")
.SetDefault(false)
.AsExtra();
AddComment(R"DOC( AddComment(R"DOC(
Shuffle Channel operator Shuffle Channel operator
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import MkldnnAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
from functools import partial
import unittest
from hypothesis import given
import hypothesis.strategies as st
class TestMKLDNNShuffleChannelOp(MkldnnAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
def sample_program_configs(self, *args, **kwargs):
def generate_input(*args, **kwargs):
return np.random.random(kwargs['in_shape']).astype(np.float32)
shuffle_channel_op = OpConfig(
type="shuffle_channel",
inputs={"X": ["input_data"]},
outputs={"Out": ["output_data"]},
attrs={"group": kwargs['group']})
program_config = ProgramConfig(
ops=[shuffle_channel_op],
weights={},
inputs={
"input_data": TensorConfig(data_gen=partial(generate_input,
*args, **kwargs)),
},
outputs=["output_data"])
yield program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, (1e-5, 1e-5)
@given(
group=st.sampled_from([1, 2, 8, 32, 128]),
in_shape=st.sampled_from([[5, 512, 2, 3], [2, 256, 5, 4]]))
def test(self, *args, **kwargs):
self.run_test(quant=False, *args, **kwargs)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
@OpTestTool.skip_if_not_cpu_bf16()
class TestShuffleChannelOneDNNOp(OpTest):
def setUp(self):
self.op_type = "shuffle_channel"
self.set_dtype()
self.set_group()
self.inputs = {'X': np.random.random((5, 64, 2, 3)).astype(self.dtype)}
self.attrs = {'use_mkldnn': True, 'group': self.group}
_, c, h, w = self.inputs['X'].shape
input_reshaped = np.reshape(self.inputs['X'],
(-1, self.group, c // self.group, h, w))
input_transposed = np.transpose(input_reshaped, (0, 2, 1, 3, 4))
self.outputs = {'Out': np.reshape(input_transposed, (-1, c, h, w))}
def set_dtype(self):
self.dtype = np.float32
def set_group(self):
self.group = 4
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
class TestShuffleChannelSingleGroupOneDNNOp(TestShuffleChannelOneDNNOp):
def set_group(self):
self.group = 1
class TestShuffleChannelBF16OneDNNOp(TestShuffleChannelOneDNNOp):
def set_dtype(self):
self.dtype = np.uint16
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册