// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" #include "paddle/fluid/framework/new_executor/standalone_executor.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/api_builder.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_dialect.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_op.h" #include "paddle/fluid/ir/dialect/paddle_dialect/ir/pd_type.h" #include "paddle/fluid/ir/dialect/paddle_dialect/utils/utils.h" #include "paddle/fluid/ir/transforms/pd_op_to_kernel_pass.h" #include "paddle/fluid/platform/init_phi.h" #include "paddle/ir/core/block.h" #include "paddle/ir/core/builtin_attribute.h" #include "paddle/ir/core/builtin_op.h" #include "paddle/ir/core/ir_context.h" #include "paddle/ir/core/program.h" #include "paddle/ir/core/utils.h" DECLARE_FILE_SYMBOLS(kernel_dialect); PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(tanh, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(mean, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(tanh_grad, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(mean_grad, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(add_grad, CPU, ALL_LAYOUT); namespace paddle { namespace framework { TEST(VJP, TanhBackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); std::shared_ptr builder = paddle::dialect::APIBuilder::Instance().GetBuilder(); paddle::dialect::FullOp op1 = builder->Build( std::vector{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::TanhOp op2 = builder->Build(op1.out()); paddle::dialect::FullOp op3 = builder->Build( std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}}; std::vector> out_grads{{op3.out()}}; ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.tanh"); auto tanh_vjp_interface_impl = op2_info.GetInterfaceImpl(); tanh_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; ProgramDesc prog_desc; InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl())); std::string prefix_str = os.str(); test_core.SetSkipGcVars( {prefix_str + "_inner_var_1", prefix_str + "_inner_var_3"}); test_core.Run({}); auto out_tensor = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_1")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_1") ->Get(); auto grad_out_tensor = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_3")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_3") ->Get(); ASSERT_NEAR(out_tensor.data()[0], 0.76159, 1e-5); ASSERT_NEAR(grad_out_tensor.data()[0], 0.83995, 1e-5); } TEST(VJP, Tanh_BackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); std::shared_ptr builder = paddle::dialect::APIBuilder::Instance().GetBuilder(); paddle::dialect::FullOp op1 = builder->Build( std::vector{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::Tanh_Op op2 = builder->Build(op1.out()); paddle::dialect::FullOp op3 = builder->Build( std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}}; std::vector> out_grads{{op3.out()}}; ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.tanh_"); auto tanh_vjp_interface_impl = op2_info.GetInterfaceImpl(); tanh_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; ProgramDesc prog_desc; InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl())); std::string prefix_str = os.str(); test_core.SetSkipGcVars( {prefix_str + "_inner_var_0", prefix_str + "_inner_var_2"}); test_core.Run({}); auto out_tensor = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_0")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_0") ->Get(); auto grad_out_tensor = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_2")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_2") ->Get(); ASSERT_NEAR(out_tensor.data()[0], 0.76159, 1e-5); ASSERT_NEAR(grad_out_tensor.data()[0], 0.83995, 1e-5); } TEST(VJP, MeanBackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); std::shared_ptr builder = paddle::dialect::APIBuilder::Instance().GetBuilder(); paddle::dialect::FullOp op1 = builder->Build( std::vector{2, 2}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::MeanOp op2 = builder->Build(op1.out()); paddle::dialect::FullOp op3 = builder->Build( std::vector{}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}}; std::vector> out_grads{{op3.out()}}; ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.mean"); auto mean_vjp_interface_impl = op2_info.GetInterfaceImpl(); mean_vjp_interface_impl->vjp_(op2.operation(), out_grads, stop_gradients); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; ProgramDesc prog_desc; InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl())); std::string prefix_str = os.str(); test_core.SetSkipGcVars( {prefix_str + "_inner_var_1", prefix_str + "_inner_var_3"}); test_core.Run({}); auto out_tensor = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_1")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_1") ->Get(); auto grad_out_tensor = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_3")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_3") ->Get(); ASSERT_EQ(out_tensor.data()[0], 2.0); ASSERT_EQ(grad_out_tensor.data()[0], 0.25); ASSERT_EQ(grad_out_tensor.data()[1], 0.25); ASSERT_EQ(grad_out_tensor.data()[2], 0.25); ASSERT_EQ(grad_out_tensor.data()[3], 0.25); } TEST(VJP, ConcatBackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); std::shared_ptr builder = paddle::dialect::APIBuilder::Instance().GetBuilder(); paddle::dialect::FullOp op1 = builder->Build( std::vector{1, 2}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector combine_input{{op1.out(), op1.out()}}; ir::CombineOp op2 = builder->Build(combine_input); paddle::dialect::ConcatOp op3 = builder->Build(op2.out(), 0); paddle::dialect::FullOp op4 = builder->Build( std::vector{2, 2}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false, false}}; std::vector> out_grads{{op4.out()}}; ir::OpInfo op2_info = ctx->GetRegisteredOpInfo("pd.concat"); auto concat_vjp_interface_impl = op2_info.GetInterfaceImpl(); concat_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; ProgramDesc prog_desc; InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl())); std::string prefix_str = os.str(); test_core.SetSkipGcVars({prefix_str + "_inner_var_3", prefix_str + "_inner_var_7", prefix_str + "_inner_var_8"}); test_core.Run({}); auto out_tensor = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_3")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_3") ->Get(); auto grad_out_tensor_0 = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_7")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_7") ->Get(); auto grad_out_tensor_1 = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_8")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_8") ->Get(); ASSERT_EQ(out_tensor.data()[0], 2.0); ASSERT_EQ(grad_out_tensor_0.data()[0], 1.0); ASSERT_EQ(grad_out_tensor_0.data()[1], 1.0); ASSERT_EQ(grad_out_tensor_1.data()[0], 1.0); ASSERT_EQ(grad_out_tensor_1.data()[1], 1.0); } TEST(VJP, AddBackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); std::shared_ptr builder = paddle::dialect::APIBuilder::Instance().GetBuilder(); paddle::dialect::FullOp op1 = builder->Build( std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::FullOp op2 = builder->Build( std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::AddOp op3 = builder->Build(op1.out(), op2.out()); paddle::dialect::FullOp op4 = builder->Build( std::vector{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}, {false}}; std::vector> out_grads{{op4.out()}}; ir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd.add"); auto add_vjp_interface_impl = op3_info.GetInterfaceImpl(); add_vjp_interface_impl->vjp_(op3.operation(), out_grads, stop_gradients); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; ProgramDesc prog_desc; InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl())); std::string prefix_str = os.str(); test_core.SetSkipGcVars({prefix_str + "_inner_var_2", prefix_str + "_inner_var_4", prefix_str + "_inner_var_5"}); test_core.Run({}); auto out_tensor = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_2")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_2") ->Get(); auto dx = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_4")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_4") ->Get(); auto dy = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_5")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_5") ->Get(); ASSERT_EQ(out_tensor.data()[0], 4.0); ASSERT_EQ(dx.data()[0], 1.0); ASSERT_EQ(dy.data()[0], 1.0); } TEST(VJP, Add_BackwardTest) { ir::IrContext* ctx = ir::IrContext::Instance(); ctx->GetOrRegisterDialect(); ir::Program program((ctx)); paddle::dialect::APIBuilder::Instance().SetProgram(&program); std::shared_ptr builder = paddle::dialect::APIBuilder::Instance().GetBuilder(); paddle::dialect::FullOp op1 = builder->Build( std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::FullOp op2 = builder->Build( std::vector{1}, 2.0, phi::DataType::FLOAT32, phi::CPUPlace()); paddle::dialect::Add_Op op3 = builder->Build(op1.out(), op2.out()); paddle::dialect::FullOp op4 = builder->Build( std::vector{1}, 1.0, phi::DataType::FLOAT32, phi::CPUPlace()); std::vector> stop_gradients{{false}, {false}}; std::vector> out_grads{{op4.out()}}; ir::OpInfo op3_info = ctx->GetRegisteredOpInfo("pd.add_"); auto add_inplace_vjp_interface_impl = op3_info.GetInterfaceImpl(); add_inplace_vjp_interface_impl->vjp_( op3.operation(), out_grads, stop_gradients); auto kernel_program = paddle::dialect::PdOpLowerToKernelPass(&program); auto place = platform::CPUPlace(); Scope scope; ProgramDesc prog_desc; InterpreterCore test_core(place, {}, std::move(kernel_program), &scope); std::stringstream os; os << reinterpret_cast( const_cast(test_core.Impl())); std::string prefix_str = os.str(); test_core.SetSkipGcVars({prefix_str + "_inner_var_0", prefix_str + "_inner_var_3", prefix_str + "_inner_var_4"}); test_core.Run({}); auto out_tensor = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_0")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_0") ->Get(); auto dx = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_3")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_3") ->Get(); auto dy = test_core.local_scope() == nullptr ? scope.FindVar(prefix_str + "_inner_var_4")->Get() : test_core.local_scope() ->FindVar(prefix_str + "_inner_var_4") ->Get(); ASSERT_EQ(out_tensor.data()[0], 4.0); ASSERT_EQ(dx.data()[0], 1.0); ASSERT_EQ(dy.data()[0], 1.0); } } // namespace framework } // namespace paddle