未验证 提交 4b6d2f5f 编写于 作者: H hong 提交者: GitHub

[NewIR]new ir support builtin slice op (#55381)

* new ir support builtin slice op

* fix phi kernel adaptor bug
上级 0dad9458
......@@ -957,7 +957,7 @@ void BuildOpFuncList(
if (op_name == "builtin.combine" || op_name == "pd.feed" ||
op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter") {
op_name == "builtin.get_parameter" || op_name == "builtin.slice") {
VLOG(6) << "skip process " << op_name;
continue;
}
......@@ -977,6 +977,7 @@ void BuildOpFuncList(
phi::MetaTensor,
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>((*it),
value_2_name_map,
scope,
......@@ -1003,6 +1004,7 @@ void BuildOpFuncList(
const phi::TensorBase*,
phi::TensorBase*,
paddle::small_vector<const phi::TensorBase*>,
paddle::small_vector<phi::TensorBase*>,
true>((*it),
value_2_name_map,
scope,
......
......@@ -20,11 +20,11 @@ namespace paddle {
namespace framework {
template <>
struct PhiVectorType<const phi::DenseTensor*> {
const char* type_name = "PhiTensorRefArray";
struct PhiVectorType<const framework::Variable*> {
const char* type_name = "VariableRefArray";
};
using TensorRefArray = PhiVector<const phi::DenseTensor*>;
using VariableRefArray = PhiVector<const framework::Variable*>;
} // namespace framework
} // namespace paddle
......@@ -41,6 +41,6 @@ template class TypeInfoTraits<phi::TensorBase, paddle::framework::FeedList>;
template class TypeInfoTraits<phi::TensorBase, egr::VariableCompatTensor>;
template class TypeInfoTraits<phi::TensorBase, paddle::prim::DescTensor>;
template class TypeInfoTraits<phi::TensorBase,
paddle::framework::TensorRefArray>;
paddle::framework::VariableRefArray>;
} // namespace phi
......@@ -212,7 +212,7 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
std::vector<float>,
std::vector<std::string>,
RawTensor,
TensorRefArray>;
VariableRefArray>;
template <typename T>
struct VarTypeTrait {
static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type");
......
......@@ -87,6 +87,7 @@ class PhiKernelAdaptor {
phi::MetaTensor,
phi::MetaTensor,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
paddle::small_vector<phi::MetaTensor, phi::kInputSmallVectorSize>,
false>((*it), name_map, scope_, nullptr, op_yaml_info_parser, &ctx);
infer_meta_impl->infer_meta_(&ctx);
......@@ -106,6 +107,7 @@ class PhiKernelAdaptor {
const phi::TensorBase*,
phi::TensorBase*,
paddle::small_vector<const phi::TensorBase*>,
paddle::small_vector<phi::TensorBase*>,
true>(
(*it), name_map, scope_, nullptr, op_yaml_info_parser, &kernel_ctx);
kernel_fn(&kernel_ctx);
......
......@@ -43,6 +43,9 @@
namespace ir {
using VariableNameMap =
std::unordered_map<const paddle::framework::Variable*, std::string>;
paddle::framework::Variable* CreateVar(ir::Value value,
const std::string& name,
paddle::framework::Scope* scope,
......@@ -89,6 +92,7 @@ void BuildValue(ir::Value value,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
VariableNameMap* variable_name_map,
int& count) { // NOLINT
auto inner_local_scope = local_scope != nullptr ? local_scope : scope;
std::string name;
......@@ -107,7 +111,7 @@ void BuildValue(ir::Value value,
} else if (value.type().isa<paddle::dialect::AllocatedSelectedRowsType>()) {
var->GetMutable<phi::SelectedRows>();
} else if (value.type().isa<ir::VectorType>()) {
auto tensor_array = var->GetMutable<paddle::framework::TensorRefArray>();
auto tensor_array = var->GetMutable<paddle::framework::VariableRefArray>();
for (size_t i = 0; i < value.type().dyn_cast<ir::VectorType>().size();
i++) {
PADDLE_ENFORCE(value.type()
......@@ -118,7 +122,9 @@ void BuildValue(ir::Value value,
"DenseTensorType"));
std::string name_i = "inner_var_" + std::to_string(count++);
auto var_i = CreateVar(value, name_i, scope, inner_local_scope);
tensor_array->emplace_back(var_i->GetMutable<phi::DenseTensor>());
var_i->GetMutable<phi::DenseTensor>();
tensor_array->emplace_back(var_i);
variable_name_map->emplace(var_i, name_i);
}
} else {
PADDLE_THROW(phi::errors::PreconditionNotMet(
......@@ -127,6 +133,7 @@ void BuildValue(ir::Value value,
}
void HandleForSpecialOp(ir::Operation* op,
const VariableNameMap& variable_name_map,
paddle::framework::Scope* scope,
paddle::framework::Scope* local_scope,
std::unordered_map<ir::Value, std::string>* name_map,
......@@ -180,7 +187,7 @@ void HandleForSpecialOp(ir::Operation* op,
}
auto var = CreateVar(out_value, name, scope, local_scope);
auto tensor_array = var->GetMutable<paddle::framework::TensorRefArray>();
auto tensor_array = var->GetMutable<paddle::framework::VariableRefArray>();
// clear tensor array
tensor_array->clear();
......@@ -192,8 +199,7 @@ void HandleForSpecialOp(ir::Operation* op,
true,
phi::errors::PreconditionNotMet("can not found input of combine op"));
tensor_array->emplace_back(
&(CreateVar(value, name_map->at(value), scope, local_scope)
->Get<phi::DenseTensor>()));
CreateVar(value, name_map->at(value), scope, local_scope));
}
}
......@@ -223,6 +229,34 @@ void HandleForSpecialOp(ir::Operation* op,
auto out_ptr = op->result(0);
name_map->emplace(out_ptr, param_name);
}
if (op_name == "builtin.slice") {
VLOG(6) << "Handle for builtin.slice";
auto out_value = op->result(0);
auto in_value = op->operand(0);
PADDLE_ENFORCE_EQ(name_map->count(in_value),
true,
phi::errors::PreconditionNotMet(
"input of buildin slice not in name map"));
int index =
op->attributes().at("index").dyn_cast<ir::Int32Attribute>().data();
auto in_var = scope->FindVar(name_map->at(in_value));
auto variable_array = in_var->Get<paddle::framework::VariableRefArray>();
PADDLE_ENFORCE_EQ(
variable_name_map.count(variable_array[index]),
true,
phi::errors::PreconditionNotMet("[%d] the variable in build slice "
"input MUST in variable name map",
index));
std::string var_name = variable_name_map.at(variable_array[index]);
name_map->emplace(out_value, var_name);
}
}
void HandleForInplaceOp(ir::Operation* op,
......@@ -242,7 +276,7 @@ void HandleForInplaceOp(ir::Operation* op,
paddle::dialect::OpYamlInfoParser yaml_parser(
op_info.GetInterfaceImpl<paddle::dialect::OpYamlInfoInterface>()
->get_op_info_());
VariableNameMap variable_name_map;
for (size_t i = 0; i < op->num_results(); ++i) {
ir::Value value = op->result(i);
std::string value_name = yaml_parser.OutputNames()[i];
......@@ -255,7 +289,8 @@ void HandleForInplaceOp(ir::Operation* op,
<< " (var: " << var_name << ")";
name_map->emplace(value, var_name);
} else {
BuildValue(value, scope, local_scope, name_map, count);
BuildValue(
value, scope, local_scope, name_map, &variable_name_map, count);
}
}
}
......@@ -273,8 +308,11 @@ void BuildScope(const ir::Block& block,
VLOG(6) << "Build: scope [" << scope << "] inner_local_scope ["
<< inner_local_scope << "]";
std::unordered_map<const paddle::framework::Variable*, std::string>
variable_name_map;
// int count = name_map->size();
int count = inner_local_scope->Size();
int count = name_map->size();
for (auto it = block.begin(); it != block.end(); ++it) {
ir::Operation* op = *it;
......@@ -288,9 +326,10 @@ void BuildScope(const ir::Block& block,
if (op_name == "pd.feed" || op_name == "pd.fetch" ||
op_name == "builtin.combine" || op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter") {
VLOG(4) << "HandleForSpecialOp: " << op_name;
HandleForSpecialOp(op, scope, inner_local_scope, name_map, count);
op_name == "builtin.get_parameter" || op_name == "builtin.slice") {
VLOG(6) << "HandleForSpecialOp: " << op_name;
HandleForSpecialOp(
op, variable_name_map, scope, inner_local_scope, name_map, count);
continue;
}
......@@ -306,7 +345,12 @@ void BuildScope(const ir::Block& block,
continue;
} else {
for (size_t i = 0; i < op->num_results(); ++i) {
BuildValue(op->result(i), scope, local_scope, name_map, count);
BuildValue(op->result(i),
scope,
local_scope,
name_map,
&variable_name_map,
count);
}
}
}
......
......@@ -75,7 +75,8 @@ void BuildScope(const ir::Block& block,
template <typename Context,
typename InType,
typename OutType,
typename ListType,
typename InListType,
typename OutListType,
bool is_kernel>
void BuildPhiContext(
ir::Operation* op,
......@@ -121,11 +122,12 @@ void BuildPhiContext(
if (var->IsType<phi::DenseTensor>()) {
const phi::TensorBase* tensor_in = &(var->Get<phi::DenseTensor>());
ctx->EmplaceBackInput(InType(tensor_in));
} else if (var->IsType<paddle::framework::TensorRefArray>()) {
ListType inputs;
auto& tensor_array = var->Get<paddle::framework::TensorRefArray>();
for (size_t i = 0; i < tensor_array.size(); ++i) {
inputs.emplace_back(InType(tensor_array[i]));
} else if (var->IsType<paddle::framework::VariableRefArray>()) {
InListType inputs;
auto& variable_array = var->Get<paddle::framework::VariableRefArray>();
for (size_t i = 0; i < variable_array.size(); ++i) {
inputs.emplace_back(InType(const_cast<phi::DenseTensor*>(
&(variable_array[i]->Get<phi::DenseTensor>()))));
}
ctx->EmplaceBackInputs(inputs);
} else {
......@@ -157,18 +159,21 @@ void BuildPhiContext(
VLOG(6) << "ctx->EmplaceBack mutable attr: " << t << "\t" << in_var_name;
if (tensor_attr_type == "paddle::dialect::IntArrayAttribute") {
if (ptr.type().isa<paddle::dialect::AllocatedDenseTensorType>()) {
phi::Attribute r1 = phi::TensorRef(
phi::Attribute attr = phi::TensorRef(
&(inner_scope->FindVar(in_var_name)->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(r1);
ctx->EmplaceBackAttr(attr);
} else if (ptr.type().isa<ir::VectorType>()) {
auto& tensor_array = inner_scope->FindVar(in_var_name)
->Get<paddle::framework::TensorRefArray>();
->Get<paddle::framework::VariableRefArray>();
if (tensor_array.size() == 1) {
ctx->EmplaceBackAttr(phi::TensorRef(tensor_array[0]));
phi::Attribute attr =
phi::TensorRef(&(tensor_array[0]->Get<phi::DenseTensor>()));
ctx->EmplaceBackAttr(attr);
} else {
std::vector<phi::TensorRef> vec_ref;
for (size_t i = 0; i < tensor_array.size(); ++i) {
vec_ref.emplace_back(phi::TensorRef(tensor_array[i]));
vec_ref.emplace_back(
phi::TensorRef(&(tensor_array[i]->Get<phi::DenseTensor>())));
}
ctx->EmplaceBackAttr(vec_ref);
}
......@@ -328,8 +333,18 @@ void BuildPhiContext(
} else if (out_type.isa<paddle::dialect::AllocatedSelectedRowsType>()) {
ctx->EmplaceBackOutput(OutType(const_cast<phi::SelectedRows*>(
&(scope->Var(name)->Get<phi::SelectedRows>()))));
} else if (out_type.isa<ir::VectorType>()) {
OutListType outputs;
auto& variable_array =
scope->Var(name)->Get<paddle::framework::VariableRefArray>();
for (size_t i = 0; i < variable_array.size(); ++i) {
outputs.emplace_back(OutType(const_cast<phi::DenseTensor*>(
&(variable_array[i]->Get<phi::DenseTensor>()))));
}
ctx->EmplaceBackOutputs(outputs);
} else {
PADDLE_THROW("not support type");
PADDLE_THROW(
phi::errors::Unimplemented("only support DenseTensor and vector "));
}
if (output_map != nullptr) {
......
......@@ -955,6 +955,104 @@ struct FeedOpTranscriber : public OpTranscriber {
}
};
struct SplitOpTranscriber : public OpTranscriber {
std::vector<ir::OpResult> GenerateOperationInput(
ir::IrContext* ctx,
TranslationContext* param_map,
const OpDesc& op_desc,
const std::string& normalized_op_name,
const OpInputInfoList& input_infos,
ir::Program* program) override {
// input of pslit is [Tensor x, IntArray sections, Scalar(int) axis)]
VLOG(10) << "[op:split][input] start";
std::vector<ir::OpResult> op_inputs;
// process first input
auto x_input_vars = op_desc.Input("X");
IR_ENFORCE(x_input_vars.size() == 1, "x input of split MUST be a tensor");
auto x_defining_info = (*param_map)[x_input_vars[0]];
op_inputs.push_back(x_defining_info.value);
// process sections
int num = paddle::get<int>(op_desc.GetAttr("num"));
if (num <= 0) {
if (op_desc.HasInput("SectionsTensorList")) {
// get SectionsTensorList from input
auto sec_tensor_list = op_desc.Input("SectionsTensorList");
auto* combine_op = InsertCombineOperationForTarget(
ctx, param_map, program, sec_tensor_list);
op_inputs.push_back(combine_op->result(0));
} else {
auto& attribute_translator = AttributeTranslator::instance();
ir::Attribute new_attr = attribute_translator(
"paddle::dialect::IntArrayAttribute", op_desc.GetAttr("sections"));
auto sec_defin_op =
InsertFullOperationForAttributeInput(ctx, program, new_attr);
op_inputs.push_back(sec_defin_op->result(0));
}
}
// process axis
if (op_desc.HasInput("AxisTensor") &&
op_desc.Input("AxisTensor").size() > 0) {
// get axis from input
auto axis_var_list = op_desc.Input("AxisTensor");
IR_ENFORCE(axis_var_list.size() == 1,
"axis tensor input of split MUST be a tensor");
auto axis_defining_info = (*param_map)[axis_var_list[0]];
op_inputs.push_back(axis_defining_info.value);
} else {
auto& attribute_translator = AttributeTranslator::instance();
ir::Attribute new_attr =
attribute_translator("ir::Int32Attribute", op_desc.GetAttr("axis"));
auto sec_defin_op =
InsertFullOperationForAttributeInput(ctx, program, new_attr);
op_inputs.push_back(sec_defin_op->result(0));
}
return op_inputs;
}
ir::AttributeMap TranslateOpAttribute(
ir::IrContext* ctx,
const std::string& normalized_op_name,
const OpAttributeInfoList& op_attr_infos,
const OpDesc& op_desc) override {
int num = paddle::get<int>(op_desc.GetAttr("num"));
if (num > 0) {
ir::AttributeMap attribute_map = {
{"num",
ir::Int32Attribute::get(ctx, op_desc.GetAttrIfExists<int>("num"))},
};
return attribute_map;
}
return {};
}
ir::OpInfo LoopkUpOpInfo(ir::IrContext* ctx, const OpDesc& op_desc) override {
int num = paddle::get<int>(op_desc.GetAttr("num"));
std::string target_op_name;
if (num > 0) {
target_op_name = "pd.split_with_num";
} else {
target_op_name = "pd.split";
}
const auto& op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
IR_THROW("Op assign_value should have corresponding OpInfo pd.split");
}
return op_info;
}
};
struct FetchOpTranscriber : public OpTranscriber {
ir::Operation* operator()(ir::IrContext* ctx,
TranslationContext* param_map,
......@@ -994,6 +1092,7 @@ OpTranslator::OpTranslator() {
special_handlers["feed"] = FeedOpTranscriber();
special_handlers["fetch_v2"] = FetchOpTranscriber();
special_handlers["cast"] = CastOpTranscriber();
special_handlers["split"] = SplitOpTranscriber();
special_handlers["lookup_table_v2"] = EmbeddingOpTranscriber();
special_handlers["lookup_table_v2_grad"] = EmbeddingGradOpTranscriber();
special_handlers["assign_value"] = AssignValueOpTranscriber();
......
......@@ -2532,7 +2532,17 @@
int_array:
sections :
data_type : int
tensor_name : AxesTensor
scalar :
axis :
data_type : int
support_tensor : true
- op : split_with_num
scalar :
axis :
data_type : int
support_tensor : true
tensor_name : AxisTensor
- op : sqrt
backward : sqrt_grad, sqrt_double_grad (sqrt_grad_grad)
......
......@@ -141,5 +141,27 @@ class TestAddGradOp(unittest.TestCase):
np.testing.assert_array_equal(out[0], gold_res)
class TestSplitOp(unittest.TestCase):
def test_with_new_ir(self):
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
main_program = paddle.static.Program()
new_scope = paddle.static.Scope()
with paddle.static.scope_guard(new_scope):
with paddle.static.program_guard(main_program):
x = paddle.static.data("x", [6, 2], dtype="float32")
out0, out1, out2 = paddle.split(x, num_or_sections=3, axis=0)
np_a = np.random.rand(6, 2).astype("float32")
out = exe.run(
main_program,
feed={"x": np_a},
fetch_list=[out0.name],
)
np.testing.assert_array_equal(out[0], np_a[0:2])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册