未验证 提交 fb891776 编写于 作者: L liuruyan 提交者: GitHub

add symbolTable & symbolicDimProduct & symbolicDimMgr. (#56351)

* add symbolicDimProduct & symbolicDimMgr without method shape_constraint related

* split ddim in phi, add a target ddim, used by pd_type

* add pd_type.cc to ir_shape CMakeLists
上级 5d43f5e4
add_subdirectory(ir)
file(GLOB_RECURSE SHAPE_SRCS "*.cc")
ir_library(
ir_shape
SRCS
${SHAPE_SRCS}
${PADDLE_SOURCE_DIR}/paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.cc
DEPS
ddim
ir_core)
file(GLOB SHAPE_SRCS "*.cc")
ir_library(ir_shape SRCS ${SHAPE_SRCS} DEPS ir_core)
......@@ -54,7 +54,7 @@ void SymbolicDim::Build(
argument.AddAttribute("knownNonSizeZero", attr_knownNonSizeZero);
}
std::string SymbolicDim::getSymName() {
const std::string SymbolicDim::getSymName() {
return attribute<ir::StrAttribute>("sym_name").AsString();
}
int64_t SymbolicDim::getValue() {
......@@ -103,6 +103,35 @@ void SymbolicDim::updateKnownNonSizeZero(bool attrValue) {
ir::BoolAttribute::get(ir::IrContext::Instance(), attrValue));
}
bool SymbolicDim::isDynamic() {
return getValue() == -100000;
} // TODO(zhangbo): getValue() == ShapedType::kDynamic;
bool SymbolicDim::merge(SymbolicDim other) {
if (!isDynamic() && !other.isDynamic() && getValue() != other.getValue())
return false;
if (isDynamic() && !other.isDynamic()) updateValue(other.getValue());
bool knownNonNegativeFlag =
getKnownNonNegative() || other.getKnownNonNegative();
bool knownNegativeOneFlag =
getKnownNegativeOne() || other.getKnownNegativeOne();
bool knownNonSizeOneFlag = getKnownNonSizeOne() ||
other.getKnownNonSizeOne() || knownNegativeOneFlag;
bool knownNonSizeZeroFlag = getKnownNonSizeZero() ||
other.getKnownNonSizeZero() ||
knownNegativeOneFlag;
if (knownNonNegativeFlag && knownNegativeOneFlag) return false;
updateKnownNonSizeZero(knownNonSizeZeroFlag);
updateKnownNonSizeOne(knownNonSizeOneFlag);
updateKnownNegativeOne(knownNegativeOneFlag);
updateKnownNonNegative(knownNonNegativeFlag);
return true;
}
} // namespace dialect
} // namespace ir
......
......@@ -37,7 +37,7 @@ class IR_API SymbolicDim : public Op<SymbolicDim> {
bool knownNegativeOne = false,
bool knownNonSizeOne = false,
bool knownNonSizeZero = false);
std::string getSymName();
const std::string getSymName();
int64_t getValue();
bool getKnownNonNegative();
bool getKnownNegativeOne();
......@@ -46,11 +46,14 @@ class IR_API SymbolicDim : public Op<SymbolicDim> {
void updateSymName(std::string attrValue);
void updateValue(int64_t attrValue);
void updateKnownNonNegative(bool attrValue);
void updateKnownNegativeOne(bool attrValue);
void updateKnownNonSizeOne(bool attrValue);
void updateKnownNonSizeZero(bool attrValue);
bool isDynamic();
bool merge(SymbolicDim other);
void Verify() {}
};
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/ir/dialect/shape/utils/shape_utils.h"
#include <string>
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
namespace ir {
bool compareSymbolicDimNames(const std::string& lhs, const std::string& rhs) {
if (lhs.size() < 1 || lhs[0] != 'S' && lhs[0] != 'C') return lhs < rhs;
if (rhs.size() < 1 || rhs[0] != 'S' && rhs[0] != 'C') return lhs < rhs;
int64_t lhsIdx = 0, rhsIdx = 0;
try {
lhsIdx = stol(lhs.substr(1));
rhsIdx = stol(rhs.substr(1));
} catch (const std::exception& e) {
IR_THROW("Invalid symbolic name");
}
return (lhs[0] < rhs[0]) || (lhs[0] == rhs[0] && lhsIdx < rhsIdx);
}
ir::Operation* SymbolTable::lookup(const std::string& name) const {
auto it = symbolTableMap_.find(name);
return it != symbolTableMap_.end() ? it->second : nullptr;
}
const std::string SymbolTable::insert(ir::Operation* symbol) {
std::string name;
if (symbol->HasAttribute("sym_name")) {
name = symbol->dyn_cast<SymbolicDim>().getSymName();
}
// TODO(liujinnan): add constraint_func name branch.
symbolTableMap_.insert({name, symbol});
return name;
}
const std::string SymbolicDimMgr::getNextName() {
std::string name;
do {
name = "S" + std::to_string(nextSymbolicIdx_++);
} while (!symbolNameSet_.insert(name).second);
return name;
}
SymbolicDimMgr::SymbolicDimMgr(ir::ModuleOp m) : m_(m), symbolTable_(m_) {}
SymbolicDim SymbolicDimMgr::newSymbolicDim(const std::string& name) {
::ir::Builder builder = ::ir::Builder(m_.ir_context(), m_.block());
ir::dialect::SymbolicDim symbol = builder.Build<ir::dialect::SymbolicDim>(
name.empty() ? getNextName() : name);
symbolDimUnionSet_[symbol] = symbol;
symbolTable_.insert(symbol);
return symbol;
}
SymbolicDim SymbolicDimMgr::newConstantSymbolicDim(int64_t val) {
auto it = constantSymbolicDimMap_.find(val);
if (it == constantSymbolicDimMap_.end()) {
auto name = "C" + std::to_string(val);
it = constantSymbolicDimMap_
.insert(std::make_pair(val, newSymbolicDim(name)))
.first;
it->second.updateValue(val);
}
return getRootSymbolicDim(it->second);
}
std::vector<SymbolicDim> SymbolicDimMgr::createSymbolicDimsForRankedValue(
ir::Value value) {
std::vector<SymbolicDim> symbols;
auto dims = value.type().dyn_cast<paddle::dialect::DenseTensorType>().dims();
for (int idx = 0; idx < dims.size(); ++idx) {
symbols.push_back(
dims[idx] == -100000 // TODO(zhangbo): value = ShapedType::kDynamic
? newSymbolicDim()
: newConstantSymbolicDim(dims[idx]));
}
return symbols;
}
SymbolicDim SymbolicDimMgr::getRootSymbolicDim(SymbolicDim symbol) {
SymbolicDim current = symbol;
std::vector<SymbolicDim> path;
while (symbolDimUnionSet_[current] != current) {
path.push_back(current);
current = symbolDimUnionSet_[current];
}
for (SymbolicDim sym : path) symbolDimUnionSet_[sym] = current;
return current;
}
bool SymbolicDimMgr::isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) {
SymbolicDim lhsRoot = getRootSymbolicDim(lhs);
SymbolicDim rhsRoot = getRootSymbolicDim(rhs);
return lhsRoot == rhsRoot;
}
bool SymbolicDimMgr::mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs) {
SymbolicDim lhsRoot = getRootSymbolicDim(lhs);
SymbolicDim rhsRoot = getRootSymbolicDim(rhs);
if (lhsRoot != rhsRoot) {
if (compareSymbolicDimNames(lhsRoot.getSymName(), rhsRoot.getSymName())) {
if (!lhsRoot.merge(rhsRoot)) return false;
symbolDimUnionSet_[rhsRoot] = lhsRoot;
} else {
if (!rhsRoot.merge(lhsRoot)) return false;
symbolDimUnionSet_[lhsRoot] = rhsRoot;
}
}
return true;
}
} // namespace ir
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <unordered_map>
#include <unordered_set>
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/utils.h"
#include "paddle/ir/dialect/shape/ir/shape_op.h"
namespace ir {
using ir::dialect::SymbolicDim;
struct SymbolicDimProduct {
std::vector<SymbolicDim> symbols;
int64_t factor = 1;
bool empty() { return factor == 1 && symbols.empty(); }
};
inline bool operator==(const SymbolicDimProduct& lhs,
const SymbolicDimProduct& rhs) {
return lhs.factor == rhs.factor && lhs.symbols == rhs.symbols;
}
inline bool operator!=(const SymbolicDimProduct& lhs,
const SymbolicDimProduct& rhs) {
return !(lhs == rhs);
}
class SymbolTable {
public:
explicit SymbolTable(ir::Operation* symbolTableOp)
: symbolTableOp_(symbolTableOp) {}
ir::Operation* lookup(const std::string& name) const;
const std::string insert(Operation* symbol);
ir::Operation* getOp() const { return symbolTableOp_; }
private:
ir::Operation* symbolTableOp_;
std::unordered_map<std::string, ir::Operation*> symbolTableMap_;
};
struct SymDimHasher {
size_t operator()(const ir::dialect::SymbolicDim& symbol) const noexcept {
return std::hash<ir::Operation*>{}(symbol.operation());
}
};
struct SymProductHasher {
size_t operator()(const ir::SymbolicDimProduct& symProd) const noexcept {
size_t hash = std::hash<size_t>{}(symProd.symbols.size());
for (auto& symbol : symProd.symbols) {
hash = hash_combine(hash, SymDimHasher{}(symbol)); // NOLINT
}
hash = hash_combine(hash, std::hash<int64_t>{}(symProd.factor));
return hash;
}
};
class SymbolicDimMgr {
public:
explicit SymbolicDimMgr(ir::ModuleOp m);
SymbolicDim newSymbolicDim(const std::string& name = {});
SymbolicDim newConstantSymbolicDim(int64_t val);
std::vector<SymbolicDim> createSymbolicDimsForRankedValue(Value value);
SymbolicDim getRootSymbolicDim(SymbolicDim symbol);
bool isSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs);
SymbolTable& symbolTable() { return symbolTable_; }
bool mapSymbolicDimEqual(SymbolicDim lhs, SymbolicDim rhs);
private:
const std::string getNextName();
private:
ir::ModuleOp m_;
SymbolTable symbolTable_;
int64_t nextSymbolicIdx_ = 0;
std::unordered_set<std::string> symbolNameSet_;
std::unordered_map<SymbolicDim, SymbolicDim, SymDimHasher> symbolDimUnionSet_;
std::unordered_map<int64_t, SymbolicDim> constantSymbolicDimMap_;
// productEqualityMap_[A][B] == true : Product[A] == Product[B]
using SymbolicDimProductMap = std::unordered_map<
SymbolicDimProduct,
std::unordered_map<SymbolicDimProduct, bool, SymProductHasher>,
SymProductHasher>;
SymbolicDimProductMap productEqualityMap_;
};
} // namespace ir
......@@ -39,3 +39,5 @@ collect_srcs(
kernel_factory.cc
tensor_utils.cc
utils/type_info.cc)
cc_library(ddim SRCS ddim.cc)
......@@ -14,39 +14,132 @@
#include <gtest/gtest.h>
#include <map>
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
#include "paddle/ir/core/block.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_type.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/ir_context.h"
#include "paddle/ir/core/program.h"
#include "paddle/ir/dialect/shape/ir/shape_dialect.h"
#include "paddle/ir/dialect/shape/ir/shape_op.h"
#include "paddle/ir/dialect/shape/utils/shape_utils.h"
TEST(assist_struct_test, symbolic_dim) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Program program(ctx);
ctx->GetOrRegisterDialect<ir::dialect::ShapeDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
ir::dialect::SymbolicDim sym_dim = builder.Build<ir::dialect::SymbolicDim>(
ir::dialect::SymbolicDim symDim = builder.Build<ir::dialect::SymbolicDim>(
"S0", 10, false, false, false, false);
EXPECT_EQ(sym_dim.getValue(), 10);
EXPECT_EQ(sym_dim.getSymName(), "S0");
EXPECT_FALSE(sym_dim.getKnownNegativeOne());
EXPECT_FALSE(sym_dim.getKnownNonSizeOne());
EXPECT_FALSE(sym_dim.getKnownNonSizeZero());
EXPECT_FALSE(sym_dim.getKnownNonNegative());
sym_dim.updateValue(20);
sym_dim.updateSymName("S1");
sym_dim.updateKnownNegativeOne(true);
sym_dim.updateKnownNonSizeOne(true);
sym_dim.updateKnownNonSizeZero(true);
sym_dim.updateKnownNonNegative(true);
EXPECT_EQ(sym_dim.getValue(), 20);
EXPECT_EQ(sym_dim.getSymName(), "S1");
EXPECT_TRUE(sym_dim.getKnownNegativeOne());
EXPECT_TRUE(sym_dim.getKnownNonSizeOne());
EXPECT_TRUE(sym_dim.getKnownNonSizeZero());
EXPECT_TRUE(sym_dim.getKnownNonNegative());
ir::dialect::SymbolicDim symDim_ = builder.Build<ir::dialect::SymbolicDim>(
"S1", 10, false, false, false, false);
EXPECT_EQ(symDim.getValue(), 10);
EXPECT_EQ(symDim.getSymName(), "S0");
EXPECT_FALSE(symDim.getKnownNegativeOne());
EXPECT_FALSE(symDim.getKnownNonSizeOne());
EXPECT_FALSE(symDim.getKnownNonSizeZero());
EXPECT_FALSE(symDim.getKnownNonNegative());
EXPECT_FALSE(symDim.isDynamic());
EXPECT_TRUE(symDim.merge(symDim_));
symDim.updateValue(20);
symDim.updateSymName("S2");
symDim.updateKnownNegativeOne(true);
symDim.updateKnownNonSizeOne(true);
symDim.updateKnownNonSizeZero(true);
symDim.updateKnownNonNegative(true);
EXPECT_FALSE(symDim.merge(symDim_));
EXPECT_EQ(symDim.getValue(), 20);
EXPECT_EQ(symDim.getSymName(), "S2");
EXPECT_TRUE(symDim.getKnownNegativeOne());
EXPECT_TRUE(symDim.getKnownNonSizeOne());
EXPECT_TRUE(symDim.getKnownNonSizeZero());
EXPECT_TRUE(symDim.getKnownNonNegative());
}
TEST(assist_struct_test, symbolic_dim_product) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Program program(ctx);
ctx->GetOrRegisterDialect<ir::dialect::ShapeDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
ir::dialect::SymbolicDim symDim = builder.Build<ir::dialect::SymbolicDim>(
"S0", -100000, false, false, false, false);
ir::SymbolicDimProduct symDimProduct;
ir::SymbolicDimProduct symDimProduct_;
symDimProduct.symbols.push_back(symDim);
symDimProduct.factor *= 10;
EXPECT_EQ(symDimProduct.factor, 10);
EXPECT_NE(symDimProduct, symDimProduct_);
EXPECT_FALSE(symDimProduct.empty());
}
TEST(assist_struct_test, symbolic_dim_table) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Program program(ctx);
ctx->GetOrRegisterDialect<ir::dialect::ShapeDialect>();
ir::Builder builder = ir::Builder(ctx, program.block());
ir::dialect::SymbolicDim symDim = builder.Build<ir::dialect::SymbolicDim>(
"S0", 10, false, false, false, false);
ir::SymbolTable symbolTable(program.module_op());
EXPECT_EQ(symbolTable.insert(symDim), "S0");
EXPECT_EQ(symbolTable.lookup("S0")->dyn_cast<ir::dialect::SymbolicDim>(),
symDim);
EXPECT_EQ(symbolTable.lookup("S1"), nullptr);
EXPECT_EQ(symbolTable.getOp(), program.module_op());
}
TEST(assist_struct_test, symbolic_dim_mgr) {
ir::IrContext *ctx = ir::IrContext::Instance();
ir::Program program(ctx);
ctx->GetOrRegisterDialect<ir::dialect::ShapeDialect>();
ir::SymbolicDimMgr symDimMgr(program.module_op());
ir::dialect::SymbolicDim symDimS0 = symDimMgr.newSymbolicDim();
ir::dialect::SymbolicDim symDimS1 = symDimMgr.newSymbolicDim();
ir::dialect::SymbolicDim symDimC10 = symDimMgr.newConstantSymbolicDim(10);
symDimMgr.mapSymbolicDimEqual(symDimS0, symDimS1);
ir::Attribute attr_value = ir::StrAttribute::get(ctx, "op_attr");
ir::AttributeMap attr_map;
attr_map.insert(std::pair<std::string, ir::Attribute>("op", attr_value));
std::vector<ir::OpResult> op_inputs = {};
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
phi::DDim dims = {-100000, 2};
phi::DataLayout data_layout = phi::DataLayout::NCHW;
phi::LoD lod = {{0, 1, 2}};
size_t offset = 0;
std::vector<ir::Type> op_output_types = {
paddle::dialect::DenseTensorType::get(
ctx, fp32_dtype, dims, data_layout, lod, offset)};
ir::Operation *op =
ir::Operation::Create(op_inputs, attr_map, op_output_types, ir::OpInfo());
ir::Value res = op->result(0);
std::vector<ir::dialect::SymbolicDim> symDimVec =
symDimMgr.createSymbolicDimsForRankedValue(res);
EXPECT_EQ(symDimS0.getSymName(), "S0");
EXPECT_EQ(symDimS1.getSymName(), "S1");
EXPECT_EQ(symDimS1.getValue(), -100000);
EXPECT_EQ(symDimC10.getSymName(), "C10");
EXPECT_EQ(symDimC10.getValue(), 10);
EXPECT_EQ(symDimVec[0].getSymName(), "S2");
EXPECT_EQ(symDimVec[1].getSymName(), "C2");
EXPECT_EQ(symDimMgr.symbolTable()
.lookup("S0")
->dyn_cast<ir::dialect::SymbolicDim>(),
symDimS0);
EXPECT_EQ(symDimMgr.symbolTable()
.lookup("C10")
->dyn_cast<ir::dialect::SymbolicDim>(),
symDimC10);
EXPECT_EQ(symDimMgr.getRootSymbolicDim(symDimS1), symDimS0);
EXPECT_TRUE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimS1));
EXPECT_FALSE(symDimMgr.isSymbolicDimEqual(symDimS0, symDimC10));
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册