diff --git a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py index fb22aa2e9b25b95fb68001d9f6f2f5ebd70d4000..4833111c9d2ab3cb36254ca802fce4c501abfcb0 100644 --- a/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/ir/dialect/op_generator/op_interface_gen.py @@ -40,5 +40,5 @@ def gen_exclusive_interface_str(op_info): " static void InferMeta( phi::InferMetaContext *infer_meta );" ) if op_info.op_phi_name[0] in vjp_interface_gen_op_list: - exclusive_interface_str += "\n static std::vector> Vjp(ir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients);" + exclusive_interface_str += "\n static std::vector> Vjp(ir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients);" return exclusive_interface_str diff --git a/paddle/fluid/ir/dialect/pd_api.cc b/paddle/fluid/ir/dialect/pd_api.cc index df88dd9cc734839f7b81c677935ee7bf43624bbd..f65b1e25f9c462e77ccad557aa7a4d1d71c6be0e 100644 --- a/paddle/fluid/ir/dialect/pd_api.cc +++ b/paddle/fluid/ir/dialect/pd_api.cc @@ -72,7 +72,7 @@ ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out) { ir::OpResult mean_grad(ir::OpResult x, ir::OpResult out_grad, - std::vector axis, + const std::vector& axis, bool keepdim, bool reduce_all) { paddle::dialect::MeanGradOp mean_grad_op = diff --git a/paddle/fluid/ir/dialect/pd_api.h b/paddle/fluid/ir/dialect/pd_api.h index a44c8bb83a76a7bbd3604ad5da8f21ce32eada58..5d3b2376314e13aefc0dba073b114e38ef47ebba 100644 --- a/paddle/fluid/ir/dialect/pd_api.h +++ b/paddle/fluid/ir/dialect/pd_api.h @@ -44,7 +44,7 @@ ir::OpResult tanh_grad(ir::OpResult out, ir::OpResult grad_out); ir::OpResult mean_grad(ir::OpResult x, ir::OpResult out_grad, - std::vector axis = {}, + const std::vector& axis = {}, bool keepdim = false, bool reduce_all = false); } // namespace dialect diff --git a/paddle/fluid/ir/dialect/pd_dialect.h b/paddle/fluid/ir/dialect/pd_dialect.h index 1e43a40c55f6b5b0d32e43581bfdcaca5fb38606..4fa14d394248a0540e5f340bdb675d963d0f24f9 100644 --- a/paddle/fluid/ir/dialect/pd_dialect.h +++ b/paddle/fluid/ir/dialect/pd_dialect.h @@ -17,6 +17,7 @@ #include "paddle/fluid/framework/variable.h" #include "paddle/ir/core/dialect.h" #include "paddle/ir/core/enforce.h" +#include "paddle/ir/core/macros.h" #include "paddle/ir/core/parameter.h" #include "paddle/ir/core/program.h" @@ -92,7 +93,7 @@ class APIBuilder { ctx_->GetOrRegisterDialect(); } - APIBuilder(const APIBuilder&) = delete; + DISABLE_COPY_AND_ASSIGN(APIBuilder); ir::IrContext* ctx_; std::shared_ptr builder_; diff --git a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc index 42bb1556aa21109c70c10328c07c2e107a8d05e3..be43ddd60491c462553fd47983fa2c3de9d8b05b 100644 --- a/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc +++ b/paddle/fluid/ir/dialect/pd_op_vjp_manual.cc @@ -17,16 +17,19 @@ #include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/type/desc_tensor.h" #include "paddle/ir/core/op_base.h" +#include "paddle/phi/common/int_array.h" // TODO(wanghao107) // this file will be generated in pd_op.cc namespace paddle { namespace dialect { +using IntArray = paddle::experimental::IntArray; + std::vector> TanhOp::Vjp( ir::Operation* op, const std::vector>& out_grads, - const std::vector>& stop_gradients) { + const std::vector>& stop_gradients) { TanhOp op_obj = op->dyn_cast(); Tensor out( std::make_shared(op_obj.out())); @@ -35,7 +38,7 @@ std::vector> TanhOp::Vjp( std::vector> tensor_res = primitive::experimental::tanh_vjp(out, grad_out, stop_gradients); std::vector> res(1, std::vector(1)); - if (!stop_gradients[0][0]) { + if (tensor_res[0][0].defined()) { res[0][0] = std::static_pointer_cast( tensor_res[0][0].impl()) ->getValue() @@ -47,7 +50,7 @@ std::vector> TanhOp::Vjp( std::vector> Tanh_Op::Vjp( ir::Operation* op, const std::vector>& out_grads, - const std::vector>& stop_gradients) { + const std::vector>& stop_gradients) { // TODO(wanghao107) // we don't support inplace now, // so use the non-inplace version instead currently. @@ -60,7 +63,7 @@ std::vector> Tanh_Op::Vjp( std::vector> tensor_res = primitive::experimental::tanh_vjp(out, grad_out, stop_gradients); std::vector> res(1, std::vector(1)); - if (!stop_gradients[0][0]) { + if (tensor_res[0][0].defined()) { res[0][0] = std::static_pointer_cast( tensor_res[0][0].impl()) ->getValue() @@ -72,24 +75,22 @@ std::vector> Tanh_Op::Vjp( std::vector> MeanOp::Vjp( ir::Operation* op, const std::vector>& out_grads, - const std::vector>& stop_gradients) { + const std::vector>& stop_gradients) { MeanOp op_obj = op->dyn_cast(); Tensor x(std::make_shared(op_obj.x())); Tensor out_grad( std::make_shared(out_grads[0][0])); - std::vector axis = - op->attribute("axis") - .dyn_cast() - .data() - .GetData(); + IntArray axis = op->attribute("axis") + .dyn_cast() + .data(); bool keepdim = op->attribute("keepdim").dyn_cast().data(); bool reduce_all = false; std::vector> tensor_res = primitive::experimental::mean_vjp( x, out_grad, axis, keepdim, reduce_all, stop_gradients); std::vector> res(1, std::vector(1)); - if (!stop_gradients[0][0]) { + if (tensor_res[0][0].defined()) { res[0][0] = std::static_pointer_cast( tensor_res[0][0].impl()) ->getValue() diff --git a/paddle/fluid/ir/interface/vjp.h b/paddle/fluid/ir/interface/vjp.h index 07e64da142f7356e6ef30099174eb0f6d311b2dd..a373cd0bacca4e23c8ca9d21b2d826a43c66c905 100644 --- a/paddle/fluid/ir/interface/vjp.h +++ b/paddle/fluid/ir/interface/vjp.h @@ -23,12 +23,12 @@ class VjpInterface : public ir::OpInterfaceBase { explicit Concept(std::vector> (*vjp)( ir::Operation* op, const std::vector>& out_grads, - const std::vector>& stop_gradients)) + const std::vector>& stop_gradients)) : vjp_(vjp) {} std::vector> (*vjp_)( ir::Operation* op, const std::vector>& out_grads, - const std::vector>& stop_gradients); + const std::vector>& stop_gradients); }; template @@ -36,7 +36,7 @@ class VjpInterface : public ir::OpInterfaceBase { static std::vector> Vjp( ir::Operation* op, const std::vector>& out_grads, - const std::vector>& stop_gradients) { + const std::vector>& stop_gradients) { return ConcreteOp::Vjp(op, out_grads, stop_gradients); } @@ -49,7 +49,7 @@ class VjpInterface : public ir::OpInterfaceBase { std::vector> Vjp( ir::Operation* op, const std::vector>& out_grads, - const std::vector>& stop_gradients) { + const std::vector>& stop_gradients) { return impl_->vjp_(op, out_grads, stop_gradients); } diff --git a/paddle/fluid/primitive/backend/CMakeLists.txt b/paddle/fluid/primitive/backend/CMakeLists.txt index 75e59d0b881638da52b10afc1979a7341e4b293b..26855583b46f9c72640a7795284105f803b041bc 100644 --- a/paddle/fluid/primitive/backend/CMakeLists.txt +++ b/paddle/fluid/primitive/backend/CMakeLists.txt @@ -1,4 +1,4 @@ -if(NOT (NOT WITH_PYTHON AND ON_INFER)) +if(WITH_PYTHON OR NOT ON_INFER) cc_library( primitive_backend_eager_experimental SRCS eager_backend.cc diff --git a/paddle/fluid/primitive/backend/static_backend.cc b/paddle/fluid/primitive/backend/static_backend.cc index b0a515c0d75afe55b539819f562a47575b9bb29d..b041d3710c25d4e73e5bab3efdb3e071b106a2e4 100644 --- a/paddle/fluid/primitive/backend/static_backend.cc +++ b/paddle/fluid/primitive/backend/static_backend.cc @@ -42,7 +42,7 @@ Tensor tanh_grad(const Tensor& out, const Tensor& grad_out) { template <> Tensor mean_grad(const Tensor& x, const Tensor& out_grad, - std::vector axis, + const IntArray& axis, bool keepdim, bool reduce_all) { ir::OpResult x_res = std::static_pointer_cast(x.impl()) @@ -54,7 +54,7 @@ Tensor mean_grad(const Tensor& x, .dyn_cast(); ir::OpResult op_res = paddle::dialect::mean_grad( - x_res, out_grad_res, axis, keepdim, reduce_all); + x_res, out_grad_res, axis.GetData(), keepdim, reduce_all); return Tensor(std::make_shared(op_res)); } diff --git a/paddle/fluid/primitive/backend/static_backend.h b/paddle/fluid/primitive/backend/static_backend.h index bd1fb737b8658ab54407acd2cf31ffb7e47441cd..09835bb75967418f2bdac7d35f59f08f182776dc 100644 --- a/paddle/fluid/primitive/backend/static_backend.h +++ b/paddle/fluid/primitive/backend/static_backend.h @@ -18,6 +18,7 @@ #include #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/common/int_array.h" namespace paddle { namespace primitive { @@ -25,6 +26,7 @@ namespace backend { namespace experimental { using Tensor = paddle::Tensor; +using IntArray = paddle::experimental::IntArray; template Tensor tanh_grad(const Tensor& out, const Tensor& grad_out); @@ -32,7 +34,7 @@ Tensor tanh_grad(const Tensor& out, const Tensor& grad_out); template Tensor mean_grad(const Tensor& x, const Tensor& out_grad, - std::vector axis = {}, + const IntArray& axis = {}, bool keepdim = false, bool reduce_all = false); } // namespace experimental diff --git a/paddle/fluid/primitive/rule/vjp/vjp.cc b/paddle/fluid/primitive/rule/vjp/vjp.cc index 28ffff5d9c7017ddfa493ff1bade18c5b669fbab..b5f0acf98c1d8f52649a2f68a8bb6e73edc8d1c7 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.cc +++ b/paddle/fluid/primitive/rule/vjp/vjp.cc @@ -12,12 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include - +#include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/ir/dialect/pd_api.h" #include "paddle/fluid/primitive/backend/static_backend.h" -#include "paddle/fluid/primitive/rule/vjp/vjp.h" #include "paddle/fluid/primitive/type/desc_tensor.h" #include "paddle/ir/core/operation.h" // TODO(wanghao107): @@ -26,10 +23,11 @@ namespace paddle { namespace primitive { namespace experimental { + std::vector> tanh_vjp( const Tensor& out, const Tensor& grad_out, - const std::vector>& stop_gradients) { + const std::vector>& stop_gradients) { std::vector> vjp_res( 1, std::vector(1)); // get tanh_grad res. @@ -71,10 +69,10 @@ std::vector> tanh_vjp( std::vector> mean_vjp( const Tensor& x, const Tensor& out_grad, - std::vector axis, + const IntArray& axis, bool keepdim, bool reduce_all, - const std::vector>& stop_gradients) { + const std::vector>& stop_gradients) { std::vector> vjp_res( 1, std::vector(1)); // get mean_grad res. diff --git a/paddle/fluid/primitive/rule/vjp/vjp.h b/paddle/fluid/primitive/rule/vjp/vjp.h index 9da7d57429bc37ee9897283426e03da8e60b694a..48bc2affa9db4a7b6b5cb32f7b904b38c65a6187 100644 --- a/paddle/fluid/primitive/rule/vjp/vjp.h +++ b/paddle/fluid/primitive/rule/vjp/vjp.h @@ -24,24 +24,27 @@ #include "paddle/fluid/primitive/primitive/primitive.h" #include "paddle/ir/core/value.h" #include "paddle/phi/api/include/tensor.h" +#include "paddle/phi/common/int_array.h" namespace paddle { namespace primitive { namespace experimental { + +using IntArray = paddle::experimental::IntArray; // TODO(wanghao107): // op's vjp will be auto generated. std::vector> tanh_vjp( const Tensor& out, const Tensor& grad_out, - const std::vector>& stop_gradients); + const std::vector>& stop_gradients); std::vector> mean_vjp( const Tensor& x, const Tensor& out_grad, - std::vector axis, + const IntArray& axis, bool keepdim, bool reduce_all, - const std::vector>& stop_gradients); + const std::vector>& stop_gradients); namespace details { // NOTE: this namespace will store diff --git a/paddle/fluid/primitive/type/desc_tensor.h b/paddle/fluid/primitive/type/desc_tensor.h index 60dc4e01377ebdf78174dbc3b8a99d423a64fa61..650b00e58ba7d1045bb55670007cb6dae1eb69a6 100644 --- a/paddle/fluid/primitive/type/desc_tensor.h +++ b/paddle/fluid/primitive/type/desc_tensor.h @@ -43,14 +43,11 @@ class DescTensor : public phi::ExtendedTensor, ir::Value getValue() const { return value_; } - const phi::Place& place() const override { return place_; } - bool initialized() const override { return value_.impl() != nullptr; } private: ir::Value value_; mutable phi::DDim dims_; - phi::Place place_; }; } // namespace experimental diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index bc0df124de9607c30638e8b2497d2be7c83138f8..4a005bc6dd37273bf09bd226b62c8fe6563807e6 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -693,7 +693,7 @@ void BindVjp(pybind11::module *m) { "call_vjp", [](ir::Operation &fwd_op, const std::vector> &out_grads, - const std::vector> &stop_gradients) { + const std::vector> &stop_gradients) { py::list res; ir::IrContext *ctx = ir::IrContext::Instance(); ir::OpInfo fwd_op_info = ctx->GetRegisteredOpInfo(fwd_op.name()); @@ -731,7 +731,7 @@ void BindVjp(pybind11::module *m) { vjp_res[i].size())); py::list sub_res; for (size_t j = 0; j < vjp_res[i].size(); ++j) { - if (stop_gradients[i][j]) { + if (!vjp_res[i][j]) { sub_res.append(nullptr); } else { sub_res.append(vjp_res[i][j]); diff --git a/python/paddle/autograd/backward.py b/python/paddle/autograd/backward.py index ba9d8a7a3f2e057a09f429909079783fc644b898..6ea1c491d4ad559cace67726617310b647818e0f 100644 --- a/python/paddle/autograd/backward.py +++ b/python/paddle/autograd/backward.py @@ -377,9 +377,9 @@ def append_backward_ops( input_grad_stopgradient_list = [] for input in op.operands_source(): if input in no_grad_set: - input_grad_stopgradient_list.append([1]) + input_grad_stopgradient_list.append([True]) else: - input_grad_stopgradient_list.append([0]) + input_grad_stopgradient_list.append([False]) before_ops_num = len(block.ops) # prim should be a globel flag, it will make create_grad_op choose diffrient func diff --git a/test/cpp/prim/test_vjp.cc b/test/cpp/prim/test_vjp.cc index 49cb6e29ab12c33101113ede50d5ab124b0e6503..9f7633c008176c7aa817bfd6426125690374030d 100644 --- a/test/cpp/prim/test_vjp.cc +++ b/test/cpp/prim/test_vjp.cc @@ -55,7 +55,7 @@ TEST(VJP, TanhBackwardTest) { paddle::dialect::FullOp op3 = builder->Build( std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); - std::vector> stop_gradients{{0}}; + std::vector> stop_gradients{{false}}; std::vector> out_grads{{op3.out()}}; ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.tanh"); @@ -109,7 +109,7 @@ TEST(VJP, Tanh_BackwardTest) { paddle::dialect::FullOp op3 = builder->Build( std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); - std::vector> stop_gradients{{0}}; + std::vector> stop_gradients{{false}}; std::vector> out_grads{{op3.out()}}; ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.tanh_"); @@ -163,7 +163,7 @@ TEST(VJP, MeanBackwardTest) { paddle::dialect::FullOp op3 = builder->Build( std::vector{}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); - std::vector> stop_gradients{{0}}; + std::vector> stop_gradients{{false}}; std::vector> out_grads{{op3.out()}}; ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.mean"); diff --git a/test/ir/new_ir/test_ir_vjp.py b/test/ir/new_ir/test_ir_vjp.py index 12931b89cca2a2aa013bc3d5685269f5d44c14db..45da7162664e493a3f5413ffac33a10816719ca8 100644 --- a/test/ir/new_ir/test_ir_vjp.py +++ b/test/ir/new_ir/test_ir_vjp.py @@ -41,7 +41,7 @@ class TestTanhVjp(unittest.TestCase): tanh_op = newir_program.block().ops[-2] fill_constant_op = newir_program.block().ops[-1] out_grads = [[fill_constant_op.result(0)]] - stop_gradients = [[0]] + stop_gradients = [[False]] with paddle.ir.core.program_guard(newir_program): grad_outs = call_vjp(tanh_op, out_grads, stop_gradients) self.assertEqual( @@ -72,7 +72,7 @@ class TestTanhVjp(unittest.TestCase): tanh_op = newir_program.block().ops[-2] fill_constant_op = newir_program.block().ops[-1] out_grads = [[fill_constant_op.result(0)]] - stop_gradients = [[1]] + stop_gradients = [[True]] with paddle.ir.core.program_guard(newir_program): grad_outs = call_vjp(tanh_op, out_grads, stop_gradients) self.assertEqual(grad_outs[0][0], None) @@ -93,7 +93,7 @@ class TestMeanVjp(unittest.TestCase): fill_constant_op = newir_program.block().ops[-1] mean_op = newir_program.block().ops[-2] out_grads = [[fill_constant_op.result(0)]] - stop_gradients = [[0]] + stop_gradients = [[False]] with paddle.ir.core.program_guard(newir_program): grad_outs = call_vjp(mean_op, out_grads, stop_gradients) self.assertEqual( @@ -133,7 +133,7 @@ class TestMeanVjp(unittest.TestCase): fill_constant_op = newir_program.block().ops[-1] mean_op = newir_program.block().ops[-2] out_grads = [[fill_constant_op.result(0)]] - stop_gradients = [[1]] + stop_gradients = [[True]] with paddle.ir.core.program_guard(newir_program): grad_outs = call_vjp(mean_op, out_grads, stop_gradients) self.assertEqual(grad_outs[0][0], None)