未验证 提交 6de9a8d3 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【new ir】modify builtin_split bug (#56463)

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* add vjp interface

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* add eager and static backend for warp lower level api

* support call_vjp pybind

* polish code and add test for vjp

* remove useless code

* polish code

* remove useless code

* support mean vjp

* backward origin code

* add test for mean vjp and support has_vjp function

* fix call_vjp

* polish code

* add attrs and dtype interface

* add primitive ops set for backend

* fix compile bugs

* fix some bugs

* fix windows bugs

* add vjp test for tanh_

* fix inference CI

* fix inference ci

* modify fluid cmake

* origin test of tanh and mean passed

* fix conflict

* modify stop_gradient

* remove useless deps

* add cmake

* modify block.ops

* modify test

* fix conflict

* reply review comments

* reply review comments

* pulish code

* fix comment

* fix test

* polish code

* modify backward stop_gradients

* modify static_backend.cc

* refactor grad_op

* support add and add_inplace vjp

* remove useless code

* remove useless code

* remove cout

* modify add_n

* modify add_n with add_vjp test

* modify add_n with add_vjp test

* fix conflict and concat call_vjp

* modify backward test

* Add more gen api

* modify split kernel pass

* modify concat api

* modify builtin split bug

* delete vlog

---------
Co-authored-by: Ncxxly <chenxx_id@163.com>
Co-authored-by: NCharles-hit <wanghao107@baidu.com>
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
Co-authored-by: NYuanRisheng <yuanrisheng@baidu.com>
Co-authored-by: N0x45f <wangzhen45@baidu.com>
上级 fb891776
......@@ -592,7 +592,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
}
if (op_item->name() == "builtin.split") {
phi::Place out_place = place;
std::vector<phi::Place> out_places(op_item->num_results());
// Copy op inputs
std::vector<ir::OpResult> vec_inputs;
if (op_item->num_operands() > 0) {
......@@ -613,10 +613,12 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
if (new_in.type().isa<ir::VectorType>()) {
auto vec_types = new_in.type().dyn_cast<ir::VectorType>().data();
out_place =
vec_types[0]
for (uint64_t idx = 0; idx < vec_types.size(); idx++) {
out_places[idx] =
vec_types[idx]
.dyn_cast<paddle::dialect::AllocatedDenseTensorType>()
.place();
}
} else {
PADDLE_THROW(
phi::errors::Unimplemented("only support vector type for now"));
......@@ -634,7 +636,7 @@ std::unique_ptr<ir::Program> PdOpLowerToKernelPass(ir::Program* prog,
auto allocated_dense_tensor_dtype =
paddle::dialect::AllocatedDenseTensorType::get(
ctx,
out_place,
out_places[i],
result_type.dyn_cast<dialect::DenseTensorType>());
op_output_types.push_back(allocated_dense_tensor_dtype);
} else {
......
......@@ -162,6 +162,17 @@ void CombineOp::Verify() const {
}
const char *SliceOp::attributes_name[attributes_num] = {"index"}; // NOLINT
void SliceOp::Build(Builder &builder,
OperationArgument &argument,
const ir::OpResult &input,
int index) {
argument.inputs = {input};
argument.output_types.emplace_back(input.type()
.dyn_cast<ir::VectorType>()
.data()[static_cast<size_t>(index)]);
}
void SliceOp::Verify() const {
// inputs.size() == 1
auto input_size = num_operands();
......@@ -207,6 +218,17 @@ void SliceOp::Verify() const {
output_type);
}
void SplitOp::Build(Builder &builder,
OperationArgument &argument,
const ir::OpResult &input) {
argument.inputs = {input};
for (size_t idx = 0; idx < input.type().dyn_cast<ir::VectorType>().size();
++idx) {
argument.output_types.emplace_back(
input.type().dyn_cast<ir::VectorType>().data()[idx]);
}
}
void SplitOp::Verify() const {
// inputs.size() == 1
IR_ENFORCE(num_operands() == 1u, "The size of inputs must be equal to 1.");
......@@ -235,18 +257,6 @@ void SplitOp::Verify() const {
}
}
void SplitOp::Build(Builder &builder,
OperationArgument &argument,
const ir::OpResult &input) {
argument.inputs = {input};
std::vector<ir::Type> outputs_types;
for (size_t idx = 0; idx < input.type().dyn_cast<ir::VectorType>().size();
++idx) {
argument.output_types.emplace_back(
input.type().dyn_cast<ir::VectorType>()[idx]);
}
}
const char *ConstantOp::attributes_name[attributes_num] = {"value"}; // NOLINT
void ConstantOp::Build(Builder &builder,
......
......@@ -118,7 +118,8 @@ class IR_API SliceOp : public ir::Op<SliceOp> {
static void Build(Builder &builder, // NOLINT
OperationArgument &argument, // NOLINT
const ir::OpResult &input);
const ir::OpResult &input,
int index);
void Verify() const;
ir::Value input() { return operand_source(0); }
......
......@@ -1121,7 +1121,7 @@ def concat(x, axis=0, name=None):
return _C_ops.concat(input, axis)
else:
if paddle.ir.core._use_new_ir_api():
if not isinstance(input, Variable):
if not isinstance(input, paddle.ir.Value):
input = [t for t in input if t.shape.count(0) == 0]
return paddle._ir_ops.concat(input, axis)
check_type(input, 'input', (list, tuple, Variable), 'concat')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册