未验证 提交 96ced1a1 编写于 作者: 王明冬 提交者: GitHub

[infrt] opt support input valid places by commondline. (#41544)

上级 b937cdc5
......@@ -39,15 +39,12 @@ enum class PrecisionType : uint8_t {
};
struct Place {
TargetType target;
PrecisionType precision;
LayoutType layout;
TargetType target = TargetType::UNK;
PrecisionType precision = PrecisionType::UNK;
LayoutType layout = LayoutType::UNK;
Place(TargetType tar, PrecisionType pre, LayoutType lay)
: target(tar), precision(pre), layout(lay) {}
Place()
: target(TargetType::UNK),
precision(PrecisionType::UNK),
layout(LayoutType::UNK) {}
Place() = default;
};
llvm::Optional<TargetType> GetTargetType(llvm::StringRef key);
......
......@@ -36,12 +36,35 @@
#include "paddle/phi/ops/compat/signatures.h"
namespace {
infrt::Place ParsePlaceFromStr(const std::string &key) {
size_t first_index = key.find_first_of('-');
size_t second_index = key.find_last_of('-');
if (first_index != second_index) {
llvm::Optional<infrt::TargetType> tar =
infrt::GetTargetType(key.substr(0, first_index));
llvm::Optional<infrt::PrecisionType> pre = infrt::GetPrecisionType(
key.substr(first_index + 1, second_index - first_index - 1));
llvm::Optional<infrt::LayoutType> lay =
infrt::GetLayoutType(key.substr(second_index + 1));
if (tar && pre && lay) {
return infrt::Place(tar.getValue(), pre.getValue(), lay.getValue());
}
}
LOG(FATAL) << "Can't parse infrt::Place from string:" << key;
return infrt::Place();
}
class PhiOpConvertPass
: public mlir::PassWrapper<PhiOpConvertPass, mlir::FunctionPass> {
public:
::llvm::StringRef getName() const override { return "PhiOpConvertPass"; }
void runOnFunction() override;
PhiOpConvertPass();
/// Initialize the valid_places_ by the valid_places_options_ while
/// valid_places_options_ has values.
mlir::LogicalResult initialize(mlir::MLIRContext *context) override;
PhiOpConvertPass() {}
explicit PhiOpConvertPass(const std::vector<infrt::Place> &valid_places)
: valid_places_(valid_places) {}
......@@ -56,14 +79,35 @@ class PhiOpConvertPass
void convertStage();
void dispatchStage();
// Force a specified data format for all layout sensitive operations.
Option<std::string> valid_places_options_{
ListOption<std::string> valid_places_options_{
*this,
"valid-targets",
llvm::cl::desc("Set the valid target, [CPU-FP32-NCHW]")};
llvm::cl::desc(
"Set the valids target, such as: CPU-FP32-NCHW,GPU-FP32-NCHW"),
llvm::cl::MiscFlags::CommaSeparated};
std::vector<infrt::Place> valid_places_;
};
/// Initialize the canonicalizer by building the set of patterns used during
/// execution.
mlir::LogicalResult PhiOpConvertPass::initialize(mlir::MLIRContext *context) {
if (valid_places_options_.hasValue()) {
VLOG(4) << "Start parse valid_places from commond line:";
if (!valid_places_.empty()) {
LOG(WARNING) << "Find valid place from commandline, current value will "
"be overwrittern.";
valid_places_.clear();
}
for (auto &val : *valid_places_options_) {
VLOG(4) << "place string:" << val;
valid_places_.emplace_back(ParsePlaceFromStr(val));
}
VLOG(4) << "End parse valid_places from commond line:";
}
return mlir::success();
}
// Implementation of the PhiOpConvertPass.
void PhiOpConvertPass::runOnFunction() {
convertStage();
......@@ -191,7 +235,16 @@ void PhiOpConvertPass::dispatchStage() {
.output();
phi_context[infrt::TargetType::CPU] = context_value;
} break;
case infrt::TargetType::GPU:
case infrt::TargetType::GPU: {
auto context_value =
builder
.create<infrt::phi::CreateGPUContextOp>(
kernel_op.getLoc(),
infrt::phi::ContextType::get(kernel_op.getContext(),
infrt::TargetType::GPU))
.output();
phi_context[infrt::TargetType::GPU] = context_value;
} break;
case infrt::TargetType::UNK:
default:
LOG(FATAL) << "Unsupported TargetType";
......@@ -237,17 +290,6 @@ void PhiOpConvertPass::dispatchStage() {
}
}
PhiOpConvertPass::PhiOpConvertPass() {
if (!valid_places_options_.hasValue()) {
valid_places_.emplace_back(infrt::TargetType::CPU,
infrt::PrecisionType::FLOAT32,
infrt::LayoutType::NCHW);
return;
}
LOG(FATAL) << "To be done for specifying places in command line";
}
void PhiOpConvertPass::getDependentDialects(
mlir::DialectRegistry &registry) const {
registry.insert<infrt::InfrtDialect>();
......@@ -265,7 +307,3 @@ std::unique_ptr<mlir::Pass> infrt::CreatePhiOpCvtPass(
std::vector<Place> valid_places) {
return std::make_unique<PhiOpConvertPass>(valid_places);
}
std::unique_ptr<mlir::Pass> infrt::CreatePhiOpCvtPass() {
return std::make_unique<PhiOpConvertPass>();
}
......@@ -23,6 +23,4 @@ namespace infrt {
*/
std::unique_ptr<mlir::Pass> CreatePhiOpCvtPass(std::vector<Place> valid_places);
std::unique_ptr<mlir::Pass> CreatePhiOpCvtPass();
} // namespace infrt
// RUN: infrtopt -phi-op-convert -infrt-op-fuse %s
// RUN: infrtopt -phi-op-convert=valid-targets=CPU-FP32-NCHW -infrt-op-fuse %s
// CHECK-LABEL: @ops
func @ops(%a:!infrt.lod_tensor<?xf32,0>, %b:!infrt.lod_tensor<?xf32,0>) {
......
......@@ -444,7 +444,7 @@ module {
%387 = "pd.flatten_contiguous_range"(%386) {start_axis = 1 : si32, stop_axis = 3 : si32} : (!infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
%388 = "pd.matmul_v2"(%387, %245) {trans_x = false, trans_y = false} : (!infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
%389 = "pd.elementwise_add"(%388, %30) {axis = 1 : si32} : (!infrt.dense_tensor<CPU, FP32, NCHW>, !infrt.dense_tensor<CPU, FP32, NCHW>) -> !infrt.dense_tensor<CPU, FP32, NCHW>
infrt.return %270 : !infrt.dense_tensor<CPU, FP32, NCHW>
infrt.return %389 : !infrt.dense_tensor<CPU, FP32, NCHW>
}
func @main() {
......
......@@ -2931,4 +2931,5 @@ void WhereIndexInferMeta(const MetaTensor& condition, MetaTensor* out) {
} // namespace phi
PD_REGISTER_INFER_META_FN(copy_to, phi::CopyToInferMeta);
PD_REGISTER_INFER_META_FN(flatten, phi::FlattenInferMeta);
PD_REGISTER_INFER_META_FN(split, phi::SplitInferMeta);
......@@ -201,7 +201,7 @@ fi
# infrt needs to temporarily use LOG(FATAL) during the debugging period, and will replace it with standard error format in the future.
NO_INFRT_FILES=`git diff --name-only upstream/develop | grep -v "tools/\|paddle/infrt/" || true`
HAS_LOG_FATAL=`git diff -U0 upstream/$BRANCH $NO_INFRT_FILES |grep "^+" |grep -o -m 1 "LOG(FATAL)" || true`
if [ ${HAS_LOG_FATAL} ] && [ "${GIT_PR_ID}" != "" ]; then
if [ ${NO_INFRT_FILES} ] && [ ${HAS_LOG_FATAL} ] && [ "${GIT_PR_ID}" != "" ]; then
echo_line="LOG(FATAL) is not recommended, because it will throw exception without standard stack information, so please use PADDLE_THROW macro here. If you have to use LOG(FATAL) here, please request chenwhql (Recommend), luotao1 or lanxianghit review and approve.\n"
check_approval 1 6836917 47554610 22561442
fi
......
{
"phi_apis":["conj", "dropout", "expand_as", "flatten", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth"],
"phi_apis":["conj", "dropout", "expand_as", "nll_loss", "psroi_pool", "roi_align", "roi_pool", "label_smooth"],
"phi_kernels":["equal_all"]
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册