未验证 提交 ecff21e7 编写于 作者: Y Yuanle Liu 提交者: GitHub

[Inference] auto mixed precision inference support white list (#56535)

* auto mixed precision inference support white list

* update

* update

* update

* move down identity_op_clean_pass

* fix code style
上级 5f9d6d68
......@@ -165,7 +165,9 @@ void DoInsertCastOp(Graph* graph,
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list) {
const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list) {
if (white_list.count(op_type)) return true;
return black_list.count(op_type) == 0 &&
KernelSupportPrecision(op_type, backend, precision);
}
......@@ -230,11 +232,16 @@ void AutoMixedPrecisionPass::Init(Graph* graph) const {
if (skip_pass_) return;
black_list_ = Get<std::unordered_set<std::string>>("mixed_black_list");
white_list_ = Get<std::unordered_set<std::string>>("mixed_white_list");
SetDefaultBlacklist();
VLOG(4) << "black_list has ";
for (const auto& name : black_list_) {
VLOG(4) << " - " << name;
}
VLOG(4) << "white_list has ";
for (const auto& name : white_list_) {
VLOG(4) << " - " << name;
}
if (Has("enable_low_precision_io")) {
enable_low_precision_io_ = Get<bool>("enable_low_precision_io");
......@@ -403,8 +410,11 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
op_node->Op()->GetAttrIfExists<bool>("enable_low_precision_io");
support_low_precision = enable_fp16 && !enable_int8 && low_precision_io;
} else {
support_low_precision = OpSupportPrecision(
GetOpOriginalType(op_type), backend_, low_precision_, black_list_);
support_low_precision = OpSupportPrecision(GetOpOriginalType(op_type),
backend_,
low_precision_,
black_list_,
white_list_);
std::unordered_set<std::string> check_dtype_op_blacklist(
{"arg_max", "arg_min"});
......@@ -422,8 +432,8 @@ void AutoMixedPrecisionPass::GetOpPrecision() const {
out_dtype == -1);
}
// If scale op's "scale" and "bias" attr value exceed the range of fp16
// and bf16, it cannot run at low precision.
// If scale op's "scale" and "bias" attr value exceed the range of
// fp16 and bf16, it cannot run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "scale") {
auto scale = op_node->Op()->GetAttrIfExists<float>("scale");
auto bias = op_node->Op()->GetAttrIfExists<float>("bias");
......@@ -500,9 +510,9 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
<< " is output of " << op_type;
}
// the select_input op's input var should not convert to low precision.
// when op's output var is select_input op's input var, the op should
// not run at low precision.
// the select_input op's input var should not convert to low
// precision. when op's output var is select_input op's input var, the
// op should not run at low precision.
if (GetOpOriginalType(op_node->Op()->Type()) == "select_input") {
for (auto* in_var_node : op_node->inputs) {
CHECK_EQ(in_var_node->IsVar(), true);
......@@ -517,6 +527,7 @@ void AutoMixedPrecisionPass::UpdateOpPrecision() const {
// output var, then op_2 should not run at low precision.
if (GetOpOriginalType(op_type) != "feed" &&
GetOpOriginalType(op_type) != "tensorrt_engine" &&
white_list_.count(GetOpOriginalType(op_type)) == 0 &&
!KernelSupportPrecision(
GetOpOriginalType(op_type), backend_, phi::DataType::FLOAT32)) {
for (auto* out_var_node : op_node->outputs) {
......
......@@ -75,6 +75,7 @@ class AutoMixedPrecisionPass : public FusePassBase {
mutable phi::Backend backend_{phi::Backend::UNDEFINED};
mutable std::unordered_set<std::string> black_list_;
mutable std::unordered_set<std::string> white_list_;
// subgraph id -> pointer to subgraph
mutable std::vector<Graph*> subgraphes_;
......@@ -93,7 +94,8 @@ class AutoMixedPrecisionPass : public FusePassBase {
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list);
const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list);
void DoInsertCastOp(Graph* graph,
Node* var_node,
......
......@@ -89,11 +89,52 @@ FindUselessOpPattern::FindUselessOpPattern(PDPattern* pattern,
useless_op->LinksFrom({useless_op_in}).LinksTo({useless_op_out});
}
} // namespace patterns
// pre_op -> pre_op_out -> cast_op_1 -> cast_op_1_out -> cast_op_2 ->
// cast_op_2_out
// ->
// pre_op -> cast_op_2_out
struct FindTwoCastOpPattern : public PatternBase {
FindTwoCastOpPattern(PDPattern* pattern, const std::string& name_scope);
void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
Init(name_scope_, graph);
// declare operator node's name
PATTERN_DECL_NODE(pre_op_out);
PATTERN_DECL_NODE(cast_op_1);
PATTERN_DECL_NODE(cast_op_1_out);
PATTERN_DECL_NODE(cast_op_2);
PATTERN_DECL_NODE(cast_op_2_out);
};
FindTwoCastOpPattern::FindTwoCastOpPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope, name_scope) {
auto* pre_op_out = pattern->NewNode(pre_op_out_repr())
->assert_is_var()
->assert_var_not_persistable()
->assert_has_n_outputs(1)
->assert_more([](Node* x) {
for (auto* op : x->inputs) {
CHECK_EQ(op->IsOp(), true);
const auto& op_type = op->Op()->Type();
if (op_type == "conditional_block" ||
op_type == "while" || op_type == "feed") {
return false;
}
}
return true;
});
auto* cast_op_1 = pattern->NewNode(cast_op_1_repr())->assert_is_op("cast");
auto* cast_op_1_out = pattern->NewNode(cast_op_1_out_repr())->assert_is_var();
auto* cast_op_2 = pattern->NewNode(cast_op_2_repr())->assert_is_op("cast");
auto* cast_op_2_out = pattern->NewNode(cast_op_2_out_repr())->assert_is_var();
cast_op_1->LinksFrom({pre_op_out}).LinksTo({cast_op_1_out});
cast_op_2->LinksFrom({cast_op_1_out}).LinksTo({cast_op_2_out});
}
} // namespace patterns
int IdentityOpCleanPass::CleanUselessOp(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::FindUselessOpPattern pattern(gpd.mutable_pattern(), name_scope_);
......@@ -119,6 +160,48 @@ void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
};
gpd(graph, handler);
return found_count;
}
int IdentityOpCleanPass::CleanTwoCastOp(ir::Graph* graph) const {
GraphPatternDetector gpd;
patterns::FindTwoCastOpPattern pattern(gpd.mutable_pattern(), name_scope_);
int found_count = 0;
GraphPatternDetector::handle_t handler =
[&](const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
GET_IR_NODE_FROM_SUBGRAPH(pre_op_out, pre_op_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast_op_1, cast_op_1, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast_op_1_out, cast_op_1_out, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast_op_2, cast_op_2, pattern);
GET_IR_NODE_FROM_SUBGRAPH(cast_op_2_out, cast_op_2_out, pattern);
CHECK_EQ(pre_op_out->IsVar(), true);
CHECK_EQ(cast_op_1_out->IsVar(), true);
CHECK_EQ(cast_op_2_out->IsVar(), true);
CHECK_EQ(cast_op_1->IsOp(), true);
CHECK_EQ(cast_op_2->IsOp(), true);
if (pre_op_out->Var()->GetDataType() ==
cast_op_2_out->Var()->GetDataType()) {
for (auto* prev_op : pre_op_out->inputs) {
CHECK_EQ(prev_op->IsOp(), true);
prev_op->Op()->RenameOutput(pre_op_out->Var()->Name(),
cast_op_2_out->Var()->Name());
IR_NODE_LINK_TO(prev_op, cast_op_2_out);
}
GraphSafeRemoveNodes(
graph, {pre_op_out, cast_op_1, cast_op_1_out, cast_op_2});
found_count++;
}
};
gpd(graph, handler);
return found_count;
}
void IdentityOpCleanPass::ApplyImpl(ir::Graph* graph) const {
Init(name_scope_, graph);
int found_count = CleanUselessOp(graph) + CleanTwoCastOp(graph);
AddStatis(found_count);
}
......
......@@ -27,6 +27,10 @@ class IdentityOpCleanPass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override;
private:
int CleanUselessOp(ir::Graph* graph) const;
int CleanTwoCastOp(ir::Graph* graph) const;
const std::string name_scope_{"identity_op_clean_pass"};
};
......
......@@ -419,6 +419,9 @@ struct Argument {
DECL_ARGUMENT_FIELD(mixed_black_list,
MixedBlackList,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(mixed_white_list,
MixedWhiteList,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(enable_gpu_mixed, EnableGPUMixed, bool);
DECL_ARGUMENT_FIELD(mixed_precision_mode, MixedPrecisionMode, int);
DECL_ARGUMENT_FIELD(enable_low_precision_io, EnableLowPrecisionIO, bool);
......
......@@ -100,6 +100,9 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set(
"mixed_black_list",
new std::unordered_set<std::string>(argument->mixed_black_list()));
pass->Set(
"mixed_white_list",
new std::unordered_set<std::string>(argument->mixed_white_list()));
pass->Set("enable_gpu_mixed", new bool(argument->enable_gpu_mixed()));
pass->Set("enable_custom_device_mixed",
new bool(argument->enable_custom_device_mixed()));
......
......@@ -50,7 +50,8 @@ void OutputProcess(framework::ir::Graph *graph,
const std::unordered_set<framework::ir::Node *> &trt_outputs,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string> &blacklist) {
const std::unordered_set<std::string> &blacklist,
const std::unordered_set<std::string> &whitelist) {
framework::BlockDesc *block_desc{nullptr};
int suffix = 0;
std::unordered_map<framework::ir::Node *, framework::ir::Node *>
......@@ -86,7 +87,8 @@ void OutputProcess(framework::ir::Graph *graph,
phi::TransToPhiKernelName(next_op->Op()->Type()),
backend,
precision,
blacklist)) {
blacklist,
whitelist)) {
InsertCastOp(graph,
var_node,
next_op,
......@@ -363,6 +365,8 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
static_cast<phi::DataType>(Get<int>("model_precision"));
auto mixed_black_list =
Get<std::unordered_set<std::string>>("mixed_black_list");
auto mixed_white_list =
Get<std::unordered_set<std::string>>("mixed_white_list");
std::set<std::string> output_names;
std::set<std::string> output_names_with_id;
......@@ -414,8 +418,12 @@ std::string TensorRtSubgraphPass::CreateTensorRTOp(
static_cast<int>(x->Var()->GetDataType());
}
OutputProcess(
graph, trt_outputs, phi::Backend::GPU, model_precision, mixed_black_list);
OutputProcess(graph,
trt_outputs,
phi::Backend::GPU,
model_precision,
mixed_black_list,
mixed_white_list);
std::unordered_map<std::string, std::string> output_name_map;
std::unordered_map<std::string, framework::ir::Node *> graph_var_map;
......
......@@ -14,7 +14,7 @@ cc_library(
convert_to_mixed_precision
SRCS convert_to_mixed_precision.cc
DEPS analysis_pass ir_graph_build_pass auto_mixed_precision_pass
constant_folding_pass)
constant_folding_pass identity_op_clean_pass)
cc_library(
ir_params_sync_among_devices_pass
SRCS ir_params_sync_among_devices_pass.cc
......
......@@ -18,6 +18,7 @@
#include "paddle/fluid/framework/ir/auto_mixed_precision_pass.h"
#include "paddle/fluid/framework/ir/constant_folding_pass.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/identity_op_clean_pass.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/phi/common/backend.h"
......@@ -33,7 +34,8 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass(
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
const std::unordered_set<std::string>& black_list)
const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list)
: model_file_(model_file),
params_file_(params_file),
mixed_model_file_(mixed_model_file),
......@@ -41,7 +43,8 @@ ConvertToMixedPrecisionPass::ConvertToMixedPrecisionPass(
mixed_precision_(mixed_precision),
backend_(backend),
keep_io_types_(keep_io_types),
black_list_(black_list) {
black_list_(black_list),
white_list_(white_list) {
switch (backend_) {
case phi::Backend::GPU:
PADDLE_ENFORCE(mixed_precision_ == phi::DataType::FLOAT16 ||
......@@ -88,19 +91,27 @@ void ConvertToMixedPrecisionPass::Run() {
framework::ir::ConstantFoldingPass constant_folding_pass;
constant_folding_pass.Apply(main_graph_.get());
framework::ir::AutoMixedPrecisionPass pass;
pass.Set("mixed_precision_mode", new int{static_cast<int>(mixed_precision_)});
framework::ir::AutoMixedPrecisionPass auto_mixed_precision_pass;
auto_mixed_precision_pass.Set("mixed_precision_mode",
new int{static_cast<int>(mixed_precision_)});
if (backend_ == phi::Backend::GPU) {
pass.Set("enable_gpu_mixed", new bool{true});
auto_mixed_precision_pass.Set("enable_gpu_mixed", new bool{true});
} else if (backend_ == phi::Backend::XPU) {
pass.Set("enable_xpu_mixed", new bool{true});
auto_mixed_precision_pass.Set("enable_xpu_mixed", new bool{true});
} else if (backend_ == phi::Backend::CUSTOM) {
pass.Set("enable_custom_device_mixed", new bool{true});
auto_mixed_precision_pass.Set("enable_custom_device_mixed", new bool{true});
}
pass.Set("mixed_black_list",
new std::unordered_set<std::string>{black_list_});
pass.Set("enable_low_precision_io", new bool{!keep_io_types_});
pass.Apply(main_graph_.get());
auto_mixed_precision_pass.Set(
"mixed_black_list", new std::unordered_set<std::string>{black_list_});
auto_mixed_precision_pass.Set(
"mixed_white_list", new std::unordered_set<std::string>{white_list_});
auto_mixed_precision_pass.Set("enable_low_precision_io",
new bool{!keep_io_types_});
auto_mixed_precision_pass.Apply(main_graph_.get());
framework::ir::IdentityOpCleanPass identity_op_clean_pass;
identity_op_clean_pass.Apply(main_graph_.get());
SaveMixedModel();
}
......@@ -184,9 +195,10 @@ void ConvertToMixedPrecisionPass::SaveMixedModel() {
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list) {
const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list) {
return framework::ir::OpSupportPrecision(
op_type, backend, precision, black_list);
op_type, backend, precision, black_list, white_list);
}
void InsertCastOp(
......@@ -216,7 +228,8 @@ void ConvertToMixedPrecision(
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
const std::unordered_set<std::string>& black_list) {
const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list) {
ConvertToMixedPrecisionPass pass(model_file,
params_file,
mixed_model_file,
......@@ -224,7 +237,8 @@ void ConvertToMixedPrecision(
mixed_precision,
backend,
keep_io_types,
black_list);
black_list,
white_list);
pass.Run();
}
......
......@@ -38,7 +38,8 @@ class ConvertToMixedPrecisionPass {
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
const std::unordered_set<std::string>& black_list);
const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list);
void Run();
......@@ -55,6 +56,7 @@ class ConvertToMixedPrecisionPass {
phi::Backend backend_;
bool keep_io_types_;
std::unordered_set<std::string> black_list_;
std::unordered_set<std::string> white_list_;
framework::Scope scope_;
std::unique_ptr<framework::ir::Graph> main_graph_{nullptr};
......@@ -63,7 +65,8 @@ class ConvertToMixedPrecisionPass {
bool OpSupportPrecision(const std::string& op_type,
phi::Backend backend,
phi::DataType precision,
const std::unordered_set<std::string>& black_list);
const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list);
void InsertCastOp(
framework::ir::Graph* graph,
......@@ -82,7 +85,8 @@ void ConvertToMixedPrecision(const std::string& model_file,
phi::DataType mixed_precision,
phi::Backend backend,
bool keep_io_types,
const std::unordered_set<std::string>& black_list);
const std::unordered_set<std::string>& black_list,
const std::unordered_set<std::string>& white_list);
} // namespace analysis
} // namespace inference
......
......@@ -448,6 +448,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
// Mixed precision related.
CP_MEMBER(mixed_black_list_);
CP_MEMBER(mixed_white_list_);
CP_MEMBER(enable_gpu_mixed_);
CP_MEMBER(mixed_precision_mode_);
CP_MEMBER(enable_low_precision_io_);
......@@ -1154,6 +1155,7 @@ std::string AnalysisConfig::SerializeInfoCache() {
for (auto attr : pattern) ss << attr;
ss << ";";
for (auto &op : mixed_black_list_) ss << op.c_str();
for (auto &op : mixed_white_list_) ss << op.c_str();
return ss.str();
}
......@@ -1535,6 +1537,11 @@ void AnalysisConfig::Exp_DisableMixedPrecisionOps(
mixed_black_list_ = black_list;
}
void AnalysisConfig::Exp_EnableMixedPrecisionOps(
const std::unordered_set<std::string> &white_list) {
mixed_white_list_ = white_list;
}
void AnalysisConfig::Exp_EnableCINNCompiler() {
#ifdef PADDLE_WITH_CINN
use_cinn_compiler_ = true;
......
......@@ -1616,6 +1616,7 @@ void AnalysisPredictor::PrepareArgument() {
// mixed precison.
argument_->SetModelPrecision(static_cast<int>(model_precision_));
argument_->SetMixedBlackList(config_.mixed_black_list_);
argument_->SetMixedWhiteList(config_.mixed_white_list_);
argument_->SetEnableGPUMixed(config_.enable_gpu_mixed_);
argument_->SetMixedPrecisionMode(static_cast<int>(
paddle::ConvertPrecision(config_.mixed_precision_mode_)));
......@@ -3097,7 +3098,8 @@ void ConvertToMixedPrecision(const std::string &model_file,
PrecisionType mixed_precision,
paddle_infer::PlaceType backend,
bool keep_io_types,
std::unordered_set<std::string> black_list) {
std::unordered_set<std::string> black_list,
std::unordered_set<std::string> white_list) {
auto phi_backend = paddle::ConvertBackend(backend);
auto phi_precision = paddle::ConvertPrecision(mixed_precision);
paddle::inference::analysis::ConvertToMixedPrecision(model_file,
......@@ -3107,7 +3109,8 @@ void ConvertToMixedPrecision(const std::string &model_file,
phi_precision,
phi_backend,
keep_io_types,
black_list);
black_list,
white_list);
}
} // namespace paddle_infer
......
......@@ -1147,6 +1147,14 @@ struct PD_INFER_DECL AnalysisConfig {
void Exp_DisableMixedPrecisionOps(
const std::unordered_set<std::string>& black_list);
///
/// \brief Set a list of operators that do support mixed precision. This
/// interface is in the experimental stage and may change in the future. Note
/// that the whitelist must be the same as the model conversion whitelist.
///
void Exp_EnableMixedPrecisionOps(
const std::unordered_set<std::string>& white_list);
void SetApplyOptim(bool value) { apply_optim_ = value; }
void SetSkipLoadParams(bool value) { skip_load_params_ = value; }
......@@ -1179,6 +1187,7 @@ struct PD_INFER_DECL AnalysisConfig {
// Mixed precision related.
Precision mixed_precision_mode_{Precision::kFloat32};
std::unordered_set<std::string> mixed_black_list_;
std::unordered_set<std::string> mixed_white_list_;
bool enable_low_precision_io_{false};
// GPU related.
......
......@@ -245,7 +245,8 @@ PD_INFER_DECL void ConvertToMixedPrecision(
PrecisionType mixed_precision,
PlaceType backend,
bool keep_io_types = true,
std::unordered_set<std::string> black_list = {});
std::unordered_set<std::string> black_list = {},
std::unordered_set<std::string> white_list = {});
namespace services {
///
......
......@@ -268,10 +268,10 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"conv_elementwise_add_fuse_pass", //
#endif //
"transpose_flatten_concat_fuse_pass", //
"identity_op_clean_pass", //
"conv2d_fusion_layout_transfer_pass", //
"transfer_layout_elim_pass",
"auto_mixed_precision_pass", //
"identity_op_clean_pass", // should be after auto_mixed_precision_pass.
"inplace_op_var_pass", // should be the last pass.
});
......
......@@ -544,7 +544,8 @@ void BindInferenceApi(py::module *m) {
py::arg("mixed_precision"),
py::arg("backend"),
py::arg("keep_io_types") = true,
py::arg("black_list") = std::unordered_set<std::string>());
py::arg("black_list") = std::unordered_set<std::string>(),
py::arg("white_list") = std::unordered_set<std::string>());
}
namespace {
......@@ -777,6 +778,8 @@ void BindAnalysisConfig(py::module *m) {
.def("exp_enable_use_cutlass", &AnalysisConfig::Exp_EnableUseCutlass)
.def("exp_disable_mixed_precision_ops",
&AnalysisConfig::Exp_DisableMixedPrecisionOps)
.def("exp_enable_mixed_precision_ops",
&AnalysisConfig::Exp_EnableMixedPrecisionOps)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("set_exec_stream",
[](AnalysisConfig &self, phi::CUDAStream &stream) {
......
......@@ -78,7 +78,8 @@ def convert_to_mixed_precision(
mixed_precision: PrecisionType,
backend: PlaceType,
keep_io_types: bool = True,
black_list: Set = set(),
black_list: Set[str] = set(),
**kwargs,
):
'''
Convert a fp32 model to mixed precision model.
......@@ -92,6 +93,8 @@ def convert_to_mixed_precision(
backend: The backend, e.g. PlaceType.GPU.
keep_io_types: Whether the model input and output dtype remains unchanged.
black_list: Operators that do not convert precision.
kwargs: Supported keys including 'white_list'.
- white_list: Operators that do convert precision.
'''
mixed_model_dirname = os.path.dirname(mixed_model_file)
# Support mixed_params_file is empty, because some models don't have params, but convert_to_mixed_precision will call
......@@ -104,6 +107,7 @@ def convert_to_mixed_precision(
)
if not os.path.exists(mixed_params_dirname):
os.makedirs(mixed_params_dirname)
white_list = kwargs.get('white_list', set())
convert_to_mixed_precision_bind(
model_file,
params_file,
......@@ -113,6 +117,7 @@ def convert_to_mixed_precision(
backend,
keep_io_types,
black_list,
white_list,
)
......
......@@ -155,5 +155,60 @@ class TestIdentityScaleCleanPass_V2(PassAutoScanTest):
self.run_and_statis(max_examples=25, passes=["identity_op_clean_pass"])
class TestIdentityCastCleanPass(PassAutoScanTest):
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_gpu=True)
yield config, ['relu', 'relu'], (1e-2, 1e-2)
def sample_program_config(self, draw):
n = draw(st.integers(min_value=1, max_value=4))
c = draw(st.integers(min_value=1, max_value=20))
h = draw(st.integers(min_value=1, max_value=20))
w = draw(st.integers(min_value=1, max_value=20))
relu_op_1 = OpConfig(
"relu",
inputs={"X": ["relu_op_1_in"]},
outputs={"Out": ["relu_op_1_out"]},
)
cast_op_1 = OpConfig(
"cast",
inputs={"X": ["relu_op_1_out"]},
outputs={"Out": ["cast_op_1_out"]},
in_dtype=5,
out_dtype=5,
)
relu_op_2 = OpConfig(
"relu",
inputs={"X": ["cast_op_1_out"]},
outputs={"Out": ["relu_op_2_out"]},
)
cast_op_2 = OpConfig(
"cast",
inputs={"X": ["relu_op_2_out"]},
outputs={"Out": ["cast_op_2_out"]},
in_dtype=5,
out_dtype=4,
)
cast_op_3 = OpConfig(
"cast",
inputs={"X": ["cast_op_2_out"]},
outputs={"Out": ["cast_op_3_out"]},
in_dtype=4,
out_dtype=5,
)
program_config = ProgramConfig(
ops=[relu_op_1, cast_op_1, relu_op_2, cast_op_2, cast_op_3],
weights={},
inputs={"relu_op_1_in": TensorConfig(shape=[n, c, h, w])},
outputs=["cast_op_3_out"],
)
return program_config
def test(self):
self.run_and_statis(max_examples=25, passes=["identity_op_clean_pass"])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册