未验证 提交 533b62ff 编写于 作者: C Charles-hit 提交者: GitHub

[PRIM][IR]fix comment for vjp (#56137)

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* add vjp interface

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* add eager and static backend for warp lower level api

* support call_vjp pybind

* polish code and add test for vjp

* remove useless code

* polish code

* remove useless code

* support mean vjp

* add test for mean vjp and support has_vjp function

* fix call_vjp

* polish code

* add primitive ops set for backend

* add vjp test for tanh_

* fix inference CI

* fix inference ci

* modify fluid cmake

* remove useless deps

* add cmake

* fix comment

* fix test

* polish code

* modify backward stop_gradients

* modify static_backend.cc

* remove useless code

---------
Co-authored-by: Ncxxly <chenxx_id@163.com>
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
上级 24771dd6
......@@ -40,5 +40,5 @@ def gen_exclusive_interface_str(op_info):
" static void InferMeta( phi::InferMetaContext *infer_meta );"
)
if op_info.op_phi_name[0] in vjp_interface_gen_op_list:
exclusive_interface_str += "\n static std::vector<std::vector<ir::OpResult>> Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<int>>& stop_gradients);"
exclusive_interface_str += "\n static std::vector<std::vector<ir::OpResult>> Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
return exclusive_interface_str
......@@ -72,7 +72,7 @@ ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out) {
ir::OpResult mean_grad(ir::OpResult x,
ir::OpResult out_grad,
std::vector<int64_t> axis,
const std::vector<int64_t>& axis,
bool keepdim,
bool reduce_all) {
paddle::dialect::MeanGradOp mean_grad_op =
......
......@@ -44,7 +44,7 @@ ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out);
ir::OpResult mean_grad(ir::OpResult x,
ir::OpResult out_grad,
std::vector<int64_t> axis = {},
const std::vector<int64_t>& axis = {},
bool keepdim = false,
bool reduce_all = false);
} // namespace dialect
......
......@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/variable.h"
#include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/macros.h"
#include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/program.h"
......@@ -92,7 +93,7 @@ class APIBuilder {
ctx_->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
}
APIBuilder(const APIBuilder&) = delete;
DISABLE_COPY_AND_ASSIGN(APIBuilder);
ir::IrContext* ctx_;
std::shared_ptr<ir::Builder> builder_;
......
......@@ -17,16 +17,19 @@
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/desc_tensor.h"
#include "paddle/ir/core/op_base.h"
#include "paddle/phi/common/int_array.h"
// TODO(wanghao107)
// this file will be generated in pd_op.cc
namespace paddle {
namespace dialect {
using IntArray = paddle::experimental::IntArray;
std::vector<std::vector<ir::OpResult>> TanhOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) {
const std::vector<std::vector<bool>>& stop_gradients) {
TanhOp op_obj = op->dyn_cast<TanhOp>();
Tensor out(
std::make_shared<primitive::experimental::DescTensor>(op_obj.out()));
......@@ -35,7 +38,7 @@ std::vector<std::vector<ir::OpResult>> TanhOp::Vjp(
std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::tanh_vjp(out, grad_out, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (!stop_gradients[0][0]) {
if (tensor_res[0][0].defined()) {
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[0][0].impl())
->getValue()
......@@ -47,7 +50,7 @@ std::vector<std::vector<ir::OpResult>> TanhOp::Vjp(
std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) {
const std::vector<std::vector<bool>>& stop_gradients) {
// TODO(wanghao107)
// we don't support inplace now,
// so use the non-inplace version instead currently.
......@@ -60,7 +63,7 @@ std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp(
std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::tanh_vjp(out, grad_out, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (!stop_gradients[0][0]) {
if (tensor_res[0][0].defined()) {
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[0][0].impl())
->getValue()
......@@ -72,24 +75,22 @@ std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp(
std::vector<std::vector<ir::OpResult>> MeanOp::Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) {
const std::vector<std::vector<bool>>& stop_gradients) {
MeanOp op_obj = op->dyn_cast<MeanOp>();
Tensor x(std::make_shared<primitive::experimental::DescTensor>(op_obj.x()));
Tensor out_grad(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0]));
std::vector<int64_t> axis =
op->attribute("axis")
IntArray axis = op->attribute("axis")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data()
.GetData();
.data();
bool keepdim = op->attribute("keepdim").dyn_cast<ir::BoolAttribute>().data();
bool reduce_all = false;
std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::mean_vjp(
x, out_grad, axis, keepdim, reduce_all, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (!stop_gradients[0][0]) {
if (tensor_res[0][0].defined()) {
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[0][0].impl())
->getValue()
......
......@@ -23,12 +23,12 @@ class VjpInterface : public ir::OpInterfaceBase<VjpInterface> {
explicit Concept(std::vector<std::vector<ir::OpResult>> (*vjp)(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients))
const std::vector<std::vector<bool>>& stop_gradients))
: vjp_(vjp) {}
std::vector<std::vector<ir::OpResult>> (*vjp_)(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients);
const std::vector<std::vector<bool>>& stop_gradients);
};
template <class ConcreteOp>
......@@ -36,7 +36,7 @@ class VjpInterface : public ir::OpInterfaceBase<VjpInterface> {
static std::vector<std::vector<ir::OpResult>> Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) {
const std::vector<std::vector<bool>>& stop_gradients) {
return ConcreteOp::Vjp(op, out_grads, stop_gradients);
}
......@@ -49,7 +49,7 @@ class VjpInterface : public ir::OpInterfaceBase<VjpInterface> {
std::vector<std::vector<ir::OpResult>> Vjp(
ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) {
const std::vector<std::vector<bool>>& stop_gradients) {
return impl_->vjp_(op, out_grads, stop_gradients);
}
......
if(NOT (NOT WITH_PYTHON AND ON_INFER))
if(WITH_PYTHON OR NOT ON_INFER)
cc_library(
primitive_backend_eager_experimental
SRCS eager_backend.cc
......
......@@ -42,7 +42,7 @@ Tensor tanh_grad<DescTensor>(const Tensor& out, const Tensor& grad_out) {
template <>
Tensor mean_grad<DescTensor>(const Tensor& x,
const Tensor& out_grad,
std::vector<int64_t> axis,
const IntArray& axis,
bool keepdim,
bool reduce_all) {
ir::OpResult x_res = std::static_pointer_cast<DescTensor>(x.impl())
......@@ -54,7 +54,7 @@ Tensor mean_grad<DescTensor>(const Tensor& x,
.dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::mean_grad(
x_res, out_grad_res, axis, keepdim, reduce_all);
x_res, out_grad_res, axis.GetData(), keepdim, reduce_all);
return Tensor(std::make_shared<primitive::experimental::DescTensor>(op_res));
}
......
......@@ -18,6 +18,7 @@
#include <vector>
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
namespace paddle {
namespace primitive {
......@@ -25,6 +26,7 @@ namespace backend {
namespace experimental {
using Tensor = paddle::Tensor;
using IntArray = paddle::experimental::IntArray;
template <typename T>
Tensor tanh_grad(const Tensor& out, const Tensor& grad_out);
......@@ -32,7 +34,7 @@ Tensor tanh_grad(const Tensor& out, const Tensor& grad_out);
template <typename T>
Tensor mean_grad(const Tensor& x,
const Tensor& out_grad,
std::vector<int64_t> axis = {},
const IntArray& axis = {},
bool keepdim = false,
bool reduce_all = false);
} // namespace experimental
......
......@@ -12,12 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <math.h>
#include <vector>
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/ir/dialect/pd_api.h"
#include "paddle/fluid/primitive/backend/static_backend.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/desc_tensor.h"
#include "paddle/ir/core/operation.h"
// TODO(wanghao107):
......@@ -26,10 +23,11 @@
namespace paddle {
namespace primitive {
namespace experimental {
std::vector<std::vector<paddle::Tensor>> tanh_vjp(
const Tensor& out,
const Tensor& grad_out,
const std::vector<std::vector<int>>& stop_gradients) {
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(
1, std::vector<paddle::Tensor>(1));
// get tanh_grad res.
......@@ -71,10 +69,10 @@ std::vector<std::vector<paddle::Tensor>> tanh_vjp(
std::vector<std::vector<paddle::Tensor>> mean_vjp(
const Tensor& x,
const Tensor& out_grad,
std::vector<int64_t> axis,
const IntArray& axis,
bool keepdim,
bool reduce_all,
const std::vector<std::vector<int>>& stop_gradients) {
const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res(
1, std::vector<paddle::Tensor>(1));
// get mean_grad res.
......
......@@ -24,24 +24,27 @@
#include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/ir/core/value.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
namespace paddle {
namespace primitive {
namespace experimental {
using IntArray = paddle::experimental::IntArray;
// TODO(wanghao107):
// op's vjp will be auto generated.
std::vector<std::vector<paddle::Tensor>> tanh_vjp(
const Tensor& out,
const Tensor& grad_out,
const std::vector<std::vector<int>>& stop_gradients);
const std::vector<std::vector<bool>>& stop_gradients);
std::vector<std::vector<paddle::Tensor>> mean_vjp(
const Tensor& x,
const Tensor& out_grad,
std::vector<int64_t> axis,
const IntArray& axis,
bool keepdim,
bool reduce_all,
const std::vector<std::vector<int>>& stop_gradients);
const std::vector<std::vector<bool>>& stop_gradients);
namespace details {
// NOTE: this namespace will store
......
......@@ -43,14 +43,11 @@ class DescTensor : public phi::ExtendedTensor,
ir::Value getValue() const { return value_; }
const phi::Place& place() const override { return place_; }
bool initialized() const override { return value_.impl() != nullptr; }
private:
ir::Value value_;
mutable phi::DDim dims_;
phi::Place place_;
};
} // namespace experimental
......
......@@ -693,7 +693,7 @@ void BindVjp(pybind11::module *m) {
"call_vjp",
[](ir::Operation &fwd_op,
const std::vector<std::vector<ir::OpResult>> &out_grads,
const std::vector<std::vector<int>> &stop_gradients) {
const std::vector<std::vector<bool>> &stop_gradients) {
py::list res;
ir::IrContext *ctx = ir::IrContext::Instance();
ir::OpInfo fwd_op_info = ctx->GetRegisteredOpInfo(fwd_op.name());
......@@ -731,7 +731,7 @@ void BindVjp(pybind11::module *m) {
vjp_res[i].size()));
py::list sub_res;
for (size_t j = 0; j < vjp_res[i].size(); ++j) {
if (stop_gradients[i][j]) {
if (!vjp_res[i][j]) {
sub_res.append(nullptr);
} else {
sub_res.append(vjp_res[i][j]);
......
......@@ -377,9 +377,9 @@ def append_backward_ops(
input_grad_stopgradient_list = []
for input in op.operands_source():
if input in no_grad_set:
input_grad_stopgradient_list.append([1])
input_grad_stopgradient_list.append([True])
else:
input_grad_stopgradient_list.append([0])
input_grad_stopgradient_list.append([False])
before_ops_num = len(block.ops)
# prim should be a globel flag, it will make create_grad_op choose diffrient func
......
......@@ -55,7 +55,7 @@ TEST(VJP, TanhBackwardTest) {
paddle::dialect::FullOp op3 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
std::vector<std::vector<int>> stop_gradients{{0}};
std::vector<std::vector<bool>> stop_gradients{{false}};
std::vector<std::vector<ir::OpResult>> out_grads{{op3.out()}};
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.tanh");
......@@ -109,7 +109,7 @@ TEST(VJP, Tanh_BackwardTest) {
paddle::dialect::FullOp op3 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
std::vector<std::vector<int>> stop_gradients{{0}};
std::vector<std::vector<bool>> stop_gradients{{false}};
std::vector<std::vector<ir::OpResult>> out_grads{{op3.out()}};
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.tanh_");
......@@ -163,7 +163,7 @@ TEST(VJP, MeanBackwardTest) {
paddle::dialect::FullOp op3 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());
std::vector<std::vector<int>> stop_gradients{{0}};
std::vector<std::vector<bool>> stop_gradients{{false}};
std::vector<std::vector<ir::OpResult>> out_grads{{op3.out()}};
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.mean");
......
......@@ -41,7 +41,7 @@ class TestTanhVjp(unittest.TestCase):
tanh_op = newir_program.block().ops[-2]
fill_constant_op = newir_program.block().ops[-1]
out_grads = [[fill_constant_op.result(0)]]
stop_gradients = [[0]]
stop_gradients = [[False]]
with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(tanh_op, out_grads, stop_gradients)
self.assertEqual(
......@@ -72,7 +72,7 @@ class TestTanhVjp(unittest.TestCase):
tanh_op = newir_program.block().ops[-2]
fill_constant_op = newir_program.block().ops[-1]
out_grads = [[fill_constant_op.result(0)]]
stop_gradients = [[1]]
stop_gradients = [[True]]
with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(tanh_op, out_grads, stop_gradients)
self.assertEqual(grad_outs[0][0], None)
......@@ -93,7 +93,7 @@ class TestMeanVjp(unittest.TestCase):
fill_constant_op = newir_program.block().ops[-1]
mean_op = newir_program.block().ops[-2]
out_grads = [[fill_constant_op.result(0)]]
stop_gradients = [[0]]
stop_gradients = [[False]]
with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(mean_op, out_grads, stop_gradients)
self.assertEqual(
......@@ -133,7 +133,7 @@ class TestMeanVjp(unittest.TestCase):
fill_constant_op = newir_program.block().ops[-1]
mean_op = newir_program.block().ops[-2]
out_grads = [[fill_constant_op.result(0)]]
stop_gradients = [[1]]
stop_gradients = [[True]]
with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(mean_op, out_grads, stop_gradients)
self.assertEqual(grad_outs[0][0], None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册