未验证 提交 1da20aa8 编写于 作者: X xiaoguoguo626807 提交者: GitHub

【New IR] Concat python api and vjp (#56316)

* support ir api form prim

* convert vector of int to intarray

* support ir api for prim

* Add more gen api

* concat python api to concat_grad

* fix gen conflict

* support vjp prim mode in new ir

* remove useless code

* add vjp autogen v1.0

* add test for prim

* resolve type conflict

* modify utils

* remove useless code

* add split op and modify some bug of vectorType

* fix conflict

* add concat python test

---------
Co-authored-by: Ncyber-pioneer <chenzhuo@tju.edu.cn>
Co-authored-by: NCharles-hit <wanghao107@baidu.com>
Co-authored-by: N0x45f <wangzhen45@baidu.com>
Co-authored-by: Nchenzhiyang <1792266893@qq.com>
Co-authored-by: NChen Zhiyang <chenzhiyang99@126.com>
上级 ac8aad59
......@@ -29,6 +29,7 @@ H_FILE_TEMPLATE = """
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/common/scalar.h"
#include "paddle/fluid/ir/dialect/pd_manual_api.h"
{body}
......
......@@ -21,5 +21,13 @@
# TODO(wanghao107)
# remove this file and support Vjp methods
# code gen.
vjp_interface_declare_gen_op_list = ["tanh", "mean", "divide", "sum", "add"]
vjp_interface_declare_gen_op_list = [
"tanh",
"mean",
"divide",
"sum",
"add",
"concat",
]
vjp_interface_implementation_gen_op_list = ["tanh", "mean", "divide", "add"]
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/ir/dialect/pd_manual_api.h"
#include "paddle/fluid/ir/dialect/pd_dialect.h"
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/ir/core/builder.h"
#include "paddle/ir/core/builtin_op.h"
namespace paddle {
namespace dialect {
std::vector<ir::OpResult> concat_grad(std::vector<ir::OpResult> x,
ir::OpResult out_grad,
ir::OpResult axis) {
auto combine_op =
APIBuilder::Instance().GetBuilder()->Build<ir::CombineOp>(x);
paddle::dialect::ConcatGradOp concat_grad_op =
APIBuilder::Instance().GetBuilder()->Build<paddle::dialect::ConcatGradOp>(
combine_op.out(), out_grad, axis);
auto split_op = APIBuilder::Instance().GetBuilder()->Build<ir::SplitOp>(
concat_grad_op.result(0));
return split_op.outputs();
}
} // namespace dialect
} // namespace paddle
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/ir/core/value.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
namespace paddle {
namespace dialect {
std::vector<ir::OpResult> concat_grad(std::vector<ir::OpResult> x,
ir::OpResult out_grad,
ir::OpResult axis);
} // namespace dialect
} // namespace paddle
......@@ -16,6 +16,7 @@
#include "paddle/fluid/ir/dialect/pd_op.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/lazy_tensor.h"
#include "paddle/ir/core/builtin_op.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/phi/common/int_array.h"
......@@ -24,6 +25,42 @@
namespace paddle {
namespace dialect {
using IntArray = paddle::experimental::IntArray;
std::vector<std::vector<ir::OpResult>> ConcatOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<bool>>& stop_gradients) {
ConcatOp op_obj = op->dyn_cast<ConcatOp>();
ir::CombineOp combine_op_obj =
op_obj.x().GetDefiningOp()->dyn_cast<ir::CombineOp>();
std::vector<Tensor> x;
for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {
x.emplace_back(
std::make_shared<primitive::LazyTensor>(combine_op_obj.inputs()[idx]));
}
Tensor out_grad(std::make_shared<primitive::LazyTensor>(out_grads[0][0]));
Tensor axis(std::make_shared<primitive::LazyTensor>(op_obj.axis()));
std::vector<std::vector<Tensor>> tensor_res =
primitive::concat_vjp(x, out_grad, axis, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(tensor_res.size(),
std::vector<ir::OpResult>());
for (uint64_t i = 0; i < tensor_res.size(); i++) {
res[i].resize(tensor_res[i].size());
for (uint64_t j = 0; j < tensor_res[i].size(); j++) {
if (tensor_res[i][j].defined()) {
res[i][j] = std::static_pointer_cast<primitive::LazyTensor>(
tensor_res[i][j].impl())
->getValue()
.dyn_cast<ir::OpResult>();
}
}
}
return res;
}
std::vector<std::vector<ir::OpResult>> SumOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
......
......@@ -242,6 +242,38 @@ Tensor sum_grad<LazyTensor>(const Tensor& x,
x_res, out_grad_res, axis.GetData(), keepdim, reduce_all);
return Tensor(std::make_shared<LazyTensor>(op_res));
}
template <>
std::vector<Tensor> concat_grad<LazyTensor>(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis) {
std::vector<ir::OpResult> x_res;
for (uint64_t idx = 0; idx < x.size(); idx++) {
x_res.emplace_back(std::static_pointer_cast<LazyTensor>(x[idx].impl())
->getValue()
.dyn_cast<ir::OpResult>());
}
ir::OpResult out_grad_res =
std::static_pointer_cast<LazyTensor>(out_grad.impl())
->getValue()
.dyn_cast<ir::OpResult>();
ir::OpResult axis_res = std::static_pointer_cast<LazyTensor>(axis.impl())
->getValue()
.dyn_cast<ir::OpResult>();
std::vector<ir::OpResult> op_res =
paddle::dialect::concat_grad(x_res, out_grad_res, axis_res);
std::vector<Tensor> op_result;
for (uint64_t idx = 0; idx < op_res.size(); idx++) {
op_result.emplace_back(
std::make_shared<primitive::LazyTensor>(op_res[idx]));
}
return op_result;
}
} // namespace backend
} // namespace primitive
} // namespace paddle
......@@ -37,6 +37,11 @@ Tensor mean_grad(const Tensor& x,
bool keepdim = false,
bool reduce_all = false);
template <typename T>
std::vector<Tensor> concat_grad(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis);
template <typename T>
std::tuple<Tensor, Tensor> add_grad(const Tensor& x,
const Tensor& y,
......
......@@ -148,6 +148,28 @@ std::vector<std::vector<paddle::Tensor>> add_vjp(
return vjp_res;
}
std::vector<std::vector<paddle::Tensor>> concat_vjp(
const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(2, std::vector<Tensor>());
// get concat_grad res.
std::vector<Tensor> op_res =
backend::concat_grad<primitive::LazyTensor>(x, out_grad, axis);
// construct vjp result by op result and stop_gradients info
vjp_res[0].resize(op_res.size());
for (uint64_t idx = 0; idx < op_res.size(); idx++) {
if (!stop_gradients[0][idx]) {
vjp_res[0][idx] = op_res[idx];
}
}
// vjp_res[1] is axis's grad which is attribute (no grad).
vjp_res[1].resize(1);
return vjp_res;
}
std::vector<std::vector<paddle::Tensor>> divide_vjp(
const Tensor& x,
const Tensor& y,
......
......@@ -38,6 +38,12 @@ std::vector<std::vector<paddle::Tensor>> mean_vjp(
bool reduce_all,
const std::vector<std::vector<bool>>& stop_gradients);
std::vector<std::vector<paddle::Tensor>> concat_vjp(
const std::vector<Tensor>& x,
const Tensor& out_grad,
const Tensor& axis,
const std::vector<std::vector<bool>>& stop_gradients);
std::vector<std::vector<paddle::Tensor>> add_vjp(
const Tensor& x,
const Tensor& y,
......
......@@ -249,6 +249,7 @@ void BindValue(py::module *m) {
.def("get_defining_op",
&Value::GetDefiningOp,
return_value_policy::reference)
.def("first_use", &Value::first_use, return_value_policy::reference)
.def("__eq__", &Value::operator==)
.def("__eq__",
[](Value &self, OpResult &other) {
......@@ -272,9 +273,11 @@ void BindOpOperand(py::module *m) {
op_operand
.def("source",
[](OpOperand &self) { return self.source().dyn_cast<OpResult>(); })
.def("set_source", [](OpOperand &self, const OpResult &result) {
.def("set_source",
[](OpOperand &self, const OpResult &result) {
self.set_source(result);
});
})
.def("owner", &OpOperand::owner, return_value_policy::reference);
}
bool GetStopGradient(const OpResult &self) {
......@@ -331,6 +334,7 @@ void BindOpResult(py::module *m) {
.def("get_defining_op",
&OpResult::GetDefiningOp,
return_value_policy::reference)
.def("first_use", &OpResult::first_use, return_value_policy::reference)
.def("use_empty", &OpResult::use_empty)
.def("type", &OpResult::type)
.def_property(
......
......@@ -40,7 +40,11 @@ static PyObject *divide(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_divide(self, args, kwargs);
}
static PyMethodDef OpsAPI[] = {{"add_n", // NOLINT
static PyObject *concat(PyObject *self, PyObject *args, PyObject *kwargs) {
return static_api_concat(self, args, kwargs);
}
static PyMethodDef OpsAPI[] = {{"add_n",
(PyCFunction)(void (*)(void))add_n,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for add_n."},
......@@ -56,6 +60,10 @@ static PyMethodDef OpsAPI[] = {{"add_n", // NOLINT
(PyCFunction)(void (*)(void))divide,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for divide."},
{"concat",
(PyCFunction)(void (*)(void))concat,
METH_VARARGS | METH_KEYWORDS,
"C++ interface function for concat."},
{"full",
(PyCFunction)(void (*)(void))full,
METH_VARARGS | METH_KEYWORDS,
......
......@@ -109,6 +109,26 @@ PyObject *static_api_divide(PyObject *self, PyObject *args, PyObject *kwargs) {
}
}
PyObject *static_api_concat(PyObject *self, PyObject *args, PyObject *kwargs) {
try {
VLOG(6) << "Add concat op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
PyObject *x_obj = PyTuple_GET_ITEM(args, 0);
auto x = CastPyArg2VectorOfOpResult("concat", x_obj, 0);
PyObject *axis_obj = PyTuple_GET_ITEM(args, 1);
paddle::experimental::Scalar axis = CastPyArg2Scalar(axis_obj, "concat", 1);
// Call ir static api
auto out = paddle::dialect::concat(x, axis.to<float>());
return ToPyObject(out);
} catch (...) {
ThrowExceptionToPython(std::current_exception());
return nullptr;
}
}
PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs) {
try {
VLOG(6) << "Add full op into program";
......
......@@ -28,6 +28,7 @@ PyObject *static_api_add_n(PyObject *self, PyObject *args, PyObject *kwargs);
PyObject *static_api_mean(PyObject *self, PyObject *args, PyObject *kwargs);
PyObject *static_api_sum(PyObject *self, PyObject *args, PyObject *kwargs);
PyObject *static_api_divide(PyObject *self, PyObject *args, PyObject *kwargs);
PyObject *static_api_concat(PyObject *self, PyObject *args, PyObject *kwargs);
PyObject *static_api_full(PyObject *self, PyObject *args, PyObject *kwargs);
} // namespace pybind
......
......@@ -216,11 +216,11 @@ def prune_ops(total_ops, inputs_set, outputs_set, no_grad_set):
return effective_ops, uneffective_ops
def update_no_grad_set_after_purne(
def update_no_grad_set_after_prune(
block, effective_forward_op, no_grad_set, inputs, outputs
):
'''
update no_grad_set after forward purne
update no_grad_set after forward prune
from inputs to outputs add value not in the path to no_grad_set,
from outputs to inputs add value not in the path to no_grad_set,
......@@ -338,19 +338,19 @@ def append_backward_ops(
else continue to next op.
'''
def make_output_grad(op, split_op):
def make_output_grad(op):
zero_flag = [False] * op.num_results()
for i, value in enumerate(op.results()):
if (
value not in state.value_to_valuegrad
or state.value_to_valuegrad[value] is None
):
if split_op is not None and value == split_op.operand_source(0):
if value.first_use().owner().name() == "builtin.split":
# pattern case:
# this fwd_op's output is vectorType, it will split to
# Type by builtin.split op, so need get from split op's ouput
split_zero_flag, split_output_grad = make_output_grad(
split_op, None
value.first_use().owner()
)
zero_flag[i] = all(split_zero_flag)
grad_value = [op_list[0] for op_list in split_output_grad]
......@@ -400,11 +400,11 @@ def append_backward_ops(
output_grad = state.value_to_valuegrad[value][0]
return zero_flag, output_grad
def make_input_stopgradient(combine_op, op):
def make_input_stopgradient(op):
input_grad_stopgradient_list = []
for input in op.operands_source():
if combine_op is not None and input == combine_op.result(0):
stop_gradient = make_input_stopgradient(None, combine_op)
if input.get_defining_op().name() == "builtin.combine":
stop_gradient = make_input_stopgradient(input.get_defining_op())
input_grad_stopgradient_list.append(
[info[0] for info in stop_gradient]
)
......@@ -413,13 +413,14 @@ def append_backward_ops(
input_grad_stopgradient_list.append([True])
else:
input_grad_stopgradient_list.append([False])
return input_grad_stopgradient_list
def update_input_grad_map(combine_op, op, input_grad_list):
def update_input_grad_map(op, input_grad_list):
for i, input in enumerate(op.operands_source()):
if combine_op is not None and input == combine_op.reslut(0):
update_input_grad_map(None, combine_op, input_grad_list[i])
if input.get_defining_op().name() == "builtin.combine":
update_input_grad_map(
input.get_defining_op(), input_grad_list[i]
)
else:
input_grad = input_grad_list[i]
if isinstance(input_grad, list):
......@@ -427,48 +428,24 @@ def append_backward_ops(
else:
state.value_to_valuegrad[input].append([input_grad])
# make op to op pattern, there are four patterns:
# there are four patterns:
# [builtin.combine , op1] (op1's one input is vectorType, outputs are not vectorType)
# [op2 , builtin.split] (op2's inputs are not vectorType, one output is vectorType)
# [builtin.combine , op3 , buitin.split] (op3's one input and one output are vectorType)
# [op4] (op4's inputs and outputs are not vectorType)
# einsum has twp vectorType outputs, special pattern
pattern_effective_op_list = []
for idx, op in enumerate(effective_forward_op):
if op.name() == "builtin.combine":
pattern_effective_op_list.append([op])
pattern_effective_op_list[-1].append(effective_forward_op[idx + 1])
elif op.name() == "builtin.split":
pattern_effective_op_list[-1].append(op)
else:
if (
not pattern_effective_op_list
or op not in pattern_effective_op_list[-1]
):
pattern_effective_op_list.append([op])
for op_pattern in pattern_effective_op_list:
combine_op = None
split_op = None
if len(op_pattern) == 1:
op = op_pattern[0]
elif len(op_pattern) == 2:
if op_pattern[0] == 'builtin.combine':
combine_op = op_pattern[0]
op = op_pattern[1]
else:
op = op_pattern[0]
split_op = op_pattern[1]
else:
combine_op = op_pattern[0]
op = op_pattern[1]
split_op = op_pattern[2]
clear_effective_forward_op = []
for op in effective_forward_op:
if op.name() != "builtin.combine" and op.name() != "builtin.split":
clear_effective_forward_op.append(op)
for op in clear_effective_forward_op:
if paddle.framework.core.has_vjp(op):
# prepare output_grad
output_grad_list = [] # (opresult)
zero_flag, output_grad = make_output_grad(op, split_op)
zero_flag, output_grad = make_output_grad(op)
output_grad_list.append(output_grad)
# all(zero_flag) support this op has no contribution for grad
......@@ -477,9 +454,7 @@ def append_backward_ops(
continue
# prepare input_grad stop_gradient info.
input_grad_stopgradient_list = make_input_stopgradient(
combine_op, op
)
input_grad_stopgradient_list = make_input_stopgradient(op)
# create grad_op
before_ops_num = len(block.ops)
......@@ -495,7 +470,7 @@ def append_backward_ops(
)
# update input_grad map
update_input_grad_map(combine_op, op, input_grad_list)
update_input_grad_map(op, input_grad_list)
else:
if op.num_operands() == 0 and op.num_results() != 0:
......@@ -526,17 +501,23 @@ def append_backward_ops(
state.op_to_opgrad[op] = []
def create_backward_purne_set(inputs, outputs, no_grad_set, state):
def create_backward_prune_set(inputs, outputs, no_grad_set, state):
outputs_set = set()
for input in inputs:
if state.value_to_valuegrad[input] != []:
outputs_set.add(state.value_to_valuegrad[input][0][0])
for item in input.first_use().owner().operands_source():
if state.value_to_valuegrad[item] != []:
outputs_set.add(state.value_to_valuegrad[item][0][0])
inputs_set = set()
for output in outputs:
if state.value_to_valuegrad[output] != []:
inputs_set.add(state.value_to_valuegrad[output][0][0])
inputs_set_tmp = set()
for out_grad in inputs_set:
for item in out_grad.first_use().owner().operands_source():
inputs_set_tmp.add(item)
inputs_set.update(inputs_set_tmp)
no_gradvar_set = set() # grad_value of value in no_grad_set
for key in state.value_to_valuegrad:
if key in no_grad_set:
......@@ -590,31 +571,31 @@ def calc_gradient_helper(outputs, inputs, grad_outputs, no_grad_set):
effective_forward_op, _ = prune_ops(
block.ops, inputs_set, outputs_set, no_grad_set
)
update_no_grad_set_after_purne(
update_no_grad_set_after_prune(
block, effective_forward_op, no_grad_set, inputs, complete_outputs
)
sorted_effective_forward_op = inverse_sort_op(effective_forward_op)
inverse_effective_forward_op = inverse_sort_op(effective_forward_op)
append_backward_ops(
block, sorted_effective_forward_op, no_grad_set, backward_ops, state
block, inverse_effective_forward_op, no_grad_set, backward_ops, state
)
# now value_to_valuegrad should be value <-> value (add sum op for the same values's gradvalue)
outputs_set, inputs_set, no_gradvar_set = create_backward_purne_set(
outputs_set, inputs_set, no_gradvar_set = create_backward_prune_set(
inputs, complete_outputs, no_grad_set, state
)
_, remove_ops = prune_ops(
backward_ops, inputs_set, outputs_set, no_gradvar_set
)
state.turn_map()
state.turn_map()
for bwd_op in inverse_sort_op(remove_ops):
remove_op(block, bwd_op, state)
state.turn_map()
input_grad_map = state.value_to_valuegrad
state.turn_map()
return input_grad_map
......
......@@ -1120,6 +1120,10 @@ def concat(x, axis=0, name=None):
input = [t for t in input if t.shape.count(0) == 0]
return _C_ops.concat(input, axis)
else:
if paddle.ir.core._use_new_ir_api():
if not isinstance(input, Variable):
input = [t for t in input if t.shape.count(0) == 0]
return paddle._ir_ops.concat(input, axis)
check_type(input, 'input', (list, tuple, Variable), 'concat')
if not isinstance(input, Variable):
for id, x in enumerate(input):
......
......@@ -205,6 +205,68 @@ TEST(VJP, MeanBackwardTest) {
ASSERT_EQ(grad_out_tensor.data<float>()[3], 0.25);
}
TEST(VJP, ConcatBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
paddle::dialect::APIBuilder::Instance().SetProgram(&program);
std::shared_ptr<ir::Builder> builder =
paddle::dialect::APIBuilder::Instance().GetBuilder();
paddle::dialect::FullOp op1 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1, 2}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
std::vector<ir::OpResult> combine_input{{op1.out(), op1.out()}};
ir::CombineOp op2 = builder->Build<ir::CombineOp>(combine_input);
paddle::dialect::ConcatOp op3 =
builder->Build<paddle::dialect::ConcatOp>(op2.out(), 0);
paddle::dialect::FullOp op4 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());
std::vector<std::vector<bool>> stop_gradients{{false, false}};
std::vector<std::vector<ir::OpResult>> out_grads{{op4.out()}};
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.concat");
auto concat_vjp_interface_impl =
op2_info.GetInterfaceImpl<paddle::dialect::VjpInterface>();
concat_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients);
auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program);
auto place = platform::CPUPlace();
Scope scope;
ProgramDesc prog_desc;
InterpreterCore test_core(place, {}, std::move(kernel_program), &scope);
std::stringstream os;
os << reinterpret_cast<NewIRInterpreter*>(
const_cast<InterpreterBaseImpl*>(test_core.Impl()));
std::string prefix_str = os.str();
test_core.SetSkipGcVars({prefix_str + "_inner_var_3",
prefix_str + "_inner_var_7",
prefix_str + "_inner_var_8"});
test_core.Run({});
auto out_tensor =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_3")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_3")
->Get<phi::DenseTensor>();
auto grad_out_tensor_0 =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_7")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_7")
->Get<phi::DenseTensor>();
auto grad_out_tensor_1 =
test_core.local_scope() == nullptr
? scope.FindVar(prefix_str + "_inner_var_8")->Get<phi::DenseTensor>()
: test_core.local_scope()
->FindVar(prefix_str + "_inner_var_8")
->Get<phi::DenseTensor>();
ASSERT_EQ(out_tensor.data<float>()[0], 2.0);
ASSERT_EQ(grad_out_tensor_0.data<float>()[0], 1.0);
ASSERT_EQ(grad_out_tensor_0.data<float>()[1], 1.0);
ASSERT_EQ(grad_out_tensor_1.data<float>()[0], 1.0);
ASSERT_EQ(grad_out_tensor_1.data<float>()[1], 1.0);
}
TEST(VJP, AddBackwardTest) {
ir::IrContext* ctx = ir::IrContext::Instance();
ir::Program program((ctx));
......
......@@ -102,5 +102,24 @@ class TestBuildOp3(unittest.TestCase):
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
class TestBuildOp4(unittest.TestCase):
def test_build_concat_op(self):
newir_program = get_ir_program()
tanh_out = newir_program.block().ops[-1].result(0)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
with paddle.ir.core.program_guard(newir_program):
out = paddle.concat([tanh_out, tanh_out], 0)
self.assertEqual(out.get_defining_op().name(), "pd.concat")
self.assertEqual(
out.get_defining_op()
.operands()[0]
.source()
.get_defining_op()
.name(),
"builtin.combine",
)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
if __name__ == "__main__":
unittest.main()
......@@ -116,7 +116,6 @@ def get_ir_program_1():
class TesBackward_2(unittest.TestCase):
def test_add_n(self):
# test add_n op
newir_program = get_ir_program_1()
input_x = newir_program.block().ops[-3].operand(0).source()
......@@ -130,6 +129,43 @@ class TesBackward_2(unittest.TestCase):
self.assertEqual(
newir_program.block().ops[-2].name(), "builtin.combine"
)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
def test_concat(self):
newir_program = get_ir_program_1()
input_x = newir_program.block().ops[-3].operand(0).source()
add_out = newir_program.block().ops[-1].result(0)
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": True})
with paddle.ir.core.program_guard(newir_program):
out = paddle.concat([add_out, add_out])
input_grad = grad(out, input_x)
ops_name = [
"pd.data",
"pd.data",
"pd.tanh",
"pd.tanh",
"pd.add",
"builtin.combine",
"pd.full",
"pd.concat",
"pd.full",
"builtin.combine",
"pd.concat_grad",
"builtin.split",
"builtin.combine",
"pd.add_n",
"pd.add_grad",
"pd.tanh_grad",
"pd.tanh_grad",
"builtin.combine",
"pd.add_n",
]
for i, op in enumerate(newir_program.block().ops):
self.assertEqual(op.name(), ops_name[i])
paddle.framework.set_flags({"FLAGS_enable_new_ir_api": False})
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册