未验证 提交 8b622d58 编写于 作者: Z zhupengyang 提交者: GitHub

[XPU] add delete_cast_op_pass (#52305)

上级 3e2d0195
......@@ -240,6 +240,7 @@ if(WITH_XPU)
pass_library(fused_multi_transformer_xpu_quant_pass inference DIR xpu DEPS
${XPU_PASS_DEPS})
pass_library(stack_fuse_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
pass_library(delete_cast_op_pass inference DIR xpu DEPS ${XPU_PASS_DEPS})
endif()
cc_library(
......@@ -518,4 +519,8 @@ if(WITH_XPU)
test_stack_fuse_pass
SRCS xpu/stack_fuse_pass_test.cc
DEPS stack_fuse_pass)
cc_test(
test_delete_cast_op_pass
SRCS xpu/delete_cast_op_pass_test.cc
DEPS delete_cast_op_pass)
endif()
此差异已折叠。
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {
class Scope;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace framework {
namespace ir {
class DeleteCastOpPass : public FusePassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
private:
/*
Origin subgraph:
main_graph: while subgraph:
write_to_array cast(fp16->fp32)
| |
(write_var:fp32) write_to_array
|
(write_var:fp32)
|
read_from_array
|
cast(fp32->fp16)
Optimized subgraph:
main_graph: while subgraph:
cast write_to_array
| |
write_to_array (write_var:fp16)
| |
(write_var:fp16) read_from_array
*/
int ApplyCastWriteReadPass(ir::Graph* graph) const;
/*
Origin subgraph:
main_graph: while subgraph:
write_to_array cast(fp16->fp32)
| |
(write_var:fp32) lod_reset
| |
while write_to_array
| |
(write_var:fp32) (write_var:fp32)
| |
beam_search_decode read_from_array
| |
(out_score:fp32) cast(fp32->fp16)
Optimized subgraph:
main_graph: while subgraph:
cast lod_reset
| |
write_to_array write_to_array
| |
(write_var:fp16) (write_var:fp16)
| |
while read_from_array
|
(write_var:fp16)
|
beam_search_decode
|
cast(fp16->fp32)
|
(out_score:fp32)
*/
int ApplyCastLodResetWriteReadPass(ir::Graph* graph) const;
/*
Origin subgraph:
cast(fp16->fp32)
|
index_sample
|
cast(fp32->fp16)
Optimized subgraph:
index_sample
*/
int ApplyCastIndexSamplePass(ir::Graph* graph) const;
// Delete cast if its "in_dtype" is the same with "out_dtype"
int ApplyCastPass(ir::Graph* graph) const;
const std::string name_scope_{"delete_cast_op_pass"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
VarDesc* Data(paddle::framework::BlockDesc* block,
std::string name,
std::vector<int64_t> shape = {},
bool is_persistable = false,
proto::VarType::Type data_type = proto::VarType::FP32) {
auto* var = block->Var(name);
var->SetType(proto::VarType::LOD_TENSOR);
var->SetDataType(data_type);
var->SetShape(shape);
var->SetPersistable(is_persistable);
return var;
}
VarDesc* AddWriteToArray(BlockDesc* block,
std::vector<VarDesc*> x,
VarDesc* i,
VarDesc* out = nullptr) {
if (out == nullptr) {
out = Data(block, x[0]->Name() + "_out");
}
OpDesc* op = block->AppendOp();
op->SetType("write_to_array");
std::vector<std::string> x_names;
for (auto k : x) {
x_names.push_back(k->Name());
}
op->SetInput("X", x_names);
op->SetInput("I", {i->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* AddReadFromArray(BlockDesc* block, VarDesc* x, VarDesc* i) {
auto* out = Data(block, x->Name() + "_out");
OpDesc* op = block->AppendOp();
op->SetType("read_from_array");
op->SetInput("X", {x->Name()});
op->SetInput("I", {i->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
VarDesc* AddCast(BlockDesc* block,
VarDesc* input,
int in_dtype = 5,
int out_dtype = 5) {
VarDesc* out = Data(block, input->Name() + "_out");
OpDesc* op = block->AppendOp();
op->SetType("cast");
op->SetInput("X", {input->Name()});
op->SetOutput("Out", {out->Name()});
op->SetAttr("in_dtype", in_dtype);
op->SetAttr("out_dtype", out_dtype);
return out;
}
VarDesc* AddLodReset(BlockDesc* block, VarDesc* input) {
VarDesc* out = Data(block, input->Name() + "_out");
OpDesc* op = block->AppendOp();
op->SetType("lod_reset");
op->SetInput("X", {input->Name()});
op->SetOutput("Out", {out->Name()});
return out;
}
std::vector<VarDesc*> AddBeamSearchDecode(BlockDesc* block,
VarDesc* ids,
VarDesc* scores) {
VarDesc* out_ids = Data(block, ids->Name() + "_out");
VarDesc* out_scores = Data(block, scores->Name() + "_out");
OpDesc* op = block->AppendOp();
op->SetType("beam_search_decode");
op->SetInput("Ids", {ids->Name()});
op->SetInput("Scores", {scores->Name()});
op->SetOutput("SentenceIds", {out_ids->Name()});
op->SetOutput("SentenceScores", {out_scores->Name()});
return {out_ids, out_scores};
}
int GetOpNum(Graph* graph, std::string op_type = "") {
int num_nodes = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op() &&
(node->Op()->Type() == op_type || op_type.empty())) {
num_nodes++;
}
}
return num_nodes;
}
TEST(ApplyCastWriteReadPass, basic) {
paddle::framework::ProgramDesc program;
auto* block0 = program.MutableBlock(0);
auto* block1 = program.AppendBlock(*block0);
auto* write_0_x = Data(block0, "write_0_x", {1});
auto* write_0_i = Data(block0, "write_0_i", {1});
auto* write_0_out = AddWriteToArray(block0, {write_0_x}, write_0_i);
OpDesc* while_loop = block0->AppendOp();
while_loop->SetType("while");
while_loop->SetInput("X", {write_0_out->Name()});
while_loop->SetOutput("Out", {write_0_out->Name()});
auto* cast_1_0_in = Data(block1, "cast_1_0", {1});
auto* cast_1_0_out = AddCast(block1, cast_1_0_in, 4, 5);
auto* write_1_i = Data(block1, "write_1_i", {1});
auto* write_1_out = Data(block1, write_0_out->Name(), {1});
AddWriteToArray(block1, {cast_1_0_out}, write_1_i, write_1_out);
auto* read_1_i = Data(block1, "read_1_i", {1});
auto* read_1_out = AddReadFromArray(block1, write_1_out, read_1_i);
AddCast(block1, read_1_out, 5, 4);
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("delete_cast_op_pass");
pass->Apply(graph.get());
int cast_num_in_graph1 = GetOpNum(graph->GetSubGraph(1), "cast");
PADDLE_ENFORCE_EQ(cast_num_in_graph1,
0,
platform::errors::PreconditionNotMet(
"graph1 should have 0 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph1));
int cast_num_in_graph0 = GetOpNum(graph.get(), "cast");
PADDLE_ENFORCE_EQ(cast_num_in_graph0,
1,
platform::errors::PreconditionNotMet(
"graph0 should have 1 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph0));
}
TEST(ApplyCastLodResetWriteReadPass, basic) {
paddle::framework::ProgramDesc program;
auto* block0 = program.MutableBlock(0);
auto* block1 = program.AppendBlock(*block0);
auto* write_0_x = Data(block0, "write_0_x", {1});
auto* write_0_i = Data(block0, "write_0_i", {1});
auto* write_0_out = AddWriteToArray(block0, {write_0_x}, write_0_i);
OpDesc* while_loop = block0->AppendOp();
while_loop->SetType("while");
while_loop->SetInput("X", {write_0_out->Name()});
while_loop->SetOutput("Out", {write_0_out->Name()});
auto* ids = Data(block0, "ids", {1});
AddBeamSearchDecode(block0, ids, write_0_out);
auto* cast_1_0_in = Data(block1, "cast_1_0", {1});
auto* cast_1_0_out = AddCast(block1, cast_1_0_in, 4, 5);
auto* lod_reset_out = AddLodReset(block1, cast_1_0_out);
auto* write_1_i = Data(block1, "write_1_i", {1});
auto* write_1_out = Data(block1, write_0_out->Name(), {1});
AddWriteToArray(block1, {lod_reset_out}, write_1_i, write_1_out);
auto* read_1_i = Data(block1, "read_1_i", {1});
auto* read_1_out = AddReadFromArray(block1, write_1_out, read_1_i);
AddCast(block1, read_1_out, 5, 4);
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("delete_cast_op_pass");
pass->Apply(graph.get());
int cast_num_in_graph1 = GetOpNum(graph->GetSubGraph(1), "cast");
PADDLE_ENFORCE_EQ(cast_num_in_graph1,
0,
platform::errors::PreconditionNotMet(
"graph1 should have 0 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph1));
int cast_num_in_graph0 = GetOpNum(graph.get(), "cast");
PADDLE_ENFORCE_EQ(cast_num_in_graph0,
2,
platform::errors::PreconditionNotMet(
"graph0 should have 2 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph0));
}
TEST(ApplyCastIndexSamplePass, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
auto* cast0_in = Data(block, "cast0_in", {1});
auto* cast0_out = AddCast(block, cast0_in, 4, 5);
auto* index_sample_out = Data(block, "index_sample_out", {1});
OpDesc* index_sample = block->AppendOp();
index_sample->SetType("index_sample");
index_sample->SetInput("X", {cast0_out->Name()});
index_sample->SetOutput("Out", {index_sample_out->Name()});
AddCast(block, index_sample_out, 5, 4);
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("delete_cast_op_pass");
pass->Apply(graph.get());
int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast");
PADDLE_ENFORCE_EQ(GetOpNum(graph->GetSubGraph(0), "cast"),
0,
platform::errors::PreconditionNotMet(
"graph should have 0 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph));
}
TEST(ApplyCastPass, basic) {
paddle::framework::ProgramDesc program;
auto* block = program.MutableBlock(0);
auto* cast0_in = Data(block, "cast0_in", {1});
AddCast(block, cast0_in, 3, 3);
std::unique_ptr<ir::Graph> graph(new ir::Graph(program));
auto scope = new Scope();
graph->Set("__param_scope__", scope);
auto pass = PassRegistry::Instance().Get("delete_cast_op_pass");
pass->Apply(graph.get());
int cast_num_in_graph = GetOpNum(graph->GetSubGraph(0), "cast");
PADDLE_ENFORCE_EQ(GetOpNum(graph->GetSubGraph(0), "cast"),
0,
platform::errors::PreconditionNotMet(
"graph should have 0 cast after delete_cast_op_pass, "
"but actually has %d.",
cast_num_in_graph));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(delete_cast_op_pass);
......@@ -528,6 +528,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"multi_encoder_xpu_fuse_pass",
"multi_encoder_xpu_slice_fuse_pass",
"one_beam_size_fuse_pass",
"delete_cast_op_pass",
"stack_fuse_pass",
"fused_multi_transformer_xpu_quant_pass",
"fc_xpu_fuse_pass",
......
......@@ -249,6 +249,7 @@ REGISTER_OPERATOR(lod_reset_grad,
REGISTER_OP_CPU_KERNEL(
lod_reset,
ops::LoDResetKernel<paddle::platform::CPUPlace, paddle::platform::float16>,
ops::LoDResetKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetKernel<paddle::platform::CPUPlace, double>,
ops::LoDResetKernel<paddle::platform::CPUPlace, int>,
......@@ -257,6 +258,8 @@ REGISTER_OP_CPU_KERNEL(
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(
lod_reset,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, float>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, double>,
ops::LoDResetKernel<paddle::platform::XPUDeviceContext, int>,
......@@ -265,6 +268,8 @@ REGISTER_OP_XPU_KERNEL(
REGISTER_OP_CPU_KERNEL(
lod_reset_grad,
ops::LoDResetGradKernel<paddle::platform::CPUPlace,
paddle::platform::float16>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, float>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, double>,
ops::LoDResetGradKernel<paddle::platform::CPUPlace, int>,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册