未验证 提交 7995a389 编写于 作者: X xiaoguoguo626807 提交者: GitHub

[NewIR]Split python api and vjp (#56518)

* support ir api form prim

* convert vector of int to intarray

* add reference of lbfgs

* add reference of lbfgs

* support ir api for prim

* Add more gen api

* concat python api to concat_grad

* fix gen conflict

* support vjp prim mode in new ir

* remove useless code

* add vjp autogen v1.0

* add test for prim

* resolve type conflict

* modify utils

* remove useless code

* add split op and modify some bug of vectorType

* fix conflict

* add concat python test

* add split python api to vjp

* modify build bug

* modify run bug

* fix conflict bug

* build bug fix

* modify python api bug

* modify test

* fix conflict

* fluid backward recover

* recover conflict

* reply review comments

* modify opruntimeinfo num

---------
Co-authored-by: Ncyber-pioneer <chenzhuo@tju.edu.cn>
Co-authored-by: NCharles-hit <wanghao107@baidu.com>
Co-authored-by: N0x45f <wangzhen45@baidu.com>
Co-authored-by: Nchenzhiyang <1792266893@qq.com>
Co-authored-by: NChen Zhiyang <chenzhiyang99@126.com>
上级 cf80a66b
...@@ -172,7 +172,7 @@ scalar_type_maps = { ...@@ -172,7 +172,7 @@ scalar_type_maps = {
'bool': 'ir::BoolAttribute', 'bool': 'ir::BoolAttribute',
} }
_NO_NEED_GEN_OPS = {'add_n'} _NO_NEED_GEN_OPS = {'add_n', 'split_grad'}
def to_phi_and_fluid_op_name(op_item): def to_phi_and_fluid_op_name(op_item):
......
...@@ -29,6 +29,7 @@ vjp_interface_declare_gen_op_list = [ ...@@ -29,6 +29,7 @@ vjp_interface_declare_gen_op_list = [
"sum", "sum",
"add", "add",
"concat", "concat",
"split",
] ]
vjp_interface_implementation_gen_op_list = [ vjp_interface_implementation_gen_op_list = [
"tanh", "tanh",
......
...@@ -48,7 +48,7 @@ void PaddleDialect::initialize() { ...@@ -48,7 +48,7 @@ void PaddleDialect::initialize() {
#define GET_OP_LIST #define GET_OP_LIST
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" // NOLINT #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" // NOLINT
>(); >();
RegisterOp<paddle::dialect::AddNOp>(); RegisterOps<paddle::dialect::AddNOp, paddle::dialect::SplitGradOp>();
RegisterInterfaces<ParameterConvertInterface>(); RegisterInterfaces<ParameterConvertInterface>();
} }
......
...@@ -18,5 +18,16 @@ ...@@ -18,5 +18,16 @@
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
namespace paddle { namespace paddle {
namespace dialect {} // namespace dialect namespace dialect {
ir::OpResult split_grad(std::vector<ir::OpResult> out_grads,
ir::OpResult axis) {
auto combine_op =
APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>(out_grads);
paddle::dialect::SplitGradOp split_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::SplitGradOp>(
combine_op.out(), axis);
return split_grad_op.x_grad();
}
} // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -21,5 +21,9 @@ ...@@ -21,5 +21,9 @@
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
namespace paddle { namespace paddle {
namespace dialect {} // namespace dialect namespace dialect {
ir::OpResult split_grad(std::vector<ir::OpResult> out_grads, ir::OpResult axis);
} // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_manual_op.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_attribute.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h"
#include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h"
#include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/builtin_op.h"
...@@ -145,7 +146,221 @@ void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) { ...@@ -145,7 +146,221 @@ void AddNOp::InferMeta(phi::InferMetaContext *infer_meta) {
fn(infer_meta); fn(infer_meta);
} }
const char *SplitGradOp::attributes_name[1] = {"axis"};
OpInfoTuple SplitGradOp::GetOpInfo() {
std::vector<paddle::dialect::OpInputInfo> inputs = {
OpInputInfo("out_grad",
"ir::VectorType<paddle::dialect::DenseTensorType>",
false,
false,
false),
OpInputInfo(
"axis", "paddle::dialect::ScalarAttribute", false, false, true)};
std::vector<paddle::dialect::OpAttributeInfo> attributes = {};
std::vector<paddle::dialect::OpOutputInfo> outputs = {
OpOutputInfo("x_grad", "paddle::dialect::DenseTensorType", false, false)};
paddle::dialect::OpRunTimeInfo run_time_info =
OpRunTimeInfo("ConcatInferMeta",
{"out_grad", "axis"},
{"concat"},
{"out_grad", "axis"},
{"out_grad"},
{},
{},
{});
return std::make_tuple(
inputs, attributes, outputs, run_time_info, "split_grad");
}
void SplitGradOp::Build(ir::Builder &builder,
ir::OperationArgument &argument,
ir::OpResult out_grad_,
float axis) {
// Generate scalar mutable attribute: axis
paddle::dialect::FullOp full_axis_op = builder.Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, axis, phi::DataType::FLOAT32, phi::CPUPlace());
ir::OpResult axis_ = full_axis_op->result(0);
VLOG(4) << "Builder construction inputs";
std::vector<ir::OpResult> argument_inputs = {out_grad_, axis_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
VLOG(4) << "Builder construction attributes";
VLOG(4) << "Builder construction outputs";
ir::VectorType out_grad = out_grad_.type().dyn_cast<ir::VectorType>();
std::vector<phi::DenseTensor> vec_dense_out_grad;
for (size_t i = 0; i < static_cast<size_t>(out_grad.size()); i++) {
vec_dense_out_grad.push_back(phi::DenseTensor(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(
paddle::dialect::TransToPhiDataType(
out_grad[i]
.dyn_cast<paddle::dialect::DenseTensorType>()
.dtype()),
out_grad[i].dyn_cast<paddle::dialect::DenseTensorType>().dims(),
out_grad[i]
.dyn_cast<paddle::dialect::DenseTensorType>()
.data_layout(),
out_grad[i].dyn_cast<paddle::dialect::DenseTensorType>().lod(),
out_grad[i]
.dyn_cast<paddle::dialect::DenseTensorType>()
.offset())));
}
std::vector<phi::MetaTensor> vec_meta_out_grad;
for (size_t i = 0; i < vec_dense_out_grad.size(); i++) {
vec_meta_out_grad.push_back(phi::MetaTensor(&vec_dense_out_grad[i]));
}
std::vector<const phi::MetaTensor *> meta_out_grad;
for (size_t i = 0; i < static_cast<size_t>(vec_meta_out_grad.size()); i++) {
meta_out_grad.push_back(&vec_meta_out_grad[i]);
}
phi::DenseTensor dense_x_grad;
phi::MetaTensor meta_x_grad(&dense_x_grad);
phi::ConcatInferMeta(meta_out_grad, axis, &meta_x_grad);
std::vector<ir::Type> argument_outputs;
ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get(
ir::IrContext::Instance(),
paddle::dialect::TransToIrDataType(dense_x_grad.dtype()),
dense_x_grad.dims(),
dense_x_grad.layout(),
dense_x_grad.lod(),
dense_x_grad.offset());
argument_outputs.push_back(x_grad_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}
void SplitGradOp::Build(ir::Builder &builder,
ir::OperationArgument &argument,
ir::OpResult out_grad_,
ir::OpResult axis_) {
VLOG(4) << "Builder construction inputs";
std::vector<ir::OpResult> argument_inputs = {out_grad_, axis_};
argument.AddOperands(argument_inputs.begin(), argument_inputs.end());
VLOG(4) << "Builder construction attributes";
VLOG(4) << "Builder construction outputs";
ir::VectorType out_grad = out_grad_.type().dyn_cast<ir::VectorType>();
int axis = axis_.owner()
->dyn_cast<paddle::dialect::FullOp>()
.attributes()
.at("value")
.dyn_cast<paddle::dialect::ScalarAttribute>()
.data()
.to<int>();
std::vector<phi::DenseTensor> vec_dense_out_grad;
for (size_t i = 0; i < static_cast<size_t>(out_grad.size()); i++) {
vec_dense_out_grad.push_back(phi::DenseTensor(
std::make_unique<paddle::experimental::DefaultAllocator>(
paddle::platform::CPUPlace())
.get(),
phi::DenseTensorMeta(
TransToPhiDataType(out_grad[i]
.dyn_cast<paddle::dialect::DenseTensorType>()
.dtype()),
out_grad[i].dyn_cast<paddle::dialect::DenseTensorType>().dims(),
out_grad[i]
.dyn_cast<paddle::dialect::DenseTensorType>()
.data_layout(),
out_grad[i].dyn_cast<paddle::dialect::DenseTensorType>().lod(),
out_grad[i]
.dyn_cast<paddle::dialect::DenseTensorType>()
.offset())));
}
std::vector<phi::MetaTensor> vec_meta_out_grad;
for (size_t i = 0; i < vec_dense_out_grad.size(); i++) {
vec_meta_out_grad.push_back(phi::MetaTensor(&vec_dense_out_grad[i]));
}
std::vector<const phi::MetaTensor *> meta_out_grad;
for (size_t i = 0; i < static_cast<size_t>(vec_meta_out_grad.size()); i++) {
meta_out_grad.push_back(&vec_meta_out_grad[i]);
}
phi::DenseTensor dense_x_grad;
phi::MetaTensor meta_x_grad(&dense_x_grad);
phi::ConcatInferMeta(meta_out_grad, axis, &meta_x_grad);
std::vector<ir::Type> argument_outputs;
ir::Type x_grad_dense_tensor_type = paddle::dialect::DenseTensorType::get(
ir::IrContext::Instance(),
TransToIrDataType(dense_x_grad.dtype()),
dense_x_grad.dims(),
dense_x_grad.layout(),
dense_x_grad.lod(),
dense_x_grad.offset());
argument_outputs.push_back(x_grad_dense_tensor_type);
argument.AddOutputs(argument_outputs.begin(), argument_outputs.end());
}
void SplitGradOp::Verify() {
VLOG(4) << "Start Verifying inputs, outputs and attributes for: SplitGradOp.";
VLOG(4) << "Verifying inputs:";
{
auto input_size = num_operands();
PADDLE_ENFORCE_EQ(
input_size,
2u,
phi::errors::PreconditionNotMet(
"The size %d of inputs must be equal to 2.", input_size));
if (auto vec_type =
(*this)->operand_source(0).type().dyn_cast<ir::VectorType>()) {
for (size_t i = 0; i < vec_type.size(); ++i) {
PADDLE_ENFORCE(vec_type[i].isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
}
} else {
PADDLE_ENFORCE((*this)
->operand_source(0)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th input."));
}
PADDLE_ENFORCE((*this)
->operand_source(1)
.type()
.isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 1th input."));
}
VLOG(4) << "Verifying attributes:";
{
// Attributes num is 0, not need to check attributes type.
}
VLOG(4) << "Verifying outputs:";
{
auto output_size = num_results();
PADDLE_ENFORCE_EQ(
output_size,
1u,
phi::errors::PreconditionNotMet(
"The size %d of outputs must be equal to 1.", output_size));
PADDLE_ENFORCE(
(*this)->result(0).type().isa<paddle::dialect::DenseTensorType>(),
phi::errors::PreconditionNotMet(
"Type validation failed for the 0th output."));
}
VLOG(4) << "End Verifying for: SplitGradOp.";
}
void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) {
auto fn = PD_INFER_META(phi::ConcatInferMeta);
fn(infer_meta);
}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp)
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
#ifdef GET_MANUAL_OP_LIST #ifdef GET_MANUAL_OP_LIST
#undef GET_MANUAL_OP_LIST #undef GET_MANUAL_OP_LIST
paddle::dialect::AddNOp paddle::dialect::AddNOp, paddle::dialect::SplitGradOp
#else #else
...@@ -51,9 +51,33 @@ class AddNOp : public ir::Op<AddNOp, OpYamlInfoInterface> { ...@@ -51,9 +51,33 @@ class AddNOp : public ir::Op<AddNOp, OpYamlInfoInterface> {
static void InferMeta(phi::InferMetaContext *infer_meta); static void InferMeta(phi::InferMetaContext *infer_meta);
}; };
class SplitGradOp : public ir::Op<SplitGradOp, OpYamlInfoInterface> {
public:
using Op::Op;
static const char *name() { return "pd.split_grad"; }
static const char *attributes_name[1];
static constexpr uint32_t attributes_num = 1;
static OpInfoTuple GetOpInfo();
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult x_,
float axis = 0);
static void Build(ir::Builder &builder, // NOLINT
ir::OperationArgument &argument, // NOLINT
ir::OpResult out_grad_,
ir::OpResult axis_);
void Verify();
ir::Value out_grad() { return operand_source(0); }
ir::Value axis() { return operand_source(1); }
ir::OpResult x_grad() { return result(0); }
static void InferMeta(phi::InferMetaContext *infer_meta);
};
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::SplitGradOp)
#endif #endif
...@@ -53,5 +53,39 @@ std::vector<std::vector<ir::OpResult>> SumOp::Vjp( ...@@ -53,5 +53,39 @@ std::vector<std::vector<ir::OpResult>> SumOp::Vjp(
} }
return res; return res;
} }
std::vector<std::vector<ir::OpResult>> SplitOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
SplitOp op_obj = op->dyn_cast<SplitOp>();
Tensor axis(std::make_shared<primitive::LazyTensor>(op_obj.axis()));
std::vector<Tensor> out_grads_;
for (size_t idx = 0; idx < out_grads[0].size(); idx++) {
out_grads_.emplace_back(
std::make_shared<primitive::LazyTensor>(out_grads[0][idx]));
}
std::vector<std::vector<Tensor>> tensor_res =
primitive::split_vjp(out_grads_, axis, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(tensor_res.size(),
std::vector<ir::OpResult>());
for (uint64_t i = 0; i < tensor_res.size(); i++) {
res[i].resize(tensor_res[i].size());
for (uint64_t j = 0; j < tensor_res[i].size(); j++) {
if (tensor_res[i][j].defined()) {
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(
tensor_res[i][j].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
}
}
return res;
}
} // namespace dialect } // namespace dialect
} // namespace paddle } // namespace paddle
...@@ -33,6 +33,9 @@ std::vector<Tensor> concat_grad(const std::vector<Tensor>& x, ...@@ -33,6 +33,9 @@ std::vector<Tensor> concat_grad(const std::vector<Tensor>& x,
const Tensor& out_grad, const Tensor& out_grad,
const Tensor& axis); const Tensor& axis);
template <typename T>
Tensor split_grad(const std::vector<Tensor>& out_grads, const Tensor& axis);
} // namespace backend } // namespace backend
} // namespace primitive } // namespace primitive
} // namespace paddle } // namespace paddle
...@@ -54,6 +54,23 @@ std::vector<Tensor> concat_grad<LazyTensor>(const std::vector<Tensor>& x, ...@@ -54,6 +54,23 @@ std::vector<Tensor> concat_grad<LazyTensor>(const std::vector<Tensor>& x,
return op_result; return op_result;
} }
template <>
Tensor split_grad<LazyTensor>(const std::vector<Tensor>& out_grads,
const Tensor& axis) {
std::vector<ir::OpResult> out_grads_res;
for (uint64_t idx = 0; idx < out_grads.size(); idx++) {
out_grads_res.emplace_back(
std::static_pointer_cast<LazyTensor>(out_grads[idx].impl())
->getValue()
.dyn_cast<ir::OpResult>());
}
ir::OpResult axis_res = std::static_pointer_cast<LazyTensor>(axis.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::split_grad(out_grads_res, axis_res);
return Tensor(std::make_shared<primitive::LazyTensor>(op_res));
}
} // namespace backend } // namespace backend
} // namespace primitive } // namespace primitive
} // namespace paddle } // namespace paddle
...@@ -48,5 +48,25 @@ std::vector<std::vector<paddle::Tensor>> concat_vjp( ...@@ -48,5 +48,25 @@ std::vector<std::vector<paddle::Tensor>> concat_vjp(
return vjp_res; return vjp_res;
} }
std::vector<std::vector<paddle::Tensor>> split_vjp(
const std::vector<Tensor>& out_grads,
const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(3, std::vector<Tensor>(1));
// get concat_grad res.
Tensor op_res = backend::split_grad<primitive::LazyTensor>(out_grads, axis);
// construct vjp result by op result and stop_gradients info
if (!stop_gradients[0][0]) {
vjp_res[0][0] = op_res;
}
// vjp_res[1] is sections's grad which is attribute (no grad).
// vjp_res[2] is axis's grad which is attribute (no grad).
vjp_res[1].resize(stop_gradients[1].size());
vjp_res[2].resize(stop_gradients[2].size());
return vjp_res;
}
} // namespace primitive } // namespace primitive
} // namespace paddle } // namespace paddle
...@@ -30,5 +30,10 @@ std::vector<std::vector<paddle::Tensor>> concat_vjp( ...@@ -30,5 +30,10 @@ std::vector<std::vector<paddle::Tensor>> concat_vjp(
const Tensor& axis, const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients); const std::vector<std::vector<bool>>& stop_gradients);
std::vector<std::vector<paddle::Tensor>> split_vjp(
const std::vector<Tensor>& out_grads,
const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients);
} // namespace primitive } // namespace primitive
} // namespace paddle } // namespace paddle
...@@ -52,6 +52,10 @@ static PyObject *concat(PyObject *self, PyObject *args, PyObject *kwargs) { ...@@ -52,6 +52,10 @@ static PyObject *concat(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_concat(self, args, kwargs); return static_api_concat(self, args, kwargs);
} }
static PyObject *split(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_split(self, args, kwargs);
}
static PyMethodDef OpsAPI[] = {{"add_n", static PyMethodDef OpsAPI[] = {{"add_n",
(PyCFunction)(void (*)(void))add_n, (PyCFunction)(void (*)(void))add_n,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
...@@ -76,6 +80,10 @@ static PyMethodDef OpsAPI[] = {{"add_n", ...@@ -76,6 +80,10 @@ static PyMethodDef OpsAPI[] = {{"add_n",
(PyCFunction)(void (*)(void))full, (PyCFunction)(void (*)(void))full,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
"C++ interface function for full."}, "C++ interface function for full."},
{"split",
(PyCFunction)(void (*)(void))split,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for split."},
{"data", {"data",
(PyCFunction)(void (*)(void))data, (PyCFunction)(void (*)(void))data,
METH_VARARGS | METH_KEYWORDS, METH_VARARGS | METH_KEYWORDS,
......
...@@ -201,7 +201,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set): ...@@ -201,7 +201,6 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
outputs_set.add(operand) outputs_set.add(operand)
else: else:
relevant_op_flags[i] = False relevant_op_flags[i] = False
# recover full op or full_Intarray op created by mutable attribute. # recover full op or full_Intarray op created by mutable attribute.
total_ops_list = list(total_ops) total_ops_list = list(total_ops)
for i, op in enumerate(total_ops_list): for i, op in enumerate(total_ops_list):
...@@ -354,12 +353,16 @@ def append_backward_ops( ...@@ -354,12 +353,16 @@ def append_backward_ops(
def make_output_grad(op): def make_output_grad(op):
zero_flag = [False] * op.num_results() zero_flag = [False] * op.num_results()
output_grads = []
for i, value in enumerate(op.results()): for i, value in enumerate(op.results()):
if ( if (
value not in state.value_to_valuegrad value not in state.value_to_valuegrad
or state.value_to_valuegrad[value] is None or state.value_to_valuegrad[value] is None
): ):
if value.first_use().owner().name() == "builtin.split": if (
not value.use_empty()
and value.first_use().owner().name() == "builtin.split"
):
# pattern case: # pattern case:
# this fwd_op's output is vectorType, it will split to # this fwd_op's output is vectorType, it will split to
# Type by builtin.split op, so need get from split op's ouput # Type by builtin.split op, so need get from split op's ouput
...@@ -367,7 +370,7 @@ def append_backward_ops( ...@@ -367,7 +370,7 @@ def append_backward_ops(
value.first_use().owner() value.first_use().owner()
) )
zero_flag[i] = all(split_zero_flag) zero_flag[i] = all(split_zero_flag)
grad_value = [op_list[0] for op_list in split_output_grad] state.value_to_valuegrad[value] = [split_output_grad]
else: else:
# first case: # first case:
# this fwd_op's output didn't used by other fwd_op, # this fwd_op's output didn't used by other fwd_op,
...@@ -388,7 +391,7 @@ def append_backward_ops( ...@@ -388,7 +391,7 @@ def append_backward_ops(
) )
zero_flag[i] = True zero_flag[i] = True
state.value_to_valuegrad[value] = [[grad_value]] state.value_to_valuegrad[value] = [[grad_value]]
if len(state.value_to_valuegrad[value]) > 1: if len(state.value_to_valuegrad[value]) > 1:
# one value is input of more than one fwd_op, # one value is input of more than one fwd_op,
...@@ -411,8 +414,8 @@ def append_backward_ops( ...@@ -411,8 +414,8 @@ def append_backward_ops(
value value
] ]
output_grad = state.value_to_valuegrad[value][0] output_grads.append(state.value_to_valuegrad[value][0][0])
return zero_flag, output_grad return zero_flag, output_grads
def make_input_stopgradient(op): def make_input_stopgradient(op):
input_grad_stopgradients = [] input_grad_stopgradients = []
...@@ -529,8 +532,9 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state): ...@@ -529,8 +532,9 @@ def create_backward_prune_set(inputs, outputs, no_grad_set, state):
inputs_set.add(state.value_to_valuegrad[output][0][0]) inputs_set.add(state.value_to_valuegrad[output][0][0])
inputs_set_tmp = set() inputs_set_tmp = set()
for out_grad in inputs_set: for out_grad in inputs_set:
for item in out_grad.first_use().owner().operands_source(): if not out_grad.use_empty():
inputs_set_tmp.add(item) for item in out_grad.first_use().owner().operands_source():
inputs_set_tmp.add(item)
inputs_set.update(inputs_set_tmp) inputs_set.update(inputs_set_tmp)
no_gradvar_set = set() # grad_value of value in no_grad_set no_gradvar_set = set() # grad_value of value in no_grad_set
......
...@@ -1976,6 +1976,14 @@ def split(x, num_or_sections, axis=0, name=None): ...@@ -1976,6 +1976,14 @@ def split(x, num_or_sections, axis=0, name=None):
else: else:
return _C_ops.split(input, num_or_sections, dim) return _C_ops.split(input, num_or_sections, dim)
else: else:
if paddle.ir.core._use_new_ir_api():
if not isinstance(num_or_sections, int):
return paddle._ir_ops.split(input, num_or_sections, dim)
else:
raise NotImplementedError(
"_ir_ops.split_with_num is not implemented, please change sections as list"
)
check_variable_and_dtype( check_variable_and_dtype(
input, input,
'input', 'input',
......
...@@ -402,5 +402,74 @@ TEST(VJP, Add_BackwardTest) { ...@@ -402,5 +402,74 @@ TEST(VJP, Add_BackwardTest) {
ASSERT_EQ(dx.data<float>()[0], 1.0); ASSERT_EQ(dx.data<float>()[0], 1.0);
ASSERT_EQ(dy.data<float>()[0], 1.0); ASSERT_EQ(dy.data<float>()[0], 1.0);
} }
TEST(VJP, SplitBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program);
std::shared_ptr<ir::Builder> builder =
paddle::dialect::APIBuilder::Instance().GetBuilder();
paddle::dialect::FullOp op1 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
paddle::dialect::SplitOp op2 = builder->Build<paddle::dialect::SplitOp>(
op1.out(), std::vector<int64_t>{1, 1}, 0);
ir::SplitOp op3 = builder->Build<ir::SplitOp>(op2.out());
paddle::dialect::FullOp op4 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());
std::vector<std::vector<bool>> stop_gradients{{false}, {true}, {true}};
std::vector<std::vector<ir::OpResult>> out_grads{{op3.result(0), op4.out()}};
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.split");
auto concat_vjp_interface_impl =
op2_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();
concat_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients);
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
auto place = platform::CPUPlace();
Scope scope;
ProgramDesc prog_desc;
InterpreterCore test_core(place, {}, std::move(kernel_program), &scope);
std::stringstream os;
os << reinterpret_cast<NewIRInterpreter*>(
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
std::string prefix_str = os.str();
test_core.SetSkipGcVars({prefix_str + "_inner_var_4",
prefix_str + "_inner_var_5",
prefix_str + "_inner_var_8"});
test_core.Run({});
auto out_tensor_0 =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_4")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_4")
->Get<phi::DenseTensor>();
auto out_tensor_1 =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_5")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_5")
->Get<phi::DenseTensor>();
auto grad_out_tensor_0 =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_8")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_8")
->Get<phi::DenseTensor>();
ASSERT_EQ(out_tensor_0.data<float>()[0], 2.0);
ASSERT_EQ(out_tensor_0.data<float>()[1], 2.0);
ASSERT_EQ(out_tensor_1.data<float>()[0], 2.0);
ASSERT_EQ(out_tensor_1.data<float>()[1], 2.0);
ASSERT_EQ(grad_out_tensor_0.data<float>()[0], 2.0);
ASSERT_EQ(grad_out_tensor_0.data<float>()[1], 2.0);
ASSERT_EQ(grad_out_tensor_0.data<float>()[2], 1.0);
ASSERT_EQ(grad_out_tensor_0.data<float>()[3], 1.0);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -121,5 +121,25 @@ class TestBuildOp4(unittest.TestCase): ...@@ -121,5 +121,25 @@ class TestBuildOp4(unittest.TestCase):
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
class TestBuildOp5(unittest.TestCase):
def test_build_split_op(self):
newir_program = get_ir_program()
tanh_out = newir_program.block().ops[-1].result(0)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
with paddle.ir.core.program_guard(newir_program):
out = paddle.split(tanh_out, [2, 2], 0)
self.assertEqual(out[0].get_defining_op().name(), "builtin.split")
self.assertEqual(
out[0]
.get_defining_op()
.operands()[0]
.source()
.get_defining_op()
.name(),
"pd.split",
)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -94,6 +94,32 @@ class TesBackward_1(unittest.TestCase): ...@@ -94,6 +94,32 @@ class TesBackward_1(unittest.TestCase):
self.assertEqual(newir_program.block().ops[-1].name(), "pd.mean") self.assertEqual(newir_program.block().ops[-1].name(), "pd.mean")
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False}) paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
def test_split(self):
# test create output_grad in backward use full op
newir_program = get_ir_program_0()
input = newir_program.block().ops[-1].operand(0).source()
tanh_out = newir_program.block().ops[-1].result(0)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
with paddle.ir.core.program_guard(newir_program):
out = paddle.split(tanh_out, [2, 2], 0)
input_grad = grad(out, input)
ops_name = [
"pd.data",
"pd.tanh",
"pd.full_int_array",
"pd.full",
"pd.split",
"builtin.split",
"pd.full",
"builtin.combine",
"pd.split_grad",
"pd.tanh_grad",
]
for i, op in enumerate(newir_program.block().ops):
self.assertEqual(op.name(), ops_name[i])
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
def get_ir_program_1(): def get_ir_program_1():
x = paddle.randn([2, 2]) x = paddle.randn([2, 2])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册