未验证 提交 7e1dd338 编写于 作者: K kangguangli 提交者: GitHub

[IR] Add vector type support for program translator (#54035)

* add vector type support for program translator

* polish

* resolve conflicts

* add verify for combine/slice and unittests

* polish
上级 d73db135
...@@ -24,6 +24,8 @@ ...@@ -24,6 +24,8 @@
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/translator/program_translator.h" #include "paddle/fluid/translator/program_translator.h"
#include "paddle/fluid/translator/type_translator.h" #include "paddle/fluid/translator/type_translator.h"
#include "paddle/ir/builtin_op.h"
#include "paddle/ir/builtin_type.h"
#include "paddle/ir/ir_context.h" #include "paddle/ir/ir_context.h"
#include "paddle/ir/value.h" #include "paddle/ir/value.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
...@@ -84,23 +86,101 @@ inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) { ...@@ -84,23 +86,101 @@ inline ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) {
return op_info; return op_info;
} }
inline ir::Operation* InsertSliceOperationForTarget(
ir::IrContext* ctx,
TranslationContext* param_map,
ir::Program* program,
const VariableDefiningInfo& defining_info,
const std::string& arg_name) {
std::string slice_op_name(ir::SliceOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(slice_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{"index", ir::Int32_tAttribute::get(ctx, defining_info.idx_in_vector)},
};
ir::VectorType src_vec_type =
defining_info.value.type().dyn_cast<ir::VectorType>();
ir::Operation* operation =
ir::Operation::create({defining_info.value},
{src_vec_type[defining_info.idx_in_vector]},
op_attribute_map,
op_info);
program->InsertOp(operation);
ir::OpResult target_op_result = operation->GetResultByIndex(0);
(*param_map)[arg_name] = VariableDefiningInfo(target_op_result);
return operation;
}
inline ir::Operation* InsertCombineOperationForTarget(
ir::IrContext* ctx,
TranslationContext* param_map,
ir::Program* program,
const std::vector<std::string>& args) {
std::string combine_op_name(ir::CombineOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(combine_op_name);
std::vector<ir::OpResult> src_values;
std::vector<ir::Type> types_in_vec;
for (auto arg_name : args) {
auto defining_info = param_map->at(arg_name);
src_values.push_back(defining_info.value);
types_in_vec.push_back(defining_info.value.type());
}
ir::Type target_vec_type = ir::VectorType::get(ctx, types_in_vec);
ir::Operation* operation =
ir::Operation::create(src_values, {target_vec_type}, {}, op_info);
program->InsertOp(operation);
return operation;
}
inline std::vector<ir::OpResult> GenerateOperationInput( inline std::vector<ir::OpResult> GenerateOperationInput(
TranslationContext* param_map, const OpDesc& op_desc) { ir::IrContext* ctx,
TranslationContext* param_map,
ir::Program* program,
const OpDesc& op_desc) {
std::vector<ir::OpResult> op_inputs = {}; std::vector<ir::OpResult> op_inputs = {};
// scan all inputs to see if any of them is generated as a vector<Tensor>
// so need an additional `SliceOp` to take it out.
for (const auto& n : op_desc.Inputs()) { for (const auto& n : op_desc.Inputs()) {
auto& name = n.first; auto& name = n.first;
VLOG(10) << "[input retriving]"
<< "[" << op_desc.Type() << "]" << name;
auto& args = n.second; auto& args = n.second;
for (const auto& arg_name : args) { for (const auto& arg_name : args) {
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(
param_map->count(arg_name), param_map->count(arg_name),
0, 0,
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"arg %s as input should be exists before prasing %d", "arg %s.%s as input should be exists before prasing %d",
name,
arg_name, arg_name,
op_desc.Type())); op_desc.Type()));
op_inputs.push_back((*param_map)[arg_name]); auto defining_info = (*param_map)[arg_name];
if (defining_info.generated_by_vector) {
InsertSliceOperationForTarget(
ctx, param_map, program, defining_info, arg_name);
}
}
}
for (const auto& n : op_desc.Inputs()) {
auto& name = n.first;
VLOG(10) << "[input retriving]"
<< "[" << op_desc.Type() << "]" << name;
auto& args = n.second;
// if src type is Tensor or a Vector<Tensor> with size <= 1
if (args.size() <= 1) {
for (const auto& arg_name : args) {
auto defining_info = (*param_map)[arg_name];
op_inputs.push_back(defining_info.value);
}
// if src type is Vector<Tesnor> , need an additional `CombineOp` to
// assemble them.
} else {
auto* combine_op =
InsertCombineOperationForTarget(ctx, param_map, program, args);
op_inputs.push_back(combine_op->GetResultByIndex(0));
} }
} }
return op_inputs; return op_inputs;
...@@ -119,16 +199,39 @@ inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput( ...@@ -119,16 +199,39 @@ inline std::tuple<OpOutputTypeList, OpOutputMapping> GenerateOperationOutput(
VLOG(10) << "[output translating]" VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << name; << "[" << op_desc.Type() << "]" << name;
auto& args = n.second; auto& args = n.second;
for (const auto& arg_name : args) {
VarDesc* var = block->FindVarRecursive(arg_name);
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << name << " " << arg_name << " "
<< var->GetType();
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
arg_to_idx[arg_name] = op_output_types.size(); size_t cur_output_idx = op_output_types.size();
op_output_types.push_back(translated_var_type);
// if src type is Tensor or a Vector<Tensor> with size <= 1
if (args.size() <= 1) {
for (const auto& arg_name : args) {
VarDesc* var = block->FindVarRecursive(arg_name);
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << name << " " << arg_name
<< " " << var->GetType();
ir::Type translated_var_type =
type_translator[var->GetType()](ctx, *var);
arg_to_idx[arg_name] = cur_output_idx;
op_output_types.push_back(translated_var_type);
}
// if src type is Vector<Tesnor>
} else {
std::vector<ir::Type> types;
for (const auto& arg_name : args) {
VarDesc* var = block->FindVarRecursive(arg_name);
VLOG(10) << "[output translating]"
<< "[" << op_desc.Type() << "]" << name << " " << arg_name
<< " " << var->GetType();
ir::Type translated_var_type =
type_translator[var->GetType()](ctx, *var);
types.push_back(translated_var_type);
arg_to_idx[arg_name] = cur_output_idx;
}
ir::Type vec_type = ir::VectorType::get(ctx, types);
op_output_types.push_back(vec_type);
} }
} }
return {op_output_types, arg_to_idx}; return {op_output_types, arg_to_idx};
...@@ -143,12 +246,17 @@ inline void RecordOpResultMapping(TranslationContext* param_map, ...@@ -143,12 +246,17 @@ inline void RecordOpResultMapping(TranslationContext* param_map,
VLOG(10) << "[output recording]" VLOG(10) << "[output recording]"
<< "[" << op_desc.Type() << "]" << name; << "[" << op_desc.Type() << "]" << name;
auto& args = n.second; auto& args = n.second;
size_t idx_in_vector = 0;
for (const auto& arg_name : args) { for (const auto& arg_name : args) {
auto idx = arg_to_idx.at(arg_name); auto idx = arg_to_idx.at(arg_name);
VLOG(10) << "[output recording]" VLOG(10) << "[output recording]"
<< "[" << op_desc.Type() << "]" << arg_name << " " << idx; << "[" << op_desc.Type() << "]" << arg_name << " " << idx;
(*param_map)[arg_name] = operation->GetResultByIndex(idx); ir::OpResult value = operation->GetResultByIndex(idx);
bool generated_by_vector = value.type().isa<ir::VectorType>();
(*param_map)[arg_name] = VariableDefiningInfo(
value, generated_by_vector, generated_by_vector ? idx_in_vector : -1);
idx_in_vector++;
} }
} }
} }
...@@ -157,7 +265,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx, ...@@ -157,7 +265,7 @@ ir::Operation* GeneralOpHandler(ir::IrContext* ctx,
TranslationContext* param_map, TranslationContext* param_map,
ir::Program* program, ir::Program* program,
const OpDesc& op_desc) { const OpDesc& op_desc) {
auto op_inputs = GenerateOperationInput(param_map, op_desc); auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc);
OpOutputMapping arg_to_idx; OpOutputMapping arg_to_idx;
OpOutputTypeList op_output_types = {}; OpOutputTypeList op_output_types = {};
...@@ -193,7 +301,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx, ...@@ -193,7 +301,7 @@ ir::Operation* FetchOpHandler(ir::IrContext* ctx,
TranslationContext* param_map, TranslationContext* param_map,
ir::Program* program, ir::Program* program,
const OpDesc& op_desc) { const OpDesc& op_desc) {
auto op_inputs = GenerateOperationInput(param_map, op_desc); auto op_inputs = GenerateOperationInput(ctx, param_map, program, op_desc);
OpOutputTypeList op_output_types = {}; OpOutputTypeList op_output_types = {};
auto op_info = LoopkUpOpInfo(ctx, op_desc); auto op_info = LoopkUpOpInfo(ctx, op_desc);
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/translator/program_translator.h"
#include "paddle/ir/ir_context.h" #include "paddle/ir/ir_context.h"
#include "paddle/ir/operation.h" #include "paddle/ir/operation.h"
#include "paddle/ir/program.h" #include "paddle/ir/program.h"
...@@ -28,8 +29,6 @@ ...@@ -28,8 +29,6 @@
namespace paddle { namespace paddle {
namespace translator { namespace translator {
using TranslationContext = std::unordered_map<std::string, ir::OpResult>;
class OpTranslator { class OpTranslator {
public: public:
using ResultIdx = size_t; using ResultIdx = size_t;
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/translator/op_translator.h" #include "paddle/fluid/translator/op_translator.h"
#include "paddle/fluid/translator/type_translator.h"
#include "paddle/ir/attribute.h" #include "paddle/ir/attribute.h"
#include "paddle/ir/builtin_op.h" #include "paddle/ir/builtin_op.h"
#include "paddle/ir/builtin_type.h" #include "paddle/ir/builtin_type.h"
...@@ -38,6 +39,11 @@ ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program, ...@@ -38,6 +39,11 @@ ProgramTranslator::ProgramTranslator(const ProgramDesc* legacy_program,
ctx = ir::IrContext::Instance(); ctx = ir::IrContext::Instance();
} }
const std::unordered_set<std::string> ProgramTranslator::no_cast_var_names = {
"feed",
"fetch",
};
void ProgramTranslator::Translate() { void ProgramTranslator::Translate() {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
legacy_program->Size(), legacy_program->Size(),
...@@ -59,19 +65,24 @@ void ProgramTranslator::Translate() { ...@@ -59,19 +65,24 @@ void ProgramTranslator::Translate() {
void ProgramTranslator::ExtractParameterFromSingleBlock( void ProgramTranslator::ExtractParameterFromSingleBlock(
const BlockDesc& block) { const BlockDesc& block) {
auto& type_translator = TypeTranslator::instance();
for (auto& var : block.AllVars()) { for (auto& var : block.AllVars()) {
if (!var->Persistable()) continue; if (!var->Persistable()) continue;
if (param_map.count(var->Name()) != 0) continue; if (param_map.count(var->Name()) != 0) continue;
if (no_cast_var_names.count(var->Name()) != 0) continue;
std::string get_parameter_op_name(ir::GetParameterOp::name()); std::string get_parameter_op_name(ir::GetParameterOp::name());
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name); ir::OpInfo op_info = ctx->GetRegisteredOpInfo(get_parameter_op_name);
std::unordered_map<std::string, ir::Attribute> op_attribute_map = { std::unordered_map<std::string, ir::Attribute> op_attribute_map = {
{var->Name(), ir::StrAttribute::get(ctx, var->Name())}, {var->Name(), ir::StrAttribute::get(ctx, var->Name())},
}; };
ir::Type translated_var_type = type_translator[var->GetType()](ctx, *var);
ir::Operation* operation = ir::Operation::create( ir::Operation* operation = ir::Operation::create(
{}, {ir::Float32Type::get(ctx)}, op_attribute_map, op_info); {}, {translated_var_type}, op_attribute_map, op_info);
program->InsertOp(operation); program->InsertOp(operation);
param_map[var->Name()] = operation->GetResultByIndex(0); param_map[var->Name()] =
VariableDefiningInfo(operation->GetResultByIndex(0));
VLOG(10) << "[op translated][get parameter]" << operation; VLOG(10) << "[op translated][get parameter]" << operation;
program->SetParameter(var->Name(), nullptr); program->SetParameter(var->Name(), nullptr);
......
...@@ -27,7 +27,25 @@ ...@@ -27,7 +27,25 @@
namespace paddle { namespace paddle {
namespace translator { namespace translator {
using TranslationContext = std::unordered_map<std::string, ir::OpResult>; struct VariableDefiningInfo {
VariableDefiningInfo(ir::OpResult value,
bool generated_by_vector = false,
int idx_in_vector = -1)
: value(value),
generated_by_vector(generated_by_vector),
idx_in_vector(idx_in_vector) {}
VariableDefiningInfo() {}
ir::OpResult value;
bool generated_by_vector =
false; // true if target variabe is generated by Vector<Tensor>
int idx_in_vector =
-1; // positive if target variabe is generated by Vector<Tensor>
};
using TranslationContext =
std::unordered_map<std::string, VariableDefiningInfo>;
class ProgramTranslator { class ProgramTranslator {
using ProgramDesc = ::paddle::framework::ProgramDesc; using ProgramDesc = ::paddle::framework::ProgramDesc;
...@@ -45,6 +63,14 @@ class ProgramTranslator { ...@@ -45,6 +63,14 @@ class ProgramTranslator {
TranslationContext param_map; TranslationContext param_map;
ir::IrContext* ctx; ir::IrContext* ctx;
/// In the legacy program desc, there are two special named varibales:
/// 1. "feed", the input variable of feed op
/// 2. "fetch", the output variable of fetch op
/// However, new feed has no input and new fetch has no output
/// So we don't handle these two vairables when
/// `ExtractParameterFromSingleBlock`
static const std::unordered_set<std::string> no_cast_var_names;
void ExtractParameterFromSingleBlock(const BlockDesc& block); void ExtractParameterFromSingleBlock(const BlockDesc& block);
void InsertOperationToSingleBlock(const BlockDesc& block); void InsertOperationToSingleBlock(const BlockDesc& block);
}; };
......
...@@ -22,10 +22,7 @@ ...@@ -22,10 +22,7 @@
namespace paddle { namespace paddle {
using LegacyProgramDesc = ::paddle::framework::ProgramDesc; std::unique_ptr<::ir::Program> TranslateLegacyProgramToProgram(
using Program = ::ir::Program; const ::paddle::framework::ProgramDesc& legacy_program);
std::unique_ptr<Program> TranslateLegacyProgramToProgram(
const LegacyProgramDesc& legacy_program);
} // namespace paddle } // namespace paddle
...@@ -22,6 +22,10 @@ ...@@ -22,6 +22,10 @@
namespace paddle { namespace paddle {
namespace translator { namespace translator {
using OpDesc = paddle::framework::OpDesc;
using BlockDesc = paddle::framework::BlockDesc;
using VarDesc = paddle::framework::VarDesc;
using VarType = paddle::framework::proto::VarType;
using DenseTensorType = paddle::dialect::DenseTensorType; using DenseTensorType = paddle::dialect::DenseTensorType;
using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage; using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage;
......
...@@ -27,13 +27,13 @@ ...@@ -27,13 +27,13 @@
namespace paddle { namespace paddle {
namespace translator { namespace translator {
using OpDesc = paddle::framework::OpDesc; using TypeTranslateFn =
using BlockDesc = paddle::framework::BlockDesc; std::function<ir::Type(ir::IrContext*, const framework::VarDesc&)>;
using VarDesc = paddle::framework::VarDesc;
using VarType = paddle::framework::proto::VarType;
using TypeTranslateFn = std::function<ir::Type(ir::IrContext*, const VarDesc&)>;
class TypeTranslator { class TypeTranslator {
public:
using VarType = paddle::framework::proto::VarType;
private: private:
TypeTranslator(); // Disallow instantiation outside of the class. TypeTranslator(); // Disallow instantiation outside of the class.
std::unordered_map<VarType::Type, TypeTranslateFn> handlers; std::unordered_map<VarType::Type, TypeTranslateFn> handlers;
......
...@@ -44,7 +44,10 @@ void BuiltinDialect::initialize() { ...@@ -44,7 +44,10 @@ void BuiltinDialect::initialize() {
ir::Int64_tAttribute, ir::Int64_tAttribute,
ir::ArrayAttribute>(); ir::ArrayAttribute>();
RegisterOps<ir::GetParameterOp, ir::SetParameterOp>(); RegisterOps<ir::GetParameterOp,
ir::SetParameterOp,
ir::CombineOp,
ir::SliceOp>();
} }
} // namespace ir } // namespace ir
...@@ -13,7 +13,10 @@ ...@@ -13,7 +13,10 @@
// limitations under the License. // limitations under the License.
#include "paddle/ir/builtin_op.h" #include "paddle/ir/builtin_op.h"
#include "paddle/ir/builtin_attribute.h" #include "paddle/ir/builtin_attribute.h"
#include "paddle/ir/builtin_type.h"
#include "paddle/phi/core/enforce.h"
namespace ir { namespace ir {
const char *GetParameterOp::attributes_name[attributes_num] = { const char *GetParameterOp::attributes_name[attributes_num] = {
...@@ -58,4 +61,103 @@ void SetParameterOp::verify(const std::vector<ir::OpResult> &inputs, ...@@ -58,4 +61,103 @@ void SetParameterOp::verify(const std::vector<ir::OpResult> &inputs,
} }
} }
void CombineOp::verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
// outputs.size() == 1
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
// outputs[0].type == Vector<Type>
PADDLE_ENFORCE(outputs[0].isa<ir::VectorType>(),
phi::errors::PreconditionNotMet(
"The type %s of outputs[0] must be equal to VectorType.",
outputs[0]));
ir::VectorType output_type = outputs[0].dyn_cast<ir::VectorType>();
// inputs.size() == outputs[0].size()
PADDLE_ENFORCE_EQ(
output_type.size(),
inputs.size(),
phi::errors::PreconditionNotMet(
"The size %d of outputs[0] must be equal to size %d of inputs.",
output_type.size(),
inputs.size()));
// forall i in inputs.size(): inputs[i].type == outputs[0][i].type
for (size_t i = 0; i < inputs.size(); i++) {
PADDLE_ENFORCE_EQ(
output_type[i],
inputs[i].type(),
phi::errors::PreconditionNotMet("The type %s of outputs[0][%d] must be "
"equal to type %s of inputs[%d].",
output_type[i],
i,
inputs[i].type(),
i));
}
}
const char *SliceOp::attributes_name[attributes_num] = {"index"};
void SliceOp::verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes) {
// inputs.size() == 1
PADDLE_ENFORCE_EQ(
inputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 1.", inputs.size()));
// inputs[0].type == Vector<Type>
PADDLE_ENFORCE(inputs[0].type().isa<ir::VectorType>(),
phi::errors::PreconditionNotMet(
"The type %s of inputs[0] must be equal to VectorType.",
inputs[0].type()));
ir::VectorType input_type = inputs[0].type().dyn_cast<ir::VectorType>();
// outputs.size() == 1
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", outputs.size()));
// attributes contains index: Int32
PADDLE_ENFORCE_NE(
attributes.count("index"),
0,
phi::errors::PreconditionNotMet("The attributes must contains index."));
const ir::Attribute &attr = attributes.at("index");
PADDLE_ENFORCE(
attr.isa<ir::Int32_tAttribute>(),
phi::errors::PreconditionNotMet("The attribute index must be INT32."));
auto index = attr.dyn_cast<ir::Int32_tAttribute>().data();
// index >= 0 and < inputs[0].size()
PADDLE_ENFORCE_GE(
index,
0,
phi::errors::PreconditionNotMet(
"The index %d must be greater or equal than 0.", index));
PADDLE_ENFORCE_LT(
index,
input_type.size(),
phi::errors::PreconditionNotMet(
"The index %d must be less or equal than size %d of inputs[0].",
index,
input_type.size()));
// inputs[index].type == outputs[0].type
PADDLE_ENFORCE_EQ(
input_type[index],
outputs[0],
phi::errors::PreconditionNotMet(
"The type %s of inputs[%d] must be equal to type %s of outputs[0].",
input_type[index],
index,
outputs[0]));
}
} // namespace ir } // namespace ir
...@@ -33,7 +33,7 @@ class GetParameterOp : public ir::Op<GetParameterOp> { ...@@ -33,7 +33,7 @@ class GetParameterOp : public ir::Op<GetParameterOp> {
}; };
/// ///
/// \brief GetParameterOp: SetParameterOp(OpOperand, {StrAttribute, /// \brief SetParameterOp: SetParameterOp(OpOperand, {StrAttribute,
/// StrAttribute}) /// StrAttribute})
/// ///
class SetParameterOp : public ir::Op<SetParameterOp> { class SetParameterOp : public ir::Op<SetParameterOp> {
...@@ -47,4 +47,38 @@ class SetParameterOp : public ir::Op<SetParameterOp> { ...@@ -47,4 +47,38 @@ class SetParameterOp : public ir::Op<SetParameterOp> {
const ir::AttributeMap &attributes); const ir::AttributeMap &attributes);
}; };
///
/// \brief CombineOp: CombineOp(OpOperand)
///
class CombineOp : public ir::Op<CombineOp> {
public:
using Op::Op;
static const char *name() { return "builtin.combine"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;
static void verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
};
///
/// \brief SliceOp: SliceOp(OpOperand)
///
class SliceOp : public ir::Op<SliceOp> {
public:
using Op::Op;
static const char *name() { return "builtin.slice"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void verify(const std::vector<ir::OpResult> &inputs,
const std::vector<ir::Type> &outputs,
const ir::AttributeMap &attributes);
};
} // namespace ir } // namespace ir
...@@ -28,6 +28,21 @@ namespace ir { ...@@ -28,6 +28,21 @@ namespace ir {
namespace { namespace {
constexpr char newline[] = "\n"; constexpr char newline[] = "\n";
template <typename ForwardIterator, typename UnaryFunctor, typename NullFunctor>
void PrintInterleave(ForwardIterator begin,
ForwardIterator end,
UnaryFunctor print_func,
NullFunctor between_func) {
if (begin == end) return;
print_func(*begin);
begin++;
for (; begin != end; begin++) {
between_func();
print_func(*begin);
}
}
} // namespace } // namespace
class Printer { class Printer {
...@@ -47,6 +62,15 @@ class Printer { ...@@ -47,6 +62,15 @@ class Printer {
os << "i32"; os << "i32";
} else if (type.isa<ir::Int64Type>()) { } else if (type.isa<ir::Int64Type>()) {
os << "i64"; os << "i64";
} else if (type.isa<ir::VectorType>()) {
os << "vec<";
auto inner_types = type.dyn_cast<ir::VectorType>().data();
PrintInterleave(
inner_types.begin(),
inner_types.end(),
[this](ir::Type v) { this->PrintType(v); },
[this]() { this->os << ","; });
os << ">";
} else { } else {
auto& dialect = type.dialect(); auto& dialect = type.dialect();
dialect.PrintType(type, os); dialect.PrintType(type, os);
...@@ -77,22 +101,6 @@ class ProgramPrinter : public Printer { ...@@ -77,22 +101,6 @@ class ProgramPrinter : public Printer {
} }
} }
template <typename ForwardIterator,
typename UnaryFunctor,
typename NullFunctor>
void PrintInterleave(ForwardIterator begin,
ForwardIterator end,
UnaryFunctor print_func,
NullFunctor between_func) {
if (begin == end) return;
print_func(*begin);
begin++;
for (; begin != end; begin++) {
between_func();
print_func(*begin);
}
}
void PrintValue(ir::Value v) { void PrintValue(ir::Value v) {
const void* key = static_cast<const void*>(v.impl()); const void* key = static_cast<const void*>(v.impl());
auto ret = aliases.find(key); auto ret = aliases.find(key);
......
...@@ -16,6 +16,11 @@ ...@@ -16,6 +16,11 @@
#include "paddle/ir/dialect.h" #include "paddle/ir/dialect.h"
namespace ir { namespace ir {
IrContext *Type::ir_context() const { return dialect().ir_context(); } IrContext* Type::ir_context() const { return dialect().ir_context(); }
std::ostream& operator<<(std::ostream& os, Type type) {
type.print(os);
return os;
}
} // namespace ir } // namespace ir
...@@ -89,6 +89,8 @@ class Type { ...@@ -89,6 +89,8 @@ class Type {
const Storage *storage_{nullptr}; const Storage *storage_{nullptr};
}; };
std::ostream &operator<<(std::ostream &os, Type type);
} // namespace ir } // namespace ir
namespace std { namespace std {
......
...@@ -211,3 +211,62 @@ TEST(program_test, program) { ...@@ -211,3 +211,62 @@ TEST(program_test, program) {
EXPECT_EQ(ops.size() == 4, true); EXPECT_EQ(ops.size() == 4, true);
EXPECT_EQ(program.parameters_num() == 3, true); EXPECT_EQ(program.parameters_num() == 3, true);
} }
TEST(program_test, slice_combine_test) {
// (1) Init environment.
ir::IrContext *ctx = ir::IrContext::Instance();
ctx->GetOrRegisterDialect<ir::BuiltinDialect>();
// (2) Create an empty program object
ir::Program program;
// ir::Program *program = new ir::Program();
EXPECT_EQ(program.ops().size() == 0, true);
// (3) Create a float32 DenseTensor Parameter and save into Program
ir::Type fp32_dtype = ir::Float32Type::get(ctx);
// (4) Def a = GetParameterOp("a")
std::string op1_name = ir::GetParameterOp::name();
ir::OpInfo op1_info = ctx->GetRegisteredOpInfo(op1_name);
std::unordered_map<std::string, ir::Attribute> op1_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "a")}};
ir::Operation *op1 =
ir::Operation::create({}, {fp32_dtype}, op1_attribute, op1_info);
program.InsertOp(op1);
// (5) Def b = GetParameterOp("b")
std::string op2_name = std::string(ir::GetParameterOp::name());
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo(op2_name);
std::unordered_map<std::string, ir::Attribute> op2_attribute{
{"parameter_name", ir::StrAttribute::get(ctx, "b")}};
ir::Operation *op2 =
ir::Operation::create({}, {fp32_dtype}, op2_attribute, op2_info);
program.InsertOp(op2);
// (6) Def combine_op = CombineOp("a", "b")
std::string combine_op_name = std::string(ir::CombineOp::name());
ir::OpInfo combine_op_info = ctx->GetRegisteredOpInfo(combine_op_name);
ir::Type output_type =
ir::VectorType::get(ctx, std::vector<ir::Type>({fp32_dtype, fp32_dtype}));
ir::Operation *combine_op = ir::Operation::create(
{op1->GetResultByIndex(0), op2->GetResultByIndex(0)},
{output_type},
{},
combine_op_info);
program.InsertOp(combine_op);
// (7) Def slice_op = SliceOp(combine_op, 0)
std::string slice_op_name = std::string(ir::SliceOp::name());
ir::OpInfo slice_op_info = ctx->GetRegisteredOpInfo(slice_op_name);
ir::Attribute index_attr = ir::Int32_tAttribute::get(ctx, 0);
ir::Operation *slice_op =
ir::Operation::create({combine_op->GetResultByIndex(0)},
{fp32_dtype},
{{"index", index_attr}},
slice_op_info);
program.InsertOp(slice_op);
// (8) Traverse Program
std::list<ir::Operation *> ops = program.ops();
EXPECT_EQ(ops.size() == 4, true);
}
...@@ -59,6 +59,7 @@ TEST(PaddleDialectTest, Translator) { ...@@ -59,6 +59,7 @@ TEST(PaddleDialectTest, Translator) {
// auto program = paddle::TranslateLegacyProgramToProgram(p); // auto program = paddle::TranslateLegacyProgramToProgram(p);
// std::list<ir::Operation *> ops = program->ops(); // std::list<ir::Operation *> ops = program->ops();
// EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num()); // ops.size() = op size in BlockDesc + get_parameter_op + combine op
// VLOG(0) << *program << std::endl; // EXPECT_EQ(ops.size(), p.Block(0).OpSize() + program->parameters_num() +
// 20); std::cout << *program << std::endl;
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册