未验证 提交 533b62ff 编写于 作者: C Charles-hit 提交者: GitHub

[PRIM][IR]fix comment for vjp (#56137)

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* add vjp interface

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* [prim][newir] add basic framework for primitive

* support desctensor in new ir

* support vjp in new ir

* support vjp in new ir

* polish vjp interface

* fix stop_gradients set

* fix vjp dispatch

* add comment

* add vjp test for new ir

* add test for tanh vjp

* add eager and static backend for warp lower level api

* support call_vjp pybind

* polish code and add test for vjp

* remove useless code

* polish code

* remove useless code

* support mean vjp

* add test for mean vjp and support has_vjp function

* fix call_vjp

* polish code

* add primitive ops set for backend

* add vjp test for tanh_

* fix inference CI

* fix inference ci

* modify fluid cmake

* remove useless deps

* add cmake

* fix comment

* fix test

* polish code

* modify backward stop_gradients

* modify static_backend.cc

* remove useless code

---------
Co-authored-by: Ncxxly <chenxx_id@163.com>
Co-authored-by: Nzhangbo9674 <zhangbo54@baidu.com>
上级 24771dd6
...@@ -40,5 +40,5 @@ def gen_exclusive_interface_str(op_info): ...@@ -40,5 +40,5 @@ def gen_exclusive_interface_str(op_info):
" static void InferMeta( phi::InferMetaContext *infer_meta );" " static void InferMeta( phi::InferMetaContext *infer_meta );"
) )
if op_info.op_phi_name[0] in vjp_interface_gen_op_list: if op_info.op_phi_name[0] in vjp_interface_gen_op_list:
exclusive_interface_str += "\n static std::vector<std::vector<ir::OpResult>> Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<int>>& stop_gradients);" exclusive_interface_str += "\n static std::vector<std::vector<ir::OpResult>> Vjp(ir::Operation* op, const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
return exclusive_interface_str return exclusive_interface_str
...@@ -72,7 +72,7 @@ ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out) { ...@@ -72,7 +72,7 @@ ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out) {
ir::OpResult mean_grad(ir::OpResult x, ir::OpResult mean_grad(ir::OpResult x,
ir::OpResult out_grad, ir::OpResult out_grad,
std::vector<int64_t> axis, const std::vector<int64_t>& axis,
bool keepdim, bool keepdim,
bool reduce_all) { bool reduce_all) {
paddle::dialect::MeanGradOp mean_grad_op = paddle::dialect::MeanGradOp mean_grad_op =
......
...@@ -44,7 +44,7 @@ ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out); ...@@ -44,7 +44,7 @@ ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out);
ir::OpResult mean_grad(ir::OpResult x, ir::OpResult mean_grad(ir::OpResult x,
ir::OpResult out_grad, ir::OpResult out_grad,
std::vector<int64_t> axis = {}, const std::vector<int64_t>& axis = {},
bool keepdim = false, bool keepdim = false,
bool reduce_all = false); bool reduce_all = false);
} // namespace dialect } // namespace dialect
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable.h"
#include "paddle/ir/core/dialect.h" #include "paddle/ir/core/dialect.h"
#include "paddle/ir/core/enforce.h" #include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/macros.h"
#include "paddle/ir/core/parameter.h" #include "paddle/ir/core/parameter.h"
#include "paddle/ir/core/program.h" #include "paddle/ir/core/program.h"
...@@ -92,7 +93,7 @@ class APIBuilder { ...@@ -92,7 +93,7 @@ class APIBuilder {
ctx_->GetOrRegisterDialect<paddle::dialect::PaddleDialect>(); ctx_->GetOrRegisterDialect<paddle::dialect::PaddleDialect>();
} }
APIBuilder(const APIBuilder&) = delete; DISABLE_COPY_AND_ASSIGN(APIBuilder);
ir::IrContext* ctx_; ir::IrContext* ctx_;
std::shared_ptr<ir::Builder> builder_; std::shared_ptr<ir::Builder> builder_;
......
...@@ -17,16 +17,19 @@ ...@@ -17,16 +17,19 @@
#include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/desc_tensor.h" #include "paddle/fluid/primitive/type/desc_tensor.h"
#include "paddle/ir/core/op_base.h" #include "paddle/ir/core/op_base.h"
#include "paddle/phi/common/int_array.h"
// TODO(wanghao107) // TODO(wanghao107)
// this file will be generated in pd_op.cc // this file will be generated in pd_op.cc
namespace paddle { namespace paddle {
namespace dialect { namespace dialect {
using IntArray = paddle::experimental::IntArray;
std::vector<std::vector<ir::OpResult>> TanhOp::Vjp( std::vector<std::vector<ir::OpResult>> TanhOp::Vjp(
ir::Operation* op, ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) { const std::vector<std::vector<bool>>& stop_gradients) {
TanhOp op_obj = op->dyn_cast<TanhOp>(); TanhOp op_obj = op->dyn_cast<TanhOp>();
Tensor out( Tensor out(
std::make_shared<primitive::experimental::DescTensor>(op_obj.out())); std::make_shared<primitive::experimental::DescTensor>(op_obj.out()));
...@@ -35,7 +38,7 @@ std::vector<std::vector<ir::OpResult>> TanhOp::Vjp( ...@@ -35,7 +38,7 @@ std::vector<std::vector<ir::OpResult>> TanhOp::Vjp(
std::vector<std::vector<Tensor>> tensor_res = std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::tanh_vjp(out, grad_out, stop_gradients); primitive::experimental::tanh_vjp(out, grad_out, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1)); std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (!stop_gradients[0][0]) { if (tensor_res[0][0].defined()) {
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>( res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[0][0].impl()) tensor_res[0][0].impl())
->getValue() ->getValue()
...@@ -47,7 +50,7 @@ std::vector<std::vector<ir::OpResult>> TanhOp::Vjp( ...@@ -47,7 +50,7 @@ std::vector<std::vector<ir::OpResult>> TanhOp::Vjp(
std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp( std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp(
ir::Operation* op, ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) { const std::vector<std::vector<bool>>& stop_gradients) {
// TODO(wanghao107) // TODO(wanghao107)
// we don't support inplace now, // we don't support inplace now,
// so use the non-inplace version instead currently. // so use the non-inplace version instead currently.
...@@ -60,7 +63,7 @@ std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp( ...@@ -60,7 +63,7 @@ std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp(
std::vector<std::vector<Tensor>> tensor_res = std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::tanh_vjp(out, grad_out, stop_gradients); primitive::experimental::tanh_vjp(out, grad_out, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1)); std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (!stop_gradients[0][0]) { if (tensor_res[0][0].defined()) {
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>( res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[0][0].impl()) tensor_res[0][0].impl())
->getValue() ->getValue()
...@@ -72,24 +75,22 @@ std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp( ...@@ -72,24 +75,22 @@ std::vector<std::vector<ir::OpResult>> Tanh_Op::Vjp(
std::vector<std::vector<ir::OpResult>> MeanOp::Vjp( std::vector<std::vector<ir::OpResult>> MeanOp::Vjp(
ir::Operation* op, ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) { const std::vector<std::vector<bool>>& stop_gradients) {
MeanOp op_obj = op->dyn_cast<MeanOp>(); MeanOp op_obj = op->dyn_cast<MeanOp>();
Tensor x(std::make_shared<primitive::experimental::DescTensor>(op_obj.x())); Tensor x(std::make_shared<primitive::experimental::DescTensor>(op_obj.x()));
Tensor out_grad( Tensor out_grad(
std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0])); std::make_shared<primitive::experimental::DescTensor>(out_grads[0][0]));
std::vector<int64_t> axis = IntArray axis = op->attribute("axis")
op->attribute("axis") .dyn_cast<paddle::dialect::IntArrayAttribute>()
.dyn_cast<paddle::dialect::IntArrayAttribute>() .data();
.data()
.GetData();
bool keepdim = op->attribute("keepdim").dyn_cast<ir::BoolAttribute>().data(); bool keepdim = op->attribute("keepdim").dyn_cast<ir::BoolAttribute>().data();
bool reduce_all = false; bool reduce_all = false;
std::vector<std::vector<Tensor>> tensor_res = std::vector<std::vector<Tensor>> tensor_res =
primitive::experimental::mean_vjp( primitive::experimental::mean_vjp(
x, out_grad, axis, keepdim, reduce_all, stop_gradients); x, out_grad, axis, keepdim, reduce_all, stop_gradients);
std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1)); std::vector<std::vector<ir::OpResult>> res(1, std::vector<ir::OpResult>(1));
if (!stop_gradients[0][0]) { if (tensor_res[0][0].defined()) {
res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>( res[0][0] = std::static_pointer_cast<primitive::experimental::DescTensor>(
tensor_res[0][0].impl()) tensor_res[0][0].impl())
->getValue() ->getValue()
......
...@@ -23,12 +23,12 @@ class VjpInterface : public ir::OpInterfaceBase<VjpInterface> { ...@@ -23,12 +23,12 @@ class VjpInterface : public ir::OpInterfaceBase<VjpInterface> {
explicit Concept(std::vector<std::vector<ir::OpResult>> (*vjp)( explicit Concept(std::vector<std::vector<ir::OpResult>> (*vjp)(
ir::Operation* op, ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients)) const std::vector<std::vector<bool>>& stop_gradients))
: vjp_(vjp) {} : vjp_(vjp) {}
std::vector<std::vector<ir::OpResult>> (*vjp_)( std::vector<std::vector<ir::OpResult>> (*vjp_)(
ir::Operation* op, ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients); const std::vector<std::vector<bool>>& stop_gradients);
}; };
template <class ConcreteOp> template <class ConcreteOp>
...@@ -36,7 +36,7 @@ class VjpInterface : public ir::OpInterfaceBase<VjpInterface> { ...@@ -36,7 +36,7 @@ class VjpInterface : public ir::OpInterfaceBase<VjpInterface> {
static std::vector<std::vector<ir::OpResult>> Vjp( static std::vector<std::vector<ir::OpResult>> Vjp(
ir::Operation* op, ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) { const std::vector<std::vector<bool>>& stop_gradients) {
return ConcreteOp::Vjp(op, out_grads, stop_gradients); return ConcreteOp::Vjp(op, out_grads, stop_gradients);
} }
...@@ -49,7 +49,7 @@ class VjpInterface : public ir::OpInterfaceBase<VjpInterface> { ...@@ -49,7 +49,7 @@ class VjpInterface : public ir::OpInterfaceBase<VjpInterface> {
std::vector<std::vector<ir::OpResult>> Vjp( std::vector<std::vector<ir::OpResult>> Vjp(
ir::Operation* op, ir::Operation* op,
const std::vector<std::vector<ir::OpResult>>& out_grads, const std::vector<std::vector<ir::OpResult>>& out_grads,
const std::vector<std::vector<int>>& stop_gradients) { const std::vector<std::vector<bool>>& stop_gradients) {
return impl_->vjp_(op, out_grads, stop_gradients); return impl_->vjp_(op, out_grads, stop_gradients);
} }
......
if(NOT (NOT WITH_PYTHON AND ON_INFER)) if(WITH_PYTHON OR NOT ON_INFER)
cc_library( cc_library(
primitive_backend_eager_experimental primitive_backend_eager_experimental
SRCS eager_backend.cc SRCS eager_backend.cc
......
...@@ -42,7 +42,7 @@ Tensor tanh_grad<DescTensor>(const Tensor& out, const Tensor& grad_out) { ...@@ -42,7 +42,7 @@ Tensor tanh_grad<DescTensor>(const Tensor& out, const Tensor& grad_out) {
template <> template <>
Tensor mean_grad<DescTensor>(const Tensor& x, Tensor mean_grad<DescTensor>(const Tensor& x,
const Tensor& out_grad, const Tensor& out_grad,
std::vector<int64_t> axis, const IntArray& axis,
bool keepdim, bool keepdim,
bool reduce_all) { bool reduce_all) {
ir::OpResult x_res = std::static_pointer_cast<DescTensor>(x.impl()) ir::OpResult x_res = std::static_pointer_cast<DescTensor>(x.impl())
...@@ -54,7 +54,7 @@ Tensor mean_grad<DescTensor>(const Tensor& x, ...@@ -54,7 +54,7 @@ Tensor mean_grad<DescTensor>(const Tensor& x,
.dyn_cast<ir::OpResult>(); .dyn_cast<ir::OpResult>();
ir::OpResult op_res = paddle::dialect::mean_grad( ir::OpResult op_res = paddle::dialect::mean_grad(
x_res, out_grad_res, axis, keepdim, reduce_all); x_res, out_grad_res, axis.GetData(), keepdim, reduce_all);
return Tensor(std::make_shared<primitive::experimental::DescTensor>(op_res)); return Tensor(std::make_shared<primitive::experimental::DescTensor>(op_res));
} }
......
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <vector> #include <vector>
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
namespace paddle { namespace paddle {
namespace primitive { namespace primitive {
...@@ -25,6 +26,7 @@ namespace backend { ...@@ -25,6 +26,7 @@ namespace backend {
namespace experimental { namespace experimental {
using Tensor = paddle::Tensor; using Tensor = paddle::Tensor;
using IntArray = paddle::experimental::IntArray;
template <typename T> template <typename T>
Tensor tanh_grad(const Tensor& out, const Tensor& grad_out); Tensor tanh_grad(const Tensor& out, const Tensor& grad_out);
...@@ -32,7 +34,7 @@ Tensor tanh_grad(const Tensor& out, const Tensor& grad_out); ...@@ -32,7 +34,7 @@ Tensor tanh_grad(const Tensor& out, const Tensor& grad_out);
template <typename T> template <typename T>
Tensor mean_grad(const Tensor& x, Tensor mean_grad(const Tensor& x,
const Tensor& out_grad, const Tensor& out_grad,
std::vector<int64_t> axis = {}, const IntArray& axis = {},
bool keepdim = false, bool keepdim = false,
bool reduce_all = false); bool reduce_all = false);
} // namespace experimental } // namespace experimental
......
...@@ -12,12 +12,9 @@ ...@@ -12,12 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <math.h> #include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include <vector>
#include "paddle/fluid/ir/dialect/pd_api.h" #include "paddle/fluid/ir/dialect/pd_api.h"
#include "paddle/fluid/primitive/backend/static_backend.h" #include "paddle/fluid/primitive/backend/static_backend.h"
#include "paddle/fluid/primitive/rule/vjp/vjp.h"
#include "paddle/fluid/primitive/type/desc_tensor.h" #include "paddle/fluid/primitive/type/desc_tensor.h"
#include "paddle/ir/core/operation.h" #include "paddle/ir/core/operation.h"
// TODO(wanghao107): // TODO(wanghao107):
...@@ -26,10 +23,11 @@ ...@@ -26,10 +23,11 @@
namespace paddle { namespace paddle {
namespace primitive { namespace primitive {
namespace experimental { namespace experimental {
std::vector<std::vector<paddle::Tensor>> tanh_vjp( std::vector<std::vector<paddle::Tensor>> tanh_vjp(
const Tensor& out, const Tensor& out,
const Tensor& grad_out, const Tensor& grad_out,
const std::vector<std::vector<int>>& stop_gradients) { const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res( std::vector<std::vector<paddle::Tensor>> vjp_res(
1, std::vector<paddle::Tensor>(1)); 1, std::vector<paddle::Tensor>(1));
// get tanh_grad res. // get tanh_grad res.
...@@ -71,10 +69,10 @@ std::vector<std::vector<paddle::Tensor>> tanh_vjp( ...@@ -71,10 +69,10 @@ std::vector<std::vector<paddle::Tensor>> tanh_vjp(
std::vector<std::vector<paddle::Tensor>> mean_vjp( std::vector<std::vector<paddle::Tensor>> mean_vjp(
const Tensor& x, const Tensor& x,
const Tensor& out_grad, const Tensor& out_grad,
std::vector<int64_t> axis, const IntArray& axis,
bool keepdim, bool keepdim,
bool reduce_all, bool reduce_all,
const std::vector<std::vector<int>>& stop_gradients) { const std::vector<std::vector<bool>>& stop_gradients) {
std::vector<std::vector<paddle::Tensor>> vjp_res( std::vector<std::vector<paddle::Tensor>> vjp_res(
1, std::vector<paddle::Tensor>(1)); 1, std::vector<paddle::Tensor>(1));
// get mean_grad res. // get mean_grad res.
......
...@@ -24,24 +24,27 @@ ...@@ -24,24 +24,27 @@
#include "paddle/fluid/primitive/primitive/primitive.h" #include "paddle/fluid/primitive/primitive/primitive.h"
#include "paddle/ir/core/value.h" #include "paddle/ir/core/value.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/int_array.h"
namespace paddle { namespace paddle {
namespace primitive { namespace primitive {
namespace experimental { namespace experimental {
using IntArray = paddle::experimental::IntArray;
// TODO(wanghao107): // TODO(wanghao107):
// op's vjp will be auto generated. // op's vjp will be auto generated.
std::vector<std::vector<paddle::Tensor>> tanh_vjp( std::vector<std::vector<paddle::Tensor>> tanh_vjp(
const Tensor& out, const Tensor& out,
const Tensor& grad_out, const Tensor& grad_out,
const std::vector<std::vector<int>>& stop_gradients); const std::vector<std::vector<bool>>& stop_gradients);
std::vector<std::vector<paddle::Tensor>> mean_vjp( std::vector<std::vector<paddle::Tensor>> mean_vjp(
const Tensor& x, const Tensor& x,
const Tensor& out_grad, const Tensor& out_grad,
std::vector<int64_t> axis, const IntArray& axis,
bool keepdim, bool keepdim,
bool reduce_all, bool reduce_all,
const std::vector<std::vector<int>>& stop_gradients); const std::vector<std::vector<bool>>& stop_gradients);
namespace details { namespace details {
// NOTE: this namespace will store // NOTE: this namespace will store
......
...@@ -43,14 +43,11 @@ class DescTensor : public phi::ExtendedTensor, ...@@ -43,14 +43,11 @@ class DescTensor : public phi::ExtendedTensor,
ir::Value getValue() const { return value_; } ir::Value getValue() const { return value_; }
const phi::Place& place() const override { return place_; }
bool initialized() const override { return value_.impl() != nullptr; } bool initialized() const override { return value_.impl() != nullptr; }
private: private:
ir::Value value_; ir::Value value_;
mutable phi::DDim dims_; mutable phi::DDim dims_;
phi::Place place_;
}; };
} // namespace experimental } // namespace experimental
......
...@@ -693,7 +693,7 @@ void BindVjp(pybind11::module *m) { ...@@ -693,7 +693,7 @@ void BindVjp(pybind11::module *m) {
"call_vjp", "call_vjp",
[](ir::Operation &fwd_op, [](ir::Operation &fwd_op,
const std::vector<std::vector<ir::OpResult>> &out_grads, const std::vector<std::vector<ir::OpResult>> &out_grads,
const std::vector<std::vector<int>> &stop_gradients) { const std::vector<std::vector<bool>> &stop_gradients) {
py::list res; py::list res;
ir::IrContext *ctx = ir::IrContext::Instance(); ir::IrContext *ctx = ir::IrContext::Instance();
ir::OpInfo fwd_op_info = ctx->GetRegisteredOpInfo(fwd_op.name()); ir::OpInfo fwd_op_info = ctx->GetRegisteredOpInfo(fwd_op.name());
...@@ -731,7 +731,7 @@ void BindVjp(pybind11::module *m) { ...@@ -731,7 +731,7 @@ void BindVjp(pybind11::module *m) {
vjp_res[i].size())); vjp_res[i].size()));
py::list sub_res; py::list sub_res;
for (size_t j = 0; j < vjp_res[i].size(); ++j) { for (size_t j = 0; j < vjp_res[i].size(); ++j) {
if (stop_gradients[i][j]) { if (!vjp_res[i][j]) {
sub_res.append(nullptr); sub_res.append(nullptr);
} else { } else {
sub_res.append(vjp_res[i][j]); sub_res.append(vjp_res[i][j]);
......
...@@ -377,9 +377,9 @@ def append_backward_ops( ...@@ -377,9 +377,9 @@ def append_backward_ops(
input_grad_stopgradient_list = [] input_grad_stopgradient_list = []
for input in op.operands_source(): for input in op.operands_source():
if input in no_grad_set: if input in no_grad_set:
input_grad_stopgradient_list.append([1]) input_grad_stopgradient_list.append([True])
else: else:
input_grad_stopgradient_list.append([0]) input_grad_stopgradient_list.append([False])
before_ops_num = len(block.ops) before_ops_num = len(block.ops)
# prim should be a globel flag, it will make create_grad_op choose diffrient func # prim should be a globel flag, it will make create_grad_op choose diffrient func
......
...@@ -55,7 +55,7 @@ TEST(VJP, TanhBackwardTest) { ...@@ -55,7 +55,7 @@ TEST(VJP, TanhBackwardTest) {
paddle::dialect::FullOp op3 = builder->Build<paddle::dialect::FullOp>( paddle::dialect::FullOp op3 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
std::vector<std::vector<int>> stop_gradients{{0}}; std::vector<std::vector<bool>> stop_gradients{{false}};
std::vector<std::vector<ir::OpResult>> out_grads{{op3.out()}}; std::vector<std::vector<ir::OpResult>> out_grads{{op3.out()}};
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.tanh"); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.tanh");
...@@ -109,7 +109,7 @@ TEST(VJP, Tanh_BackwardTest) { ...@@ -109,7 +109,7 @@ TEST(VJP, Tanh_BackwardTest) {
paddle::dialect::FullOp op3 = builder->Build<paddle::dialect::FullOp>( paddle::dialect::FullOp op3 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector<int64_t>{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace());
std::vector<std::vector<int>> stop_gradients{{0}}; std::vector<std::vector<bool>> stop_gradients{{false}};
std::vector<std::vector<ir::OpResult>> out_grads{{op3.out()}}; std::vector<std::vector<ir::OpResult>> out_grads{{op3.out()}};
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.tanh_"); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.tanh_");
...@@ -163,7 +163,7 @@ TEST(VJP, MeanBackwardTest) { ...@@ -163,7 +163,7 @@ TEST(VJP, MeanBackwardTest) {
paddle::dialect::FullOp op3 = builder->Build<paddle::dialect::FullOp>( paddle::dialect::FullOp op3 = builder->Build<paddle::dialect::FullOp>(
std::vector<int64_t>{}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector<int64_t>{}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace());
std::vector<std::vector<int>> stop_gradients{{0}}; std::vector<std::vector<bool>> stop_gradients{{false}};
std::vector<std::vector<ir::OpResult>> out_grads{{op3.out()}}; std::vector<std::vector<ir::OpResult>> out_grads{{op3.out()}};
ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.mean"); ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.mean");
......
...@@ -41,7 +41,7 @@ class TestTanhVjp(unittest.TestCase): ...@@ -41,7 +41,7 @@ class TestTanhVjp(unittest.TestCase):
tanh_op = newir_program.block().ops[-2] tanh_op = newir_program.block().ops[-2]
fill_constant_op = newir_program.block().ops[-1] fill_constant_op = newir_program.block().ops[-1]
out_grads = [[fill_constant_op.result(0)]] out_grads = [[fill_constant_op.result(0)]]
stop_gradients = [[0]] stop_gradients = [[False]]
with paddle.ir.core.program_guard(newir_program): with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(tanh_op, out_grads, stop_gradients) grad_outs = call_vjp(tanh_op, out_grads, stop_gradients)
self.assertEqual( self.assertEqual(
...@@ -72,7 +72,7 @@ class TestTanhVjp(unittest.TestCase): ...@@ -72,7 +72,7 @@ class TestTanhVjp(unittest.TestCase):
tanh_op = newir_program.block().ops[-2] tanh_op = newir_program.block().ops[-2]
fill_constant_op = newir_program.block().ops[-1] fill_constant_op = newir_program.block().ops[-1]
out_grads = [[fill_constant_op.result(0)]] out_grads = [[fill_constant_op.result(0)]]
stop_gradients = [[1]] stop_gradients = [[True]]
with paddle.ir.core.program_guard(newir_program): with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(tanh_op, out_grads, stop_gradients) grad_outs = call_vjp(tanh_op, out_grads, stop_gradients)
self.assertEqual(grad_outs[0][0], None) self.assertEqual(grad_outs[0][0], None)
...@@ -93,7 +93,7 @@ class TestMeanVjp(unittest.TestCase): ...@@ -93,7 +93,7 @@ class TestMeanVjp(unittest.TestCase):
fill_constant_op = newir_program.block().ops[-1] fill_constant_op = newir_program.block().ops[-1]
mean_op = newir_program.block().ops[-2] mean_op = newir_program.block().ops[-2]
out_grads = [[fill_constant_op.result(0)]] out_grads = [[fill_constant_op.result(0)]]
stop_gradients = [[0]] stop_gradients = [[False]]
with paddle.ir.core.program_guard(newir_program): with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(mean_op, out_grads, stop_gradients) grad_outs = call_vjp(mean_op, out_grads, stop_gradients)
self.assertEqual( self.assertEqual(
...@@ -133,7 +133,7 @@ class TestMeanVjp(unittest.TestCase): ...@@ -133,7 +133,7 @@ class TestMeanVjp(unittest.TestCase):
fill_constant_op = newir_program.block().ops[-1] fill_constant_op = newir_program.block().ops[-1]
mean_op = newir_program.block().ops[-2] mean_op = newir_program.block().ops[-2]
out_grads = [[fill_constant_op.result(0)]] out_grads = [[fill_constant_op.result(0)]]
stop_gradients = [[1]] stop_gradients = [[True]]
with paddle.ir.core.program_guard(newir_program): with paddle.ir.core.program_guard(newir_program):
grad_outs = call_vjp(mean_op, out_grads, stop_gradients) grad_outs = call_vjp(mean_op, out_grads, stop_gradients)
self.assertEqual(grad_outs[0][0], None) self.assertEqual(grad_outs[0][0], None)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册