From 96ced1a1dbacc02fee80291bce39253ca44db293 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=98=8E=E5=86=AC?= <78149749+winter-wang@users.noreply.github.com> Date: Sat, 9 Apr 2022 13:57:18 +0800 Subject: [PATCH] [infrt] opt support input valid places by commondline. (#41544) --- paddle/infrt/dialect/infrt/common/types.h | 11 +-- .../dialect/phi/pass/phi_op_convert_pass.cc | 78 ++++++++++++++----- .../dialect/phi/pass/phi_op_convert_pass.h | 2 - paddle/infrt/tests/dialect/phi/phi_pass.mlir | 2 +- .../infrt/tests/dialect/phi/resnet50.mlir.in | 2 +- paddle/phi/infermeta/unary.cc | 1 + tools/check_file_diff_approvals.sh | 2 +- tools/infrt/skipped_phi_api.json | 2 +- 8 files changed, 67 insertions(+), 33 deletions(-) diff --git a/paddle/infrt/dialect/infrt/common/types.h b/paddle/infrt/dialect/infrt/common/types.h index 2ebe2b8ccdb..5bd1f40262b 100644 --- a/paddle/infrt/dialect/infrt/common/types.h +++ b/paddle/infrt/dialect/infrt/common/types.h @@ -39,15 +39,12 @@ enum class PrecisionType : uint8_t { }; struct Place { - TargetType target; - PrecisionType precision; - LayoutType layout; + TargetType target = TargetType::UNK; + PrecisionType precision = PrecisionType::UNK; + LayoutType layout = LayoutType::UNK; Place(TargetType tar, PrecisionType pre, LayoutType lay) : target(tar), precision(pre), layout(lay) {} - Place() - : target(TargetType::UNK), - precision(PrecisionType::UNK), - layout(LayoutType::UNK) {} + Place() = default; }; llvm::Optional GetTargetType(llvm::StringRef key); diff --git a/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc b/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc index e3fdd5ae5bb..e9b426a5088 100644 --- a/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc +++ b/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.cc @@ -36,12 +36,35 @@ #include "paddle/phi/ops/compat/signatures.h" namespace { + +infrt::Place ParsePlaceFromStr(const std::string &key) { + size_t first_index = key.find_first_of('-'); + size_t second_index = key.find_last_of('-'); + if (first_index != second_index) { + llvm::Optional tar = + infrt::GetTargetType(key.substr(0, first_index)); + llvm::Optional pre = infrt::GetPrecisionType( + key.substr(first_index + 1, second_index - first_index - 1)); + llvm::Optional lay = + infrt::GetLayoutType(key.substr(second_index + 1)); + if (tar && pre && lay) { + return infrt::Place(tar.getValue(), pre.getValue(), lay.getValue()); + } + } + LOG(FATAL) << "Can't parse infrt::Place from string:" << key; + return infrt::Place(); +} + class PhiOpConvertPass : public mlir::PassWrapper { public: ::llvm::StringRef getName() const override { return "PhiOpConvertPass"; } void runOnFunction() override; - PhiOpConvertPass(); + + /// Initialize the valid_places_ by the valid_places_options_ while + /// valid_places_options_ has values. + mlir::LogicalResult initialize(mlir::MLIRContext *context) override; + PhiOpConvertPass() {} explicit PhiOpConvertPass(const std::vector &valid_places) : valid_places_(valid_places) {} @@ -56,14 +79,35 @@ class PhiOpConvertPass void convertStage(); void dispatchStage(); - // Force a specified data format for all layout sensitive operations. - Option valid_places_options_{ + ListOption valid_places_options_{ *this, "valid-targets", - llvm::cl::desc("Set the valid target, [CPU-FP32-NCHW]")}; + llvm::cl::desc( + "Set the valids target, such as: CPU-FP32-NCHW,GPU-FP32-NCHW"), + llvm::cl::MiscFlags::CommaSeparated}; std::vector valid_places_; }; + +/// Initialize the canonicalizer by building the set of patterns used during +/// execution. +mlir::LogicalResult PhiOpConvertPass::initialize(mlir::MLIRContext *context) { + if (valid_places_options_.hasValue()) { + VLOG(4) << "Start parse valid_places from commond line:"; + if (!valid_places_.empty()) { + LOG(WARNING) << "Find valid place from commandline, current value will " + "be overwrittern."; + valid_places_.clear(); + } + for (auto &val : *valid_places_options_) { + VLOG(4) << "place string:" << val; + valid_places_.emplace_back(ParsePlaceFromStr(val)); + } + VLOG(4) << "End parse valid_places from commond line:"; + } + return mlir::success(); +} + // Implementation of the PhiOpConvertPass. void PhiOpConvertPass::runOnFunction() { convertStage(); @@ -191,7 +235,16 @@ void PhiOpConvertPass::dispatchStage() { .output(); phi_context[infrt::TargetType::CPU] = context_value; } break; - case infrt::TargetType::GPU: + case infrt::TargetType::GPU: { + auto context_value = + builder + .create( + kernel_op.getLoc(), + infrt::phi::ContextType::get(kernel_op.getContext(), + infrt::TargetType::GPU)) + .output(); + phi_context[infrt::TargetType::GPU] = context_value; + } break; case infrt::TargetType::UNK: default: LOG(FATAL) << "Unsupported TargetType"; @@ -237,17 +290,6 @@ void PhiOpConvertPass::dispatchStage() { } } -PhiOpConvertPass::PhiOpConvertPass() { - if (!valid_places_options_.hasValue()) { - valid_places_.emplace_back(infrt::TargetType::CPU, - infrt::PrecisionType::FLOAT32, - infrt::LayoutType::NCHW); - return; - } - - LOG(FATAL) << "To be done for specifying places in command line"; -} - void PhiOpConvertPass::getDependentDialects( mlir::DialectRegistry ®istry) const { registry.insert(); @@ -265,7 +307,3 @@ std::unique_ptr infrt::CreatePhiOpCvtPass( std::vector valid_places) { return std::make_unique(valid_places); } - -std::unique_ptr infrt::CreatePhiOpCvtPass() { - return std::make_unique(); -} diff --git a/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h b/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h index c426bbf1151..a0e74426a40 100644 --- a/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h +++ b/paddle/infrt/dialect/phi/pass/phi_op_convert_pass.h @@ -23,6 +23,4 @@ namespace infrt { */ std::unique_ptr CreatePhiOpCvtPass(std::vector valid_places); -std::unique_ptr CreatePhiOpCvtPass(); - } // namespace infrt diff --git a/paddle/infrt/tests/dialect/phi/phi_pass.mlir b/paddle/infrt/tests/dialect/phi/phi_pass.mlir index 784ead5b2a0..0d9e312ce0b 100644 --- a/paddle/infrt/tests/dialect/phi/phi_pass.mlir +++ b/paddle/infrt/tests/dialect/phi/phi_pass.mlir @@ -1,4 +1,4 @@ -// RUN: infrtopt -phi-op-convert -infrt-op-fuse %s +// RUN: infrtopt -phi-op-convert=valid-targets=CPU-FP32-NCHW -infrt-op-fuse %s // CHECK-LABEL: @ops func @ops(%a:!infrt.lod_tensor, %b:!infrt.lod_tensor) { diff --git a/paddle/infrt/tests/dialect/phi/resnet50.mlir.in b/paddle/infrt/tests/dialect/phi/resnet50.mlir.in index 2803ebb41cf..3591a62e88e 100644 --- a/paddle/infrt/tests/dialect/phi/resnet50.mlir.in +++ b/paddle/infrt/tests/dialect/phi/resnet50.mlir.in @@ -444,7 +444,7 @@ module { %387 = "pd.flatten_contiguous_range"(%386) {start_axis = 1 : si32, stop_axis = 3 : si32} : (!infrt.dense_tensor) -> !infrt.dense_tensor %388 = "pd.matmul_v2"(%387, %245) {trans_x = false, trans_y = false} : (!infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor %389 = "pd.elementwise_add"(%388, %30) {axis = 1 : si32} : (!infrt.dense_tensor, !infrt.dense_tensor) -> !infrt.dense_tensor - infrt.return %270 : !infrt.dense_tensor + infrt.return %389 : !infrt.dense_tensor } func @main() { diff --git a/paddle/phi/infermeta/unary.cc b/paddle/phi/infermeta/unary.cc index c6e2cb76191..a47fc698777 100644 --- a/paddle/phi/infermeta/unary.cc +++ b/paddle/phi/infermeta/unary.cc @@ -2931,4 +2931,5 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) { } // namespace phi PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta); +PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta); PD_REGISTER_INFER_META_FN(split, phi::SplitInferMeta); diff --git a/tools/check_file_diff_approvals.sh b/tools/check_file_diff_approvals.sh index e0598112c82..49b84da01b9 100644 --- a/tools/check_file_diff_approvals.sh +++ b/tools/check_file_diff_approvals.sh @@ -201,7 +201,7 @@ fi # infrt needs to temporarily use LOG(FATAL) during the debugging period, and will replace it with standard error format in the future. NO_INFRT_FILES=`git diff --name-only upstream/develop | grep -v "tools/\|paddle/infrt/" || true` HAS_LOG_FATAL=`git diff -U0 upstream/$BRANCH $NO_INFRT_FILES |grep "^+" |grep -o -m 1 "LOG(FATAL)" || true` -if [ ${HAS_LOG_FATAL} ] && [ "${GIT_PR_ID}" != "" ]; then +if [ ${NO_INFRT_FILES} ] && [ ${HAS_LOG_FATAL} ] && [ "${GIT_PR_ID}" != "" ]; then echo_line="LOG(FATAL) is not recommended, because it will throw exception without standard stack information, so please use PADDLE_THROW macro here. If you have to use LOG(FATAL) here, please request chenwhql (Recommend), luotao1 or lanxianghit review and approve.\n" check_approval 1 6836917 47554610 22561442 fi diff --git a/tools/infrt/skipped_phi_api.json b/tools/infrt/skipped_phi_api.json index 64fc4c618ae..8e2dd0f65d7 100644 --- a/tools/infrt/skipped_phi_api.json +++ b/tools/infrt/skipped_phi_api.json @@ -1,4 +1,4 @@ { -"phi_apis":["conj", "dropout", "expand_as", "flatten", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth"], +"phi_apis":["conj", "dropout", "expand_as", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth"], "phi_kernels":["equal_all"] } -- GitLab