未验证 提交 7aafeb45 编写于 作者: Z zhupengyang 提交者: GitHub

delete cast if lookup_table_v2 support fp16; delete repeated ops (#52888)

上级 64b4aaba
......@@ -128,6 +128,7 @@ pass_library(dense_fc_to_sparse_pass inference)
pass_library(dense_multihead_matmul_to_sparse_pass inference)
pass_library(delete_cast_op_pass inference)
pass_library(delete_elementwise_mul_op_pass inference)
pass_library(delete_repeated_ops_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)
......
......@@ -19,6 +19,8 @@
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/kernels/assign_kernel.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace phi {
class DenseTensor;
......@@ -623,6 +625,93 @@ int DeleteCastOpPass::ApplyCastScatterPass(ir::Graph* graph) const {
return found_subgraph_count;
}
namespace patterns {
struct CastLookupTablePattern : public PatternBase {
CastLookupTablePattern(PDPattern* pattern, const std::string& name_scope);
// declare operator node's name
PATTERN_DECL_NODE(lookup_table);
PATTERN_DECL_NODE(cast);
// declare variable node's name
PATTERN_DECL_NODE(lookup_table_w);
PATTERN_DECL_NODE(lookup_table_out);
PATTERN_DECL_NODE(cast_out);
};
CastLookupTablePattern::CastLookupTablePattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* lookup_table_w = pattern->NewNode(lookup_table_w_repr())
->assert_is_op_input("lookup_table_v2", "W")
->assert_is_persistable_var();
auto* lookup_table =
pattern->NewNode(lookup_table_repr())->assert_is_op("lookup_table_v2");
auto* lookup_table_out = pattern->NewNode(lookup_table_out_repr())
->assert_is_op_output("lookup_table_v2", "Out")
->assert_is_op_input("cast", "X")
->assert_has_n_outputs(1);
auto* cast =
pattern->NewNode(cast_repr())
->assert_is_op("cast")
->assert_more([](Node* node) {
auto* op_desc = node->Op();
auto in_dtype = op_desc->GetAttrIfExists<int>("in_dtype");
auto out_dtype = op_desc->GetAttrIfExists<int>("out_dtype");
return in_dtype == static_cast<int>(proto::VarType::FP32) &&
out_dtype == static_cast<int>(proto::VarType::FP16);
});
auto* cast_out =
pattern->NewNode(cast_out_repr())->assert_is_op_output("cast", "Out");
lookup_table->LinksFrom({lookup_table_w}).LinksTo({lookup_table_out});
cast->LinksFrom({lookup_table_out}).LinksTo({cast_out});
}
} // namespace patterns
int DeleteCastOpPass::ApplyCastLookupTablePass(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::CastLookupTablePattern pattern(gpd.mutable_pattern(), name_scope_);
int found_subgraph_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle ApplyCastLookupTablePass fuse";
GET_IR_NODE_FROM_SUBGRAPH(lookup_table, lookup_table, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast, cast, pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table_w, lookup_table_w, pattern);
GET_IR_NODE_FROM_SUBGRAPH(lookup_table_out, lookup_table_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast_out, cast_out, pattern);
auto* scope = param_scope();
auto* w_tensor =
scope->Var(lookup_table_w->Name())->GetMutable<phi::DenseTensor>();
lookup_table_w->Var()->SetDataType(proto::VarType::FP16);
if (w_tensor->dtype() != phi::DataType::FLOAT16) {
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
phi::DenseTensor w_fp32_tensor;
w_fp32_tensor.Resize(w_tensor->dims());
w_fp32_tensor.set_type(w_tensor->dtype());
phi::AssignKernel(*cpu_ctx, *w_tensor, &w_fp32_tensor);
w_tensor->set_type(phi::DataType::FLOAT16);
phi::CastKernel<float>(
*cpu_ctx, w_fp32_tensor, phi::DataType::FLOAT16, w_tensor);
}
for (auto* next_op : cast_out->outputs) {
next_op->Op()->RenameInput(cast_out->Name(), lookup_table_out->Name());
IR_NODE_LINK_TO(lookup_table_out, next_op);
}
std::unordered_set<const Node*> delete_nodes{cast, cast_out};
GraphSafeRemoveNodes(graph, delete_nodes);
found_subgraph_count++;
};
gpd(graph, handler);
return found_subgraph_count;
}
namespace patterns {
struct CastPattern : public PatternBase {
CastPattern(PDPattern* pattern, const std::string& name_scope);
......@@ -718,6 +807,15 @@ void DeleteCastOpPass::ApplyImpl(ir::Graph* graph) const {
<< " cast_scatter_cast subgraph";
}
found_subgraph_count = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
found_subgraph_count += ApplyCastLookupTablePass(graph->GetSubGraph(i));
}
if (found_subgraph_count > 0) {
LOG(INFO) << "--- delete " << found_subgraph_count
<< " lookup_table_cast subgraph";
}
found_subgraph_count = 0;
for (size_t i = 0; i < graph->SubGraphsSize(); i++) {
found_subgraph_count += ApplyCastPass(graph->GetSubGraph(i));
......
......@@ -124,6 +124,21 @@ class DeleteCastOpPass : public FusePassBase {
*/
int ApplyCastScatterPass(ir::Graph* graph) const;
/*
Origin subgraph:
ids w(fp32)
\ /
lookup_table
|
cast(fp32->fp16)
Optimized subgraph:
ids w(fp16)
\ /
lookup_table
*/
int ApplyCastLookupTablePass(ir::Graph* graph) const;
// Delete cast if its "in_dtype" is the same as "out_dtype"
int ApplyCastPass(ir::Graph* graph) const;
......
......@@ -20,6 +20,16 @@ namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
auto* cpu_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(phi::CPUPlace()));
cpu_ctx->Alloc<float>(tensor);
}
VarDesc* Data(paddle::framework::BlockDesc* block,
std::string name,
std::vector<int64_t> shape = {},
......@@ -255,6 +265,36 @@ TEST(ApplyCastScatterPass, basic) {
cast_num_in_graph));
}
TEST(ApplyCastLookupTablePass, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
auto* lookup_table_w = Data(block, "lookup_table_w", {1}, true);
auto* lookup_table_out = Data(block, "scatter_out", {1});
OpDesc* lookup_table = block->AppendOp();
lookup_table->SetType("lookup_table_v2");
lookup_table->SetInput("W", {lookup_table_w->Name()});
lookup_table->SetOutput("Out", {lookup_table_out->Name()});
auto* cast_out = AddCast(block, lookup_table_out, 5, 4);
OpDesc* relu = block->AppendOp();
relu->SetType("relu");
relu->SetInput("X", {cast_out->Name()});
relu->SetOutput("Out", {"relu_out"});
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
AddVarToScope(scope, lookup_table_w->Name(), {1});
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("delete_cast_op_pass");
pass->Apply(graph.get());
int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast");
PADDLE_ENFORCE_EQ(GetOpNum(graph->GetSubGraph(0), "cast"),
0,
platform::errors::PreconditionNotMet(
"graph should have 0 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph));
}
TEST(ApplyCastPass, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
namespace patterns {
struct VarWithRepeatedOpsPattern : public PatternBase {
VarWithRepeatedOpsPattern(PDPattern* pattern,
const std::string& name_scope,
const std::string& op_type);
// declare variable node's name
PATTERN_DECL_NODE(in_var);
std::string op_type_;
};
VarWithRepeatedOpsPattern::VarWithRepeatedOpsPattern(
PDPattern* pattern,
const std::string& name_scope,
const std::string& op_type)
: PatternBase(pattern, name_scope, name_scope), op_type_(op_type) {
pattern->NewNode(in_var_repr())
->assert_is_var()
->assert_more([&](Node* node) {
auto out_nodes = node->outputs;
if (out_nodes.size() <= 1) return false;
int op_counts = 0;
for (auto* next_op : out_nodes) {
if (next_op->Name() == op_type_) {
op_counts++;
}
}
return op_counts > 1;
});
}
} // namespace patterns
/*
Delete repeated ops, for example:
Origin subgraph:
(input_variable)
/ | \ ...
shape shape shape ...
| | | ...
op0 op1 op2 ...
Optimized subgraph:
(input_variable)
|
shape
/ | \ ...
op0 op1 op2 ...
*/
class DeleteRepeatedOpsPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
int DeleteShapePass(ir::Graph* graph) const;
int DeleteSlicePass(ir::Graph* graph) const;
const std::string name_scope_{"delete_repeated_ops_pass"};
};
int DeleteRepeatedOpsPass::DeleteShapePass(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::VarWithRepeatedOpsPattern pattern(
gpd.mutable_pattern(), name_scope_, "shape");
int delete_counts = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle DeleteShapePass";
GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);
std::vector<Node*> shapes;
for (auto* next_op : in_var->outputs) {
if (next_op->Name() != "shape") continue;
bool shape_out_has_control_flow_ops = false;
for (auto* shape_out_op : next_op->outputs[0]->outputs) {
if (shape_out_op->Name() == "while" ||
shape_out_op->Name() == "conditional_block") {
shape_out_has_control_flow_ops = true;
break;
}
}
if (!shape_out_has_control_flow_ops) {
shapes.push_back(next_op);
}
}
if (shapes.size() <= 1) return;
auto* first_shape_out = shapes[0]->outputs[0];
auto first_shape_out_name = first_shape_out->Name();
std::unordered_set<const Node*> delete_nodes;
for (size_t i = 1; i < shapes.size(); i++) {
auto* cur_shape = shapes[i];
auto* cur_shape_out = cur_shape->outputs[0];
auto cur_shape_out_name = cur_shape_out->Name();
for (auto* shape_out_op : cur_shape_out->outputs) {
shape_out_op->Op()->Rename(cur_shape_out_name, first_shape_out_name);
IR_NODE_LINK_TO(first_shape_out, shape_out_op);
}
delete_nodes.insert(cur_shape);
delete_nodes.insert(cur_shape_out);
delete_counts++;
}
GraphSafeRemoveNodes(graph, delete_nodes);
};
gpd(graph, handler);
return delete_counts;
}
std::string GenSliceAttrKey(OpDesc* slice_op_desc) {
std::string attr_key;
auto starts = slice_op_desc->GetAttrIfExists<std::vector<int>>("starts");
auto ends = slice_op_desc->GetAttrIfExists<std::vector<int>>("ends");
auto axes = slice_op_desc->GetAttrIfExists<std::vector<int>>("axes");
attr_key += "starts_";
for (auto start : starts) {
attr_key += std::to_string(start) + "_";
}
attr_key += "ends_";
for (auto end : ends) {
attr_key += std::to_string(end) + "_";
}
attr_key += "axes_";
for (auto axis : axes) {
attr_key += std::to_string(axis) + "_";
}
return attr_key;
}
int DeleteRepeatedOpsPass::DeleteSlicePass(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::VarWithRepeatedOpsPattern pattern(
gpd.mutable_pattern(), name_scope_, "slice");
int delete_counts = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* graph) {
VLOG(4) << "handle DeleteSlicePass";
GET_IR_NODE_FROM_SUBGRAPH(in_var, in_var, pattern);
std::map<std::string, std::vector<Node*>> slice_ops;
for (auto* next_op : in_var->outputs) {
if (next_op->Name() != "slice") continue;
auto* slice = next_op;
bool slice_out_has_control_flow_ops = false;
for (auto* slice_out_op : slice->outputs[0]->outputs) {
if (slice_out_op->Name() == "while" ||
slice_out_op->Name() == "conditional_block") {
slice_out_has_control_flow_ops = true;
break;
}
}
if (slice_out_has_control_flow_ops) continue;
auto attr_key = GenSliceAttrKey(slice->Op());
slice_ops[attr_key].push_back(slice);
}
for (auto iter = slice_ops.begin(); iter != slice_ops.end();) {
if (iter->second.size() <= 1) {
iter = slice_ops.erase(iter);
} else {
iter++;
}
}
for (auto iter : slice_ops) {
auto slices = iter.second;
auto* first_slice_out = slices[0]->outputs[0];
auto first_slice_out_name = first_slice_out->Name();
std::unordered_set<const Node*> delete_nodes;
for (size_t i = 1; i < slices.size(); i++) {
auto* cur_slice = slices[i];
auto* cur_slice_out = cur_slice->outputs[0];
auto cur_slice_out_name = cur_slice_out->Name();
for (auto* slice_out_op : cur_slice_out->outputs) {
slice_out_op->Op()->Rename(cur_slice_out_name, first_slice_out_name);
IR_NODE_LINK_TO(first_slice_out, slice_out_op);
}
delete_nodes.insert(cur_slice);
delete_nodes.insert(cur_slice_out);
delete_counts++;
}
GraphSafeRemoveNodes(graph, delete_nodes);
}
};
gpd(graph, handler);
return delete_counts;
}
void DeleteRepeatedOpsPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::PreconditionNotMet("graph should not be null."));
Init(name_scope_, graph);
int delete_counts = DeleteShapePass(graph);
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated shape ops";
}
delete_counts = DeleteSlicePass(graph);
if (delete_counts > 0) {
LOG(INFO) << "--- delete " << delete_counts << " repeated slice ops";
}
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_repeated_ops_pass,
paddle::framework::ir::DeleteRepeatedOpsPass);
REGISTER_PASS_CAPABILITY(delete_repeated_ops_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination().EQ(
"shape", 0));
......@@ -512,6 +512,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"delete_dropout_op_pass",
"delete_concat_op_pass",
"identity_scale_op_clean_pass",
"delete_repeated_ops_pass",
"delete_op_device_pass",
"constant_folding_pass",
"delete_elementwise_mul_op_pass",
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import hypothesis.strategies as st
from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig
class TestDeleteRepeatedShapePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ['shape', 'cast', 'cast', 'cast'], (1e-5, 1e-5)
def sample_program_config(self, draw):
x_shape = draw(
st.lists(
st.integers(min_value=1, max_value=20), min_size=2, max_size=4
)
)
shape_op0 = OpConfig(
"shape",
inputs={
"Input": ["shape_x"],
},
outputs={"Out": ["shape0_out"]},
)
cast_op0 = OpConfig(
"cast",
inputs={
"X": ["shape0_out"],
},
in_dtype=2,
out_dtype=5,
outputs={"Out": ["cast0_out"]},
)
shape_op1 = OpConfig(
"shape",
inputs={
"Input": ["shape_x"],
},
outputs={"Out": ["shape1_out"]},
)
cast_op1 = OpConfig(
"cast",
inputs={
"X": ["shape1_out"],
},
in_dtype=2,
out_dtype=5,
outputs={"Out": ["cast1_out"]},
)
shape_op2 = OpConfig(
"shape",
inputs={
"Input": ["shape_x"],
},
outputs={"Out": ["shape2_out"]},
)
cast_op2 = OpConfig(
"cast",
inputs={
"X": ["shape2_out"],
},
in_dtype=2,
out_dtype=5,
outputs={"Out": ["cast2_out"]},
)
ops = [shape_op0, cast_op0, shape_op1, cast_op1, shape_op2, cast_op2]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"shape_x": TensorConfig(shape=x_shape),
},
outputs=["cast0_out", "cast1_out", "cast2_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["delete_repeated_ops_pass"],
)
class TestDeleteRepeatedSlicePass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_xpu=True)
yield config, ['slice', 'cast', 'cast', 'cast'], (1e-5, 1e-5)
def sample_program_config(self, draw):
slice_x = draw(
st.lists(
st.integers(min_value=1, max_value=20), min_size=2, max_size=4
)
)
slice_op0 = OpConfig(
"slice",
inputs={
"Input": ["slice_x"],
},
starts=[0],
ends=[1],
axes=[0],
decrease_axis=[0],
outputs={"Out": ["slice0_out"]},
)
cast_op0 = OpConfig(
"cast",
inputs={
"X": ["slice0_out"],
},
in_dtype=5,
out_dtype=5,
outputs={"Out": ["cast0_out"]},
)
slice_op1 = OpConfig(
"slice",
inputs={
"Input": ["slice_x"],
},
starts=[0],
ends=[1],
axes=[0],
decrease_axis=[0],
outputs={"Out": ["slice1_out"]},
)
cast_op1 = OpConfig(
"cast",
inputs={
"X": ["slice1_out"],
},
in_dtype=5,
out_dtype=5,
outputs={"Out": ["cast1_out"]},
)
slice_op2 = OpConfig(
"slice",
inputs={
"Input": ["slice_x"],
},
starts=[0],
ends=[1],
axes=[0],
decrease_axis=[0],
outputs={"Out": ["slice2_out"]},
)
cast_op2 = OpConfig(
"cast",
inputs={
"X": ["slice2_out"],
},
in_dtype=5,
out_dtype=5,
outputs={"Out": ["cast2_out"]},
)
ops = [slice_op0, cast_op0, slice_op1, cast_op1, slice_op2, cast_op2]
program_config = ProgramConfig(
ops=ops,
weights={},
inputs={
"slice_x": TensorConfig(shape=slice_x),
},
outputs=["cast0_out", "cast1_out", "cast2_out"],
)
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=25,
passes=["delete_repeated_ops_pass"],
)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册