未验证 提交 14393611 编写于 作者: X xiaoguoguo626807 提交者: GitHub

[NewIR]Add builtin.split op (#56431)

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* add vjp interface

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* add eager and static backend for warp lower level api

* support call_vjp pybind

* polish code and add test for vjp

* remove useless code

* polish code

* remove useless code

* support mean vjp

* backward origin code

* add test for mean vjp and support has_vjp function

* fix call_vjp

* polish code

* add attrs and dtype interface

* add primitive ops set for backend

* fix compile bugs

* fix some bugs

* fix windows bugs

* add vjp test for tanh_

* fix inference CI

* fix inference ci

* modify fluid cmake

* origin test of tanh and mean passed

* fix conflict

* modify stop_gradient

* remove useless deps

* add cmake

* modify block.ops

* modify test

* fix conflict

* reply review comments

* reply review comments

* pulish code

* fix comment

* fix test

* polish code

* modify backward stop_gradients

* modify static_backend.cc

* refactor grad_op

* support add and add_inplace vjp

* remove useless code

* remove useless code

* remove cout

* modify add_n

* modify add_n with add_vjp test

* modify add_n with add_vjp test

* fix conflict and concat call_vjp

* modify backward test

* Add more gen api

* add builtin split op

---------
Co-authored-by: Ncxxly <chenxx_id@163.com>
Co-authored-by: NCharles-hit <wanghao107@baidu.com>
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
Co-authored-by: NYuanRisheng <yuanrisheng@baidu.com>
Co-authored-by: N0x45f <wangzhen45@baidu.com>
上级 5c6d0e26
...@@ -54,7 +54,10 @@ void AddNewData(ir::Value value, ...@@ -54,7 +54,10 @@ void AddNewData(ir::Value value,
std::string>* variable_2_var_name, std::string>* variable_2_var_name,
std::map<std::string, int>* var_name_2_id, std::map<std::string, int>* var_name_2_id,
std::vector<paddle::framework::Variable*>* variable_list) { std::vector<paddle::framework::Variable*>* variable_list) {
value_2_var_name->emplace(value, name); if (value_2_var_name->count(value) == 0) {
value_2_var_name->emplace(value, name);
}
variable_2_var_name->emplace(var, name); variable_2_var_name->emplace(var, name);
if (var_name_2_id->count(name) == 0) { if (var_name_2_id->count(name) == 0) {
auto id = var_name_2_id->size(); auto id = var_name_2_id->size();
...@@ -174,7 +177,6 @@ void BuildValue(ir::Value value, ...@@ -174,7 +177,6 @@ void BuildValue(ir::Value value,
var_name_2_id, var_name_2_id,
variable_list); variable_list);
} }
// Only support DenseTensor or Vector<DenseTensor> // Only support DenseTensor or Vector<DenseTensor>
if (!value.type()) { if (!value.type()) {
var->GetMutable<phi::DenseTensor>(); var->GetMutable<phi::DenseTensor>();
...@@ -200,6 +202,7 @@ void BuildValue(ir::Value value, ...@@ -200,6 +202,7 @@ void BuildValue(ir::Value value,
variable_2_var_name, variable_2_var_name,
var_name_2_id, var_name_2_id,
variable_list); variable_list);
var_i->GetMutable<phi::DenseTensor>(); var_i->GetMutable<phi::DenseTensor>();
tensor_array->emplace_back(var_i); tensor_array->emplace_back(var_i);
} }
...@@ -412,6 +415,30 @@ void HandleForSpecialOp( ...@@ -412,6 +415,30 @@ void HandleForSpecialOp(
std::string var_name = variable_2_var_name->at(variable_array[index]); std::string var_name = variable_2_var_name->at(variable_array[index]);
value_2_var_name->emplace(out_value, var_name); value_2_var_name->emplace(out_value, var_name);
} }
if (op_name == "builtin.split") {
VLOG(6) << "Handle for builtin.split";
auto in_value = op->operand_source(0);
PADDLE_ENFORCE_EQ(value_2_var_name->count(in_value),
true,
phi::errors::PreconditionNotMet(
"input of buildin split not in name map"));
auto in_var = inner_scope->FindVar(value_2_var_name->at(in_value));
auto variable_array = in_var->Get<paddle::framework::VariableRefArray>();
for (uint64_t idx = 0; idx < variable_array.size(); ++idx) {
auto out_value = op->result(idx);
PADDLE_ENFORCE_EQ(
variable_2_var_name->count(variable_array[idx]),
true,
phi::errors::PreconditionNotMet("[%d] the variable in build split "
"input MUST in variable name map",
idx));
std::string var_name = variable_2_var_name->at(variable_array[idx]);
value_2_var_name->emplace(out_value, var_name);
}
}
} }
void HandleForInplaceOp( void HandleForInplaceOp(
...@@ -498,7 +525,8 @@ void BuildScope(const ir::Block& block, ...@@ -498,7 +525,8 @@ void BuildScope(const ir::Block& block,
if (op_name == "pd.feed" || op_name == "pd.fetch" || if (op_name == "pd.feed" || op_name == "pd.fetch" ||
op_name == "builtin.combine" || op_name == "builtin.set_parameter" || op_name == "builtin.combine" || op_name == "builtin.set_parameter" ||
op_name == "builtin.get_parameter" || op_name == "builtin.slice" || op_name == "builtin.get_parameter" || op_name == "builtin.slice" ||
op_name == "pd.data" || op_name == "pd.shadow_output") { op_name == "builtin.split" || op_name == "pd.data" ||
op_name == "pd.shadow_output") {
HandleForSpecialOp(op, HandleForSpecialOp(op,
inner_scope, inner_scope,
var_name_prefix, var_name_prefix,
......
...@@ -311,7 +311,7 @@ void BuildPhiContext(ir::Operation* op, ...@@ -311,7 +311,7 @@ void BuildPhiContext(ir::Operation* op,
->Get<phi::SelectedRows>())))); ->Get<phi::SelectedRows>()))));
} else if (out_type.isa<ir::VectorType>()) { } else if (out_type.isa<ir::VectorType>()) {
OutListType outputs; OutListType outputs;
auto& variable_array = scope->FindVar(name_map.at(out_ptr)) auto& variable_array = inner_scope->FindVar(name_map.at(out_ptr))
->Get<paddle::framework::VariableRefArray>(); ->Get<paddle::framework::VariableRefArray>();
for (size_t i = 0; i < variable_array.size(); ++i) { for (size_t i = 0; i < variable_array.size(); ++i) {
outputs.emplace_back(OutType(const_cast<phi::DenseTensor*>( outputs.emplace_back(OutType(const_cast<phi::DenseTensor*>(
......
...@@ -54,6 +54,7 @@ const std::unordered_set<std::string> UnchangeOutputOps = { ...@@ -54,6 +54,7 @@ const std::unordered_set<std::string> UnchangeOutputOps = {
"pd.data", "pd.data",
"builtin.combine", "builtin.combine",
"builtin.slice", "builtin.slice",
"builtin.split",
"pd.feed", "pd.feed",
"pd.fetch", "pd.fetch",
"builtin.set_parameter", "builtin.set_parameter",
...@@ -523,7 +524,76 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog, ...@@ -523,7 +524,76 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
op_output_types.push_back(allocated_dense_tensor_dtype); op_output_types.push_back(allocated_dense_tensor_dtype);
} else { } else {
PADDLE_THROW(phi::errors::Unimplemented( PADDLE_THROW(phi::errors::Unimplemented(
"builtin.combine Result type only support DenseTensorType")); "builtin.slice Result type only support DenseTensorType"));
}
}
}
// Get op info
ir::OpInfo op_info = ctx->GetRegisteredOpInfo(op_item->name());
// Generate new op
ir::Operation* op = ir::Operation::Create(
vec_inputs, op_item->attributes(), op_output_types, op_info);
program->block()->push_back(op);
map_op_pair[op_item] = op;
// only deal with single output
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
map_value_pair[op_item->result(i)] = op->result(i);
}
}
VLOG(6) << "Deep copy a new builtin op: " << op_item->name();
continue;
}
if (op_item->name() == "builtin.split") {
phi::Place out_place = place;
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
if (op_item->num_operands() > 0) {
for (size_t i = 0; i < op_item->num_operands(); ++i) {
auto cur_in = op_item->operand_source(i);
if (!cur_in) {
vec_inputs.emplace_back();
continue;
}
PADDLE_ENFORCE_EQ(map_value_pair.count(cur_in),
true,
phi::errors::PreconditionNotMet(
"[%d]'s input of [%s] op MUST in map pair",
i,
op_item->name()));
auto new_in = map_value_pair.at(cur_in);
vec_inputs.push_back(new_in);
if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
out_place =
vec_types[0]
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place();
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
}
}
}
// Copy op output type
std::vector<ir::Type> op_output_types;
if (op_item->num_results() > 0) {
for (size_t i = 0; i < op_item->num_results(); ++i) {
auto result_type = op_item->result(i).type();
if (!result_type) {
op_output_types.push_back(result_type);
} else if (result_type.isa<dialect::DenseTensorType>()) {
auto allocated_dense_tensor_dtype =
paddle::dialect::AllocatedDenseTensorType::get(
ctx,
out_place,
result_type.dyn_cast<dialect::DenseTensorType>());
op_output_types.push_back(allocated_dense_tensor_dtype);
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"builtin.split Result type only support DenseTensorType"));
} }
} }
} }
......
...@@ -55,6 +55,7 @@ void BuiltinDialect::initialize() { ...@@ -55,6 +55,7 @@ void BuiltinDialect::initialize() {
SetParameterOp, SetParameterOp,
CombineOp, CombineOp,
SliceOp, SliceOp,
SplitOp,
ConstantOp>(); ConstantOp>();
} }
......
...@@ -207,6 +207,46 @@ void SliceOp::Verify() const { ...@@ -207,6 +207,46 @@ void SliceOp::Verify() const {
output_type); output_type);
} }
void SplitOp::Verify() const {
// inputs.size() == 1
IR_ENFORCE(num_operands() == 1u, "The size of inputs must be equal to 1.");
// input_type == Vector<Type>
auto input_type = (*this)->operand(0).type().dyn_cast<VectorType>();
IR_ENFORCE(input_type, "The type of inputs[0] must be equal to VectorType.");
// inputs[0].size() == outputs.size()
auto output_num = num_results();
IR_ENFORCE(input_type.size() == output_num,
"The size %d of output must be equal to size %d of inputs.",
output_num,
input_type.size());
// for all i in outputs.size(): outputs[i].type == inputs[0][i].type
for (size_t i = 0; i < output_num; ++i) {
auto type = (*this)->result(i).type();
IR_ENFORCE(input_type[i] == type,
"The type %s of inputs[0][%d] must be "
"equal to type %s of outputs[%d].",
input_type[i],
i,
type,
i);
}
}
void SplitOp::Build(Builder &builder,
OperationArgument &argument,
const ir::OpResult &input) {
argument.inputs = {input};
std::vector<ir::Type> outputs_types;
for (size_t idx = 0; idx < input.type().dyn_cast<ir::VectorType>().size();
++idx) {
argument.output_types.emplace_back(
input.type().dyn_cast<ir::VectorType>()[idx]);
}
}
const char *ConstantOp::attributes_name[attributes_num] = {"value"}; // NOLINT const char *ConstantOp::attributes_name[attributes_num] = {"value"}; // NOLINT
void ConstantOp::Build(Builder &builder, void ConstantOp::Build(Builder &builder,
...@@ -232,5 +272,6 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::GetParameterOp) ...@@ -232,5 +272,6 @@ IR_DEFINE_EXPLICIT_TYPE_ID(ir::GetParameterOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::SetParameterOp) IR_DEFINE_EXPLICIT_TYPE_ID(ir::SetParameterOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::CombineOp) IR_DEFINE_EXPLICIT_TYPE_ID(ir::CombineOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::SliceOp) IR_DEFINE_EXPLICIT_TYPE_ID(ir::SliceOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::SplitOp)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait) IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait)
IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantOp) IR_DEFINE_EXPLICIT_TYPE_ID(ir::ConstantOp)
...@@ -93,6 +93,13 @@ class IR_API CombineOp : public ir::Op<CombineOp> { ...@@ -93,6 +93,13 @@ class IR_API CombineOp : public ir::Op<CombineOp> {
const std::vector<ir::OpResult> &inputs); const std::vector<ir::OpResult> &inputs);
void Verify() const; void Verify() const;
std::vector<ir::Value> inputs() {
std::vector<ir::Value> inputs;
for (uint32_t idx = 0; idx < num_operands(); idx++) {
inputs.push_back(operand_source(static_cast<int>(idx)));
}
return inputs;
}
ir::OpResult out() { return result(0); } ir::OpResult out() { return result(0); }
}; };
...@@ -108,8 +115,41 @@ class IR_API SliceOp : public ir::Op<SliceOp> { ...@@ -108,8 +115,41 @@ class IR_API SliceOp : public ir::Op<SliceOp> {
static constexpr uint32_t attributes_num = 1; static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num]; static const char *attributes_name[attributes_num];
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const ir::OpResult &input);
void Verify() const; void Verify() const;
ir::OpResult out() { return result(0); } ir::Value input() { return operand_source(0); }
};
///
/// \brief SplitOp: SplitOp(OpOperand)
///
class IR_API SplitOp : public ir::Op<SplitOp> {
public:
using Op::Op;
static const char *name() { return "builtin.split"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const ir::OpResult &input);
void Verify() const;
ir::Value input() { return operand_source(0); }
std::vector<ir::OpResult> outputs() {
std::vector<ir::OpResult> outputs;
for (uint32_t idx = 0; idx < num_results(); idx++) {
outputs.push_back(result(static_cast<int>(idx)));
}
return outputs;
}
}; };
class IR_API ConstantLikeTrait : public OpTraitBase<ConstantLikeTrait> { class IR_API ConstantLikeTrait : public OpTraitBase<ConstantLikeTrait> {
...@@ -146,5 +186,6 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::GetParameterOp) ...@@ -146,5 +186,6 @@ IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::GetParameterOp)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SetParameterOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SetParameterOp)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::CombineOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::CombineOp)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SliceOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SliceOp)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::SplitOp)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ConstantLikeTrait)
IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ConstantOp) IR_EXPORT_DECLARE_EXPLICIT_TYPE_ID(ir::ConstantOp)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册