未验证 提交 0695fb88 编写于 作者: Z zhupengyang 提交者: GitHub

delete useless cast, elementwise_mul (#52831)

上级 f4ae3737
......@@ -127,6 +127,7 @@ pass_library(gpu_cpu_map_matmul_to_mul_pass inference)
pass_library(dense_fc_to_sparse_pass inference)
pass_library(dense_multihead_matmul_to_sparse_pass inference)
pass_library(delete_cast_op_pass inference)
pass_library(delete_elementwise_mul_op_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)
......
......@@ -505,6 +505,122 @@ int DeleteCastOpPass::ApplyCastIndexSamplePass(ir::Graph* graph) const {
return found_subgraph_count;
}
namespace patterns {
struct CastScatterPattern : public PatternBase {
CastScatterPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(scatter);
PATTERN_DECL_NODE(cast0);
PATTERN_DECL_NODE(cast1);
PATTERN_DECL_NODE(cast2);
// declare variable node's name
PATTERN_DECL_NODE(cast0_in);
PATTERN_DECL_NODE(cast0_out);
PATTERN_DECL_NODE(cast1_in);
PATTERN_DECL_NODE(cast1_out);
PATTERN_DECL_NODE(scatter_out);
PATTERN_DECL_NODE(cast2_out);
};
CastScatterPattern::CastScatterPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* cast0_in = pattern->NewNode(cast0_in_repr())
->assert_is_op_input("cast", "X")
->assert_has_n_outputs(1);
auto* cast0 =
pattern->NewNode(cast0_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::FP16) &&
out_dtype == static_cast<int>(proto::VarType::FP32);
});
auto* cast0_out = pattern->NewNode(cast0_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("scatter", "X")
->assert_has_n_outputs(1);
auto* cast1_in = pattern->NewNode(cast1_in_repr())
->assert_is_op_input("cast", "X")
->assert_has_n_outputs(1);
auto* cast1 =
pattern->NewNode(cast1_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::FP16) &&
out_dtype == static_cast<int>(proto::VarType::FP32);
});
auto* cast1_out = pattern->NewNode(cast1_out_repr())
->assert_is_op_output("cast", "Out")
->assert_is_op_input("scatter", "Updates")
->assert_has_n_outputs(1);
auto* scatter = pattern->NewNode(scatter_repr())->assert_is_op("scatter");
auto* scatter_out = pattern->NewNode(scatter_out_repr())
->assert_is_op_output("scatter", "Out")
->assert_is_op_input("cast", "X")
->assert_has_n_outputs(1);
auto* cast2 =
pattern->NewNode(cast2_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::FP32) &&
out_dtype == static_cast<int>(proto::VarType::FP16);
});
auto* cast2_out =
pattern->NewNode(cast2_out_repr())->assert_is_op_output("cast", "Out");
cast0->LinksFrom({cast0_in}).LinksTo({cast0_out});
cast1->LinksFrom({cast1_in}).LinksTo({cast1_out});
scatter->LinksFrom({cast0_out, cast1_out}).LinksTo({scatter_out});
cast2->LinksFrom({scatter_out}).LinksTo({cast2_out});
}
} // namespace patterns
int DeleteCastOpPass::ApplyCastScatterPass(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::CastScatterPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastScatterPass fuse";
GET_IR_NODE_FROM_SUBGRAPH(scatter, scatter, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0, cast0, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast1, cast1, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast2, cast2, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_in, cast0_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast0_out, cast0_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast1_in, cast1_in, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast1_out, cast1_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(scatter_out, scatter_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast2_out, cast2_out, pattern);
scatter->Op()->RenameInput(cast0_out->Name(), cast0_in->Name());
scatter->Op()->RenameInput(cast1_out->Name(), cast1_in->Name());
scatter->Op()->RenameOutput(scatter_out->Name(), cast2_out->Name());
IR_NODE_LINK_TO(cast0_in, scatter);
IR_NODE_LINK_TO(cast1_in, scatter);
IR_NODE_LINK_TO(scatter, cast2_out);
std::unordered_set<const Node*> delete_nodes{
cast0, cast1, cast2, cast0_out, cast1_out, scatter_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
namespace patterns {
struct CastPattern : public PatternBase {
CastPattern(PDPattern* pattern, const std::string& name_scope);
......@@ -591,6 +707,15 @@ void DeleteCastOpPass::ApplyImpl(ir::Graph* graph) const {
<< " cast_index_sample_cast subgraph";
}
found_subgraph_count = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
found_subgraph_count += ApplyCastScatterPass(graph->GetSubGraph(i));
}
if (found_subgraph_count > 0) {
LOG(INFO) << "--- delete " << found_subgraph_count
<< " cast_scatter_cast subgraph";
}
found_subgraph_count = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
found_subgraph_count += ApplyCastPass(graph->GetSubGraph(i));
......
......@@ -111,7 +111,20 @@ class DeleteCastOpPass : public FusePassBase {
*/
int ApplyCastIndexSamplePass(ir::Graph* graph) const;
// Delete cast if its "in_dtype" is the same with "out_dtype"
/*
Origin subgraph:
cast(fp16->fp32) cast(fp16->fp32)
\ /
scatter
|
cast(fp32->fp16)
Optimized subgraph:
scatter
*/
int ApplyCastScatterPass(ir::Graph* graph) const;
// Delete cast if its "in_dtype" is the same as "out_dtype"
int ApplyCastPass(ir::Graph* graph) const;
const std::string name_scope_{"delete_cast_op_pass"};
......
......@@ -226,6 +226,35 @@ TEST(ApplyCastIndexSamplePass, basic) {
cast_num_in_graph));
}
TEST(ApplyCastScatterPass, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
auto* cast0_in = Data(block, "cast0_in", {1});
auto* cast0_out = AddCast(block, cast0_in, 4, 5);
auto* cast1_in = Data(block, "cast1_in", {1});
auto* cast1_out = AddCast(block, cast1_in, 4, 5);
auto* scatter_out = Data(block, "scatter_out", {1});
OpDesc* scatter = block->AppendOp();
scatter->SetType("scatter");
scatter->SetInput("X", {cast0_out->Name()});
scatter->SetInput("Updates", {cast1_out->Name()});
scatter->SetOutput("Out", {scatter_out->Name()});
AddCast(block, scatter_out, 5, 4);
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("delete_cast_op_pass");
pass->Apply(graph.get());
int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast");
PADDLE_ENFORCE_EQ(GetOpNum(graph->GetSubGraph(0), "cast"),
0,
platform::errors::PreconditionNotMet(
"graph should have 0 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph));
}
TEST(ApplyCastPass, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct FillMulPattern : public PatternBase {
FillMulPattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(fill);
PATTERN_DECL_NODE(mul);
// declare variable node's name
PATTERN_DECL_NODE(fill_out);
PATTERN_DECL_NODE(mul_in);
PATTERN_DECL_NODE(mul_out);
};
FillMulPattern::FillMulPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* fill = pattern->NewNode(fill_repr())
->assert_is_op("fill_constant_batch_size_like")
->assert_more([](Node* node) {
float value = node->Op()->GetAttrIfExists<float>("value");
return fabs(value - 1.f) < 1e-5;
});
auto* fill_out =
pattern->NewNode(fill_out_repr())
->assert_is_op_output("fill_constant_batch_size_like", "Out")
->assert_has_n_outputs(1);
auto* mul_in = pattern->NewNode(mul_in_repr());
auto* mul = pattern->NewNode(mul_repr())->assert_is_op("elementwise_mul");
auto* mul_out = pattern->NewNode(mul_out_repr())
->assert_is_op_output("elementwise_mul", "Out");
fill->LinksTo({fill_out});
mul->LinksFrom({fill_out, mul_in}).LinksTo({mul_out});
}
} // namespace patterns
/*
Delete "elementwise" if one of inputs is "1".
*/
class DeleteElementwiseMulOpPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
const std::string name_scope_{"delete_elementwise_mul_op_pass"};
};
void DeleteElementwiseMulOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
GraphPatternDetector gpd;
patterns::FillMulPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle DeleteElementwiseMulOpPass fuse";
#define GET_IR_NODE(node_) GET_IR_NODE_FROM_SUBGRAPH(node_, node_, pattern)
GET_IR_NODE(fill);
GET_IR_NODE(mul);
GET_IR_NODE(fill_out);
GET_IR_NODE(mul_in);
GET_IR_NODE(mul_out);
#undef GET_IR_NODE
for (auto* next_op : mul_out->outputs) {
next_op->Op()->RenameInput(mul_out->Name(), mul_in->Name());
IR_NODE_LINK_TO(mul_in, next_op);
}
std::unordered_set<const Node*> delete_nodes{fill, mul, fill_out, mul_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
AddStatis(found_subgraph_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_elementwise_mul_op_pass,
paddle::framework::ir::DeleteElementwiseMulOpPass);
REGISTER_PASS_CAPABILITY(delete_elementwise_mul_op_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"fill_constant_batch_size_like", 0));
......@@ -57,6 +57,7 @@ static const std::vector<std::string> xpu_support_subgraph_passes = {
"identity_scale_op_clean_pass",
"delete_op_device_pass",
"constant_folding_pass",
"delete_elementwise_mul_op_pass",
"generate_sequence_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
......
......@@ -524,6 +524,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"identity_scale_op_clean_pass",
"delete_op_device_pass",
"constant_folding_pass",
"delete_elementwise_mul_op_pass",
"generate_sequence_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass",
"multi_encoder_xpu_fuse_pass",
......
......@@ -27,9 +27,12 @@ void ScatterKernel(const Context &ctx,
const DenseTensor &updates,
bool overwrite,
DenseTensor *out) {
using XPUTypeT = typename XPUTypeTrait<T>::Type;
out->Resize(x.dims());
ctx.template Alloc<T>(out);
int ret = xpu::copy(ctx.x_context(), x.data<T>(), out->data<T>(), x.numel());
auto *x_data = reinterpret_cast<const XPUTypeT *>(x.data<T>());
auto *updates_data = reinterpret_cast<const XPUTypeT *>(updates.data<T>());
auto *out_data = reinterpret_cast<XPUTypeT *>(ctx.template Alloc<T>(out));
int ret = xpu::copy(ctx.x_context(), x_data, out_data, x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(ret, "copy");
// Apply ScatterUpdate: Out[index] = Updates[:]
const auto &index_type = index.dtype();
......@@ -78,8 +81,6 @@ void ScatterKernel(const Context &ctx,
int dim0 = static_cast<int>(x.dims()[0]);
int dim1 =
static_cast<int>(phi::product(phi::slice_ddim(x_dims, 1, x_dims.size())));
T *out_data = out->data<T>();
const T *updates_data = updates.data<T>();
DenseTensor indices_cpu(index.type());
phi::Copy(ctx, index, phi::CPUPlace(), false, &indices_cpu);
......@@ -113,5 +114,11 @@ void ScatterKernel(const Context &ctx,
} // namespace phi
PD_REGISTER_KERNEL(
scatter, XPU, ALL_LAYOUT, phi::ScatterKernel, float, int, int64_t) {}
PD_REGISTER_KERNEL(scatter,
XPU,
ALL_LAYOUT,
phi::ScatterKernel,
float,
int,
int64_t,
phi::dtype::float16) {}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import hypothesis.strategies as st
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestDeleteElementwiseMulOpPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ["relu"], (1e-5, 1e-5)
def sample_program_config(self, draw):
x_shape = draw(
st.lists(
st.integers(min_value=1, max_value=20), min_size=2, max_size=2
)
)
fill_op = OpConfig(
"fill_constant_batch_size_like",
inputs={
"Input": ["fill_x"],
},
shape=[-1, 1],
input_dim_idx=0,
output_dim_idx=0,
dtype=5,
value=1.0,
str_value="1",
force_cpu=False,
outputs={"Out": ["fill_out"]},
)
mul_op = OpConfig(
"elementwise_mul",
inputs={"X": ["fill_out"], "Y": ["mul_in"]},
axis=0,
outputs={"Out": ["mul_out"]},
)
relu_op = OpConfig(
"relu",
inputs={
"X": ["mul_out"],
},
outputs={"Out": ["relu_out"]},
)
ops = [fill_op, mul_op, relu_op]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"fill_x": TensorConfig(shape=x_shape),
"mul_in": TensorConfig(shape=x_shape),
},
outputs=ops[-1].outputs["Out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["delete_elementwise_mul_op_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册