From 3e9601ba4943b36da375fdf50238474da760abab Mon Sep 17 00:00:00 2001 From: 0x45f <23097963+0x45f@users.noreply.github.com> Date: Wed, 9 Mar 2022 15:50:42 +0800 Subject: [PATCH] adapt run_program OP for eager (#40198) * adapt run_program OP for eager * fix program_id * refine code * fix test --- .../auto_code_generator/eager_generator.cc | 12 +- .../final_state_generator/eager_gen.py | 2 + .../eager/to_static/run_program_op_func.h | 82 +++ .../eager/to_static/run_program_op_node.h | 468 ++++++++++++++++++ .../fluid/pybind/custom_handwrite_op_funcs.h | 51 ++ .../pybind/eager_op_function_generator.cc | 24 +- paddle/fluid/pybind/eager_utils.cc | 60 +++ paddle/fluid/pybind/eager_utils.h | 7 + paddle/fluid/pybind/pybind.cc | 7 +- .../tests/unittests/test_eager_run_program.py | 120 +++++ 10 files changed, 823 insertions(+), 10 deletions(-) create mode 100644 paddle/fluid/eager/to_static/run_program_op_func.h create mode 100644 paddle/fluid/eager/to_static/run_program_op_node.h create mode 100644 paddle/fluid/pybind/custom_handwrite_op_funcs.h create mode 100644 python/paddle/fluid/tests/unittests/test_eager_run_program.py diff --git a/paddle/fluid/eager/auto_code_generator/eager_generator.cc b/paddle/fluid/eager/auto_code_generator/eager_generator.cc index 2fc846cccc..dc79a8a45a 100644 --- a/paddle/fluid/eager/auto_code_generator/eager_generator.cc +++ b/paddle/fluid/eager/auto_code_generator/eager_generator.cc @@ -47,6 +47,9 @@ std::unordered_map> static std::unordered_map operators_with_attrs = {}; +/* --- Black Ops list that's NO NEED to apply code generation --- */ +static std::unordered_set black_ops_list = {"run_program"}; + static std::string LegalizeVariableName(const std::string& var_name) { std::string ret = var_name; std::replace(ret.begin(), ret.end(), '-', '_'); // replace all '-' to '_' @@ -73,12 +76,6 @@ static bool IgnoreGradAttribute(const std::string& op_type, } static void PrepareAttrMapForOps() { - // Handle "run_program_op" - static framework::ProgramDesc fake_prog; - operators_with_attrs["run_program"] = {}; - operators_with_attrs["run_program"]["global_block"] = - fake_prog.MutableBlock(0); - // Handle "fused_elemwise_add_activation" std::vector functor_list = {"a", "b"}; operators_with_attrs["fused_elemwise_add_activation"] = {}; @@ -2349,6 +2346,9 @@ static void DygraphCodeGeneration(const std::string& output_dir) { if (!CheckOpProto(op_proto)) continue; const std::string& op_type = op_proto->type(); + if (black_ops_list.count(op_type)) { + continue; + } /* ----------------------------- */ /* ---- Collect Information ---- */ diff --git a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py index 81d0c9b7be..b594faa80a 100644 --- a/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py +++ b/paddle/fluid/eager/auto_code_generator/final_state_generator/eager_gen.py @@ -1000,6 +1000,7 @@ def GenerateNodeCCFile(filepath, node_definition_str): #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h" +#include "paddle/fluid/eager/to_static/run_program_op_node.h" """ file_contents += node_definition_str @@ -1042,6 +1043,7 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str): #include "paddle/phi/api/all.h" #include "paddle/fluid/eager/utils.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/eager/to_static/run_program_op_func.h" """ file_contents += GenerateCoreOpInfoDeclaration() diff --git a/paddle/fluid/eager/to_static/run_program_op_func.h b/paddle/fluid/eager/to_static/run_program_op_func.h new file mode 100644 index 0000000000..6f8bccd64e --- /dev/null +++ b/paddle/fluid/eager/to_static/run_program_op_func.h @@ -0,0 +1,82 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/eager/autograd_meta.h" +#include "paddle/fluid/eager/eager_tensor.h" +#include "paddle/fluid/eager/to_static/run_program_op_node.h" +#include "paddle/fluid/eager/utils.h" + +inline void run_program_dygraph_function( + const std::vector& x, + const std::vector& params, + std::vector& out, // NOLINT + std::vector& step_scope, // NOLINT + std::vector& dout, // NOLINT + const paddle::framework::AttributeMap& attrs) { + VLOG(2) << "start run run_program"; + // Call forward function + RunProgramAPI(x, params, out, step_scope, dout, attrs); + VLOG(2) << "start run run_program grad"; + + // Prepare Autograd Meta + auto deref_out = details::DereferenceTensors(out); + std::vector p_autograd_x = + egr::EagerUtils::nullable_autograd_meta(x); + std::vector p_autograd_params = + egr::EagerUtils::nullable_autograd_meta(params); + std::vector p_autograd_outs = + egr::EagerUtils::nullable_autograd_meta(deref_out); + + bool trace_backward = egr::Controller::Instance().HasGrad(); + bool require_any_grad = egr::EagerUtils::ComputeRequireGrad( + trace_backward, &p_autograd_x, &p_autograd_params); + + if (require_any_grad) { + std::vector out_names; + for (auto& t : deref_out) { + out_names.emplace_back(t.name()); + } + + egr::EagerUtils::PassStopGradient(false, &p_autograd_outs); + // Create GradOpNode (1 means [out_grad], 2 means [x_grad, paramx_grad]) + auto grad_node = std::make_shared(1, 2); + + grad_node->SetFwdOutNames(out_names); + // Set Attributes + grad_node->SetAttrMap(attrs); + // Set TensorWrappers + grad_node->SetFwdX(x); + grad_node->SetFwdParams(params); + grad_node->SetStepScope(step_scope); + + // Set Grad out rank as same as fwd input and set stop gradient to bwd + grad_node->SetGradOutMeta(&p_autograd_x, /*slot id*/ 0); + grad_node->SetGradOutMeta(&p_autograd_params, /*slot id*/ 1); + + grad_node->SetGradInMeta(&p_autograd_outs, 0); + // Set Next Edges + grad_node->AddEdges(&p_autograd_x, /*slot id*/ 0); + grad_node->AddEdges(&p_autograd_params, /*slot id*/ 1); + + egr::EagerUtils::SetOutRankWithSlot(&p_autograd_outs, 0); + + // Set History for output set current Grad Node for + egr::EagerUtils::SetHistory(&p_autograd_outs, grad_node); + egr::EagerUtils::CheckAndRetainGrad(deref_out); + } +} diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h new file mode 100644 index 0000000000..ae5d86664a --- /dev/null +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -0,0 +1,468 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/eager/api/utils/global_utils.h" +#include "paddle/fluid/eager/grad_node_info.h" +#include "paddle/fluid/eager/tensor_wrapper.h" + +#include "paddle/fluid/operators/run_program_op.h" +#include "paddle/fluid/platform/enforce.h" + +namespace details { +using Tensor = paddle::experimental::Tensor; + +static std::vector DereferenceTensors( + const std::vector &tensor_ptr) { + std::vector res; + for (auto *t : tensor_ptr) { + res.emplace_back(*t); + } + return res; +} + +static std::vector GetTensorsName(const std::vector &ins) { + std::vector in_names; + for (auto &in_t : ins) { + in_names.emplace_back(in_t.name()); + } + return in_names; +} + +static std::vector GetTensorsName( + const std::vector &ins) { + std::vector in_names; + for (auto *in_t : ins) { + in_names.emplace_back(in_t->name()); + } + return in_names; +} + +static void CheckInputVarStatus(const Tensor &tensor) { + PADDLE_ENFORCE_EQ( + tensor.defined() && phi::DenseTensor::classof(tensor.impl().get()), true, + paddle::platform::errors::InvalidArgument( + "The input tensor %s of " + "RunProgram(Grad)Op holds " + "wrong type. Expect type is DenseTensor.", + tensor.name())); + + PADDLE_ENFORCE_EQ(tensor.initialized(), true, + paddle::platform::errors::InvalidArgument( + "The tensor in input tensor %s of " + "RunProgram(Grad)Op " + "is not initialized.", + tensor.name())); +} + +static void CheckOutputVarStatus(const paddle::framework::Variable &src_var, + const Tensor &dst_tensor) { + auto name = dst_tensor.name(); + PADDLE_ENFORCE_EQ(dst_tensor.defined(), true, + paddle::platform::errors::InvalidArgument( + "dst_tensor shall be defined.")); + + if (phi::DenseTensor::classof(dst_tensor.impl().get())) { + auto &src_tensor = src_var.Get(); + PADDLE_ENFORCE_EQ(phi::DenseTensor::classof(&src_tensor), true, + paddle::platform::errors::InvalidArgument( + "The output tensor %s get from " + "RunProgram(Grad)Op's internal scope holds " + "wrong type. Expect type is DenseTensor", + name)); + PADDLE_ENFORCE_EQ(src_tensor.initialized(), true, + paddle::platform::errors::InvalidArgument( + "The tensor in output tensor %s get from " + "RunProgram(Grad)Op's internal " + "scope is not initialized.", + name)); + } else if (phi::SelectedRows::classof(dst_tensor.impl().get())) { + auto &src_tensor = src_var.Get(); + PADDLE_ENFORCE_EQ(phi::SelectedRows::classof(&src_tensor), true, + paddle::platform::errors::InvalidArgument( + "The output tensodfr %s get from " + "RunProgram(Grad)Op's internal scope holds " + "wrong type. Expect type is SelectedRows", + name)); + PADDLE_ENFORCE_EQ(src_tensor.initialized(), true, + paddle::platform::errors::InvalidArgument( + "The tensor in output tensor %s get from " + "RunProgram(Grad)Op's " + "internal scope is not initialized.", + name)); + + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "The RunProgram(Grad)Op only support output " + "variable of type LoDTensor or SelectedRows", + name)); + } +} + +static void ShareTensorsIntoScope(const std::vector &tensors, + paddle::framework::Scope *scope) { + for (size_t i = 0; i < tensors.size(); ++i) { + auto name = tensors[i].name(); + if (name == "Fake_var" || !tensors[i].is_initialized()) { + continue; + } + auto *var = scope->Var(name); + CheckInputVarStatus(tensors[i]); + // share tensor + auto tensor_base = tensors[i].impl(); + if (phi::DenseTensor::classof(tensor_base.get())) { + auto *dst_tensor = var->GetMutable(); + auto t = std::dynamic_pointer_cast(tensor_base); + *dst_tensor = *t; + } else if (phi::SelectedRows::classof(tensor_base.get())) { + auto *dst_tensor = var->GetMutable(); + auto t = std::dynamic_pointer_cast(tensor_base); + *dst_tensor = *t; + } + } +} + +static void ShareTensorsFromScope( + const std::vector &tensors, + const paddle::framework::BlockDesc &global_block, + paddle::framework::Scope *scope) { + for (size_t i = 0; i < tensors.size(); ++i) { + // NOTE: In case of setting out_tmp.stop_gradient = True in model code, all + // parameters before generating out_tmp have no @GRAD, it will raise error + // because we can't find them in scope. So we skip sharing these vars or + // var@GRAD if they don't appear in global block. + auto &name = tensors[i]->name(); + if (name == paddle::framework::kEmptyVarName || name == "Fake_var" || + !global_block.HasVar(name)) { + VLOG(2) << "find tensor name is " << name << ", skip it!"; + continue; + } + // NOTE: Here skip not found var is dangerous, if a bug is caused here, + // the result is grad calculation error, which will be very hidden! + auto *var = scope->FindVar(name); + PADDLE_ENFORCE_NOT_NULL(var, paddle::platform::errors::NotFound( + "The output tensor %s is not in " + "RunProgram(Grad)Op'" + "s internal scope.", + name)); + CheckOutputVarStatus(*var, *tensors[i]); + // share tensor + // TODO(dev): Determine Tensor type by scope.var + // auto tensor_base = tensors[i]->impl(); + // if (phi::DenseTensor::classof(tensor_base.get())) { + if (var->IsType()) { + auto &src_tensor = var->Get(); + auto *dst_tensor = const_cast( + dynamic_cast(tensors[i]->impl().get())); + VLOG(2) << "share " << name << " from scope"; + *dst_tensor = src_tensor; + } else if (var->IsType()) { + // } else if (phi::SelectedRows::classof(tensor_base.get())) { + auto &src_tensor = var->Get(); + auto *dst_tensor = const_cast( + dynamic_cast(tensors[i]->impl().get())); + *dst_tensor = src_tensor; + } + } +} + +} // namespace details + +inline void RunProgramAPI( + const std::vector &x, + const std::vector ¶ms, + std::vector &out, // NOLINT + std::vector &step_scope, // NOLINT + std::vector &dout, // NOLINT + const paddle::framework::AttributeMap &attrs) { + VLOG(2) << "RunProgramOpKernel Compute"; + auto start_op_index = BOOST_GET_CONST(int64_t, attrs.at("start_op_index")); + auto end_op_index = BOOST_GET_CONST(int64_t, attrs.at("end_op_index")); + auto is_test = BOOST_GET_CONST(bool, attrs.at("is_test")); + auto program_id = BOOST_GET_CONST(int64_t, attrs.at("program_id")); + + // NOTE(chenweihang): In order not to add new variable type, use vector + // here. Originally, here can use scope directly. + auto *out_scope_vec = &step_scope; + PADDLE_ENFORCE_EQ( + out_scope_vec->size(), 1, + paddle::platform::errors::InvalidArgument( + "The OutScope of RunProgramGradOp should only hold one scope.")); + + // Step 2. prepare executor and init persistable variables + + // NOTE(Aurelius84): While training some models, forward can be called many + // times and then apply backpropagation all at once, such as Reinforcement + // Learning. Tensor data in multi-step training should be saved into single + // scope separately. Otherwise, the gradients can be miscalculated because + // always using the Tensor data of the last step in forward. + paddle::framework::Scope *global_inner_scope = out_scope_vec->front(); + VLOG(2) << "The number of sub scopes before forward: " + << out_scope_vec->front()->kids().size(); + paddle::framework::Scope &scope = global_inner_scope->NewScope(); + + // share input_vars & parameters into scope + details::ShareTensorsIntoScope(x, &scope); + details::ShareTensorsIntoScope(params, &scope); + + auto *global_block = + BOOST_GET_CONST(paddle::framework::BlockDesc *, attrs.at("global_block")); + const auto &place = egr::Controller::Instance().GetExpectedPlace(); + + if (end_op_index > start_op_index) { + auto input_names = details::GetTensorsName(x); + auto output_names = details::GetTensorsName(out); + auto dout_names = details::GetTensorsName(dout); + auto *program = global_block->Program(); + + auto cache_info = paddle::framework::GetExecutorInfoFromCache( + *program, place, start_op_index, end_op_index, + /*is_grad=*/false, program_id, &scope); + auto ¶llel_executor = cache_info.first; + // all out_vars are skip_eager_var + auto &skip_eager_delete_vars = + paddle::framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( + program_id, false); + if (cache_info.second /*is_new_created*/) { + parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, input_names); + skip_eager_delete_vars.insert(skip_eager_delete_vars.end(), + output_names.begin(), output_names.end()); + skip_eager_delete_vars.insert(skip_eager_delete_vars.end(), + dout_names.begin(), dout_names.end()); + paddle::framework::details::ParseSafeEagerDeletionSkipVars( + *program, end_op_index, output_names, &skip_eager_delete_vars); + } + + // Step 3. run ops + parallel_executor->RunWithoutFetch(skip_eager_delete_vars); + } + // Step 4. Get Output + details::ShareTensorsFromScope(out, *global_block, &scope); + details::ShareTensorsFromScope(dout, *global_block, &scope); + + // Debug info: scope info when run end + VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); + // Step 5. Drop all children scopes while testing. + if (is_test) { + out_scope_vec->front()->DropKids(); + } + VLOG(2) << "The number of sub scopes after forward: " + << out_scope_vec->front()->kids().size(); + // #ifdef PADDLE_WITH_MKLDNN + // if (FLAGS_use_mkldnn) paddle::platform::DontClearMKLDNNCache(place); + // #endif +} + +inline void RunProgramGradAPI( + const std::vector &x, + const std::vector ¶ms, + const std::vector &out_grad, + const std::vector &step_scope, // NOLINT + const paddle::framework::AttributeMap &attrs, + std::vector &x_grad, // NOLINT + std::vector ¶ms_grad // NOLINT + ) { + // if all output vars are set to stop_gradient, grad op no need to executed + if (x_grad.empty() && params_grad.empty()) return; + + // TODO(dev): Remove this line hard code. And need to deal with the out_grad + // name problem. + // const_cast(out_grad[0]) + // .set_name("matmul_v2_0.tmp_0@GRAD"); + + auto *global_block = + BOOST_GET_CONST(paddle::framework::BlockDesc *, attrs.at("global_block")); + auto orig_end_op_index = BOOST_GET_CONST(int64_t, attrs.at("end_op_index")); + + auto program_id = BOOST_GET_CONST(int64_t, attrs.at("program_id")); + // NOTE: skip `shape` and `fill_constant` op created by + // fluid.backward.gradients, one forward output will generate one `shape` + // and `fill_constant` + int64_t start_op_index = orig_end_op_index + (out_grad.size() * 2); + int64_t end_op_index = global_block->OpSize(); + + auto *out_scope_vec = &step_scope; + PADDLE_ENFORCE_EQ( + out_scope_vec->size(), 1, + paddle::platform::errors::InvalidArgument( + "The OutScope of RunProgramGradOp should only hold one scope.")); + + paddle::framework::Scope *global_inner_scope = out_scope_vec->front(); + auto sub_scope_num = global_inner_scope->kids().size(); + VLOG(2) << "The number of sub scopes before backward: " << sub_scope_num; + PADDLE_ENFORCE_GT(sub_scope_num, 0, + paddle::platform::errors::InvalidArgument( + "The OutScope of RunProgramGradOp should hold at " + "least one sub scope.")); + + auto &scope = *(global_inner_scope->kids().front()); + const auto &place = egr::Controller::Instance().GetExpectedPlace(); + + if (end_op_index > start_op_index) { + auto out_grad_names = details::GetTensorsName(out_grad); + // NOTE: after PR22939 [Add double grad] merged, the grad op maker's + // SetOutput will set to None if the input var stop_gradient=True, + // it will cause an NotFound error when ctx.OutputNames() is called + std::vector x_grad_names; + std::vector param_grad_names; + if (!x_grad.empty()) { + x_grad_names = details::GetTensorsName(x_grad); + } + if (!params_grad.empty()) { + param_grad_names = details::GetTensorsName(params_grad); + } + + // Step 2. prepare executor and scope + auto *program = global_block->Program(); + auto cache_info = paddle::framework::GetExecutorInfoFromCache( + *program, place, start_op_index, end_op_index, + /*is_grad*/ true, program_id, &scope); + auto ¶llel_executor = cache_info.first; + + auto &skip_eager_delete_vars = + paddle::framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( + program_id, true); + if (cache_info.second /*is_new_created*/) { + parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, out_grad_names); + + skip_eager_delete_vars.insert(skip_eager_delete_vars.end(), + x_grad_names.begin(), x_grad_names.end()); + paddle::framework::details::AppendSkipDeletionVars( + param_grad_names, &skip_eager_delete_vars); + } + + details::ShareTensorsIntoScope(out_grad, &scope); + // Debug info: scope info when run end + VLOG(3) << paddle::framework::GenScopeTreeDebugInfo(out_scope_vec->front()); + + // Step 3. run ops + parallel_executor->RunWithoutFetch( + /*skip_eager_delete_vars=*/skip_eager_delete_vars); + } + + // Step 4. get outputs + details::ShareTensorsFromScope(x_grad, *global_block, &scope); + details::ShareTensorsFromScope(params_grad, *global_block, &scope); + + // Step5. drop current scope + // global_inner_scope->DeleteScope(&scope); + VLOG(2) << "The number of sub scopes after backward: " + << global_inner_scope->kids().size(); +} + +class GradNodeRunProgram : public egr::GradNodeBase { + public: + GradNodeRunProgram(size_t bwd_in_slot_num, size_t bwd_out_slot_num) + : egr::GradNodeBase(bwd_in_slot_num, bwd_out_slot_num) {} + + ~GradNodeRunProgram() override = default; + // Functor: perform backward computations + virtual std::vector> operator()( + const std::vector> &grads) + override { + VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram"; + PADDLE_ENFORCE_EQ( + grads.size(), 1, + paddle::platform::errors::InvalidArgument( + "The out_grads.size() of RunProgramGradOp should be equal to 1.")); + + VLOG(3) << "out_grads[0].size() : " << grads[0].size(); + std::vector x_grad; + std::vector params_grad; + ConstructGradTensors(x_, &x_grad); + ConstructGradTensors(params_, ¶ms_grad); + std::vector x_grad_ptr; + std::vector params_grad_ptr; + for (auto &i : x_grad) { + x_grad_ptr.emplace_back(&i); + } + for (auto &i : params_grad) { + params_grad_ptr.emplace_back(&i); + } + + // auto x_grad_ptr = ConstructGradTensors(x_); + // auto params_grad_ptr = ConstructGradTensors(params_); + + PADDLE_ENFORCE_EQ( + grads[0].size(), fwd_out_names_.size(), + paddle::platform::errors::InvalidArgument( + "The grads[0].size() and fwd_out_names_.size() should be equal.")); + for (size_t i = 0; i < fwd_out_names_.size(); ++i) { + const_cast(grads[0][i]) + .set_name(fwd_out_names_[i] + "@GRAD"); + } + + RunProgramGradAPI(x_, params_, grads[0], step_scope_, attrs_, x_grad_ptr, + params_grad_ptr); + VLOG(3) << "End Eager Backward Node: GradNodeRunProgram"; + return {x_grad, params_grad}; + // return {x_grad, details::DereferenceTensors(params_grad_ptr)}; + } + + // SetAttrMap + void SetAttrMap(const paddle::framework::AttributeMap &attrs) { + attrs_ = attrs; + } + + void SetFwdX(const std::vector &tensors) { + x_ = tensors; + } + + void SetFwdParams(const std::vector &tensors) { + params_ = tensors; + } + + void SetStepScope(const std::vector &scopes) { + step_scope_ = scopes; + } + + void SetFwdOutNames(std::vector out_names) { + fwd_out_names_ = out_names; + } + + protected: + void ConstructGradTensors( + const std::vector &fwd_tensors, + std::vector *grad_tensors) { + // TODO(dev): Need an elegant way to determine inforamtion of grad_tensor, + // such as: name, tensor type(DenseTensor or SelectedRows). + VLOG(3) << "fwd_tensors.size(): " << fwd_tensors.size(); + for (auto &fwd_t : fwd_tensors) { + grad_tensors->emplace_back(fwd_t.impl()); + auto &grad_t = grad_tensors->back(); + grad_t.set_name(fwd_t.name() + "@GRAD"); + } + } + + void ConstructGradTensors( + const std::vector &fwd_tensors) { + VLOG(3) << "fwd_tensors.size(): " << fwd_tensors.size(); + for (auto &fwd_t : fwd_tensors) { + auto grad_tesnor = egr::EagerUtils::unsafe_autograd_meta(fwd_t)->Grad(); + grad_tesnor.set_name(fwd_t.name() + "@GRAD"); + } + } + + private: + // TensorWrappers + std::vector x_; + std::vector params_; + std::vector step_scope_; + + std::vector fwd_out_names_; + + // Attribute Map + paddle::framework::AttributeMap attrs_; +}; diff --git a/paddle/fluid/pybind/custom_handwrite_op_funcs.h b/paddle/fluid/pybind/custom_handwrite_op_funcs.h new file mode 100644 index 0000000000..7a276df0d5 --- /dev/null +++ b/paddle/fluid/pybind/custom_handwrite_op_funcs.h @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include + +static PyObject *eager_api_run_program(PyObject *self, PyObject *args, + PyObject *kwargs) { + PyThreadState *tstate = nullptr; + try { + auto X = GetTensorListFromArgs("run_program", "X", args, 0, false); + auto Params = GetTensorListFromArgs("run_program", "Params", args, 1, true); + auto Out = GetTensorPtrListFromArgs("run_program", "Out", args, 2, false); + auto OutScope = + GetScopePtrListFromArgs("run_program", "OutScope", args, 3, false); + auto DOut = GetTensorPtrListFromArgs("run_program", "DOut", args, 4, true); + framework::AttributeMap attrs; + ConstructAttrMapFromPyArgs("run_program", args, 5, PyTuple_GET_SIZE(args), + attrs); + + tstate = PyEval_SaveThread(); + run_program_dygraph_function(X, Params, Out, OutScope, DOut, attrs); + std::cout << "end run_program_dygraph_function" << std::endl; + PyEval_RestoreThread(tstate); + tstate = nullptr; + } catch (...) { + if (tstate) { + PyEval_RestoreThread(tstate); + } + ThrowExceptionToPython(std::current_exception()); + } + Py_RETURN_NONE; +} + +static PyMethodDef CustomEagerFinalStateMethods[] = { + {"run_program", (PyCFunction)(void (*)(void))eager_api_run_program, + METH_VARARGS | METH_KEYWORDS, + "C++ interface function for run_program in dygraph."}, + + {nullptr, nullptr, 0, nullptr}}; diff --git a/paddle/fluid/pybind/eager_op_function_generator.cc b/paddle/fluid/pybind/eager_op_function_generator.cc index c15c171799..102cdbb91a 100644 --- a/paddle/fluid/pybind/eager_op_function_generator.cc +++ b/paddle/fluid/pybind/eager_op_function_generator.cc @@ -17,6 +17,7 @@ #include #include #include +#include #ifndef _WIN32 #include #endif @@ -129,6 +130,12 @@ static PyObject * %s(PyObject *self, PyObject *args, PyObject *kwargs) const char* PYBIND_ITEM_TEMPLATE = R"( {"%s", (PyCFunction)(void(*)(void))%s, METH_VARARGS | METH_KEYWORDS, "C++ interface function for %s in dygraph."},)"; +// These operators will skip automatical code generatrion and +// need to be handwritten in CUSTOM_HANDWRITE_OP_FUNC_FILE +std::unordered_set CUSTOM_HANDWRITE_OPS_SET = {"run_program"}; +const char* CUSTOM_HANDWRITE_OP_FUNC_FILE = + "#include \"paddle/fluid/pybind/custom_handwrite_op_funcs.h\"\n"; + // clang-format on static inline bool FindInsMap(const std::string& op_type, const std::string& in_name) { @@ -355,7 +362,7 @@ GenerateOpFunctions() { std::vector op_function_list, bind_function_list; auto& all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels(); - + bool append_custom_head_file = false; for (auto& pair : op_info_map) { auto& op_info = pair.second; auto op_proto = op_info.proto_; @@ -363,7 +370,12 @@ GenerateOpFunctions() { continue; } auto& op_type = op_proto->type(); - // Skip ooerator which is not inherit form OperatorWithKernel, like while, + // Skip operators that will be handwriten in CUSTOM_HANDWRITE_OP_FUNC_FILE. + if (CUSTOM_HANDWRITE_OPS_SET.count(op_type)) { + append_custom_head_file = true; + continue; + } + // Skip operator which is not inherit form OperatorWithKernel, like while, // since only OperatorWithKernel can run in dygraph mode. // if the phi lib contains op kernel, we still generate ops method if (!all_kernels.count(op_type) && @@ -380,6 +392,9 @@ GenerateOpFunctions() { op_function_list.emplace_back(std::move(op_function_str)); bind_function_list.emplace_back(std::move(bind_function_str)); } + if (append_custom_head_file) { + op_function_list.emplace_back(CUSTOM_HANDWRITE_OP_FUNC_FILE); + } return std::make_tuple(op_function_list, bind_function_list); } @@ -449,6 +464,11 @@ int main(int argc, char* argv[]) { << " PADDLE_THROW(platform::errors::Fatal (\"Add functions to " "core.eager.ops failed!\"));\n" << " }\n\n" + << " if (PyModule_AddFunctions(m.ptr(), CustomEagerFinalStateMethods) < " + "0) {\n" + << " PADDLE_THROW(platform::errors::Fatal (\"Add functions to " + "core.eager.ops failed!\"));\n" + << " }\n\n" << "}\n\n" << "} // namespace pybind\n" << "} // namespace paddle\n"; diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 0cfb08345b..f4e148cf8d 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/eager/api/all.h" #include "paddle/fluid/eager/autograd_meta.h" #include "paddle/fluid/framework/convert_utils.h" +#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope_guard.h" #include "paddle/fluid/memory/allocation/allocator.h" #include "paddle/fluid/operators/py_func_op.h" @@ -35,6 +36,7 @@ namespace pybind { extern PyTypeObject* p_tensor_type; +extern PyTypeObject* g_framework_scope_pytype; extern PyTypeObject* g_vartype_pytype; extern PyTypeObject* g_place_pytype; extern PyTypeObject* g_cudaplace_pytype; @@ -830,6 +832,64 @@ paddle::experimental::ScalarArray CastPyArg2ScalarArray( return paddle::experimental::ScalarArray({1}); } +paddle::framework::Scope* CastPyArg2ScopePtr(PyObject* obj) { + if (PyObject_IsInstance( + obj, reinterpret_cast(g_framework_scope_pytype))) { + return ::pybind11::handle(obj).cast(); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "PyObject can not be cast into framework::Scope")); + } +} + +std::vector GetScopePtrListFromArgs( + const std::string& op_type, const std::string& arg_name, PyObject* args, + ssize_t arg_idx, bool dispensable) { + PyObject* list = PyTuple_GET_ITEM(args, arg_idx); + if (list == nullptr) { + if (!dispensable) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of scope, but got " + "None", + op_type, arg_name, arg_idx)); + } + } + + std::vector result; + if (PyList_Check(list)) { + Py_ssize_t len = PyList_Size(list); + if (len == 0) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of scope, but got " + "empty list", + op_type, arg_name, arg_idx)); + } + for (Py_ssize_t i = 0; i < len; i++) { + result.emplace_back(CastPyArg2ScopePtr(PyList_GetItem(list, i))); + } + } else if (PyTuple_Check(list)) { + Py_ssize_t len = PyTuple_Size(list); + if (len == 0) { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of scope, but got " + "empty list", + op_type, arg_name, arg_idx)); + } + for (Py_ssize_t i = 0; i < len; i++) { + result.emplace_back(CastPyArg2ScopePtr(PyList_GetItem(list, i))); + } + } else if (list == Py_None) { + return {}; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "%s(): argument '%s' (position %d) must be list of Tensors, but got " + "%s", + op_type, arg_name, arg_idx, + (reinterpret_cast(list->ob_type))->tp_name)); + } + return result; +} + paddle::experimental::Backend CastPyArg2Backend(PyObject* obj, const std::string& op_type, ssize_t arg_pos) { diff --git a/paddle/fluid/pybind/eager_utils.h b/paddle/fluid/pybind/eager_utils.h index c5da1bb37a..966a920377 100644 --- a/paddle/fluid/pybind/eager_utils.h +++ b/paddle/fluid/pybind/eager_utils.h @@ -20,6 +20,10 @@ limitations under the License. */ #include "pybind11/pybind11.h" #include "pybind11/stl.h" namespace paddle { +namespace framework { +class Scope; +} + namespace pybind { typedef struct { @@ -134,6 +138,9 @@ std::vector GetTensorPtrListFromArgs( ssize_t arg_idx, bool dispensable = false); // end of Slice related methods +std::vector GetScopePtrListFromArgs( + const std::string& op_type, const std::string& arg_name, PyObject* args, + ssize_t arg_idx, bool dispensable); } // namespace pybind } // namespace paddle diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index fcfc3e6a37..566e38b7a2 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -175,6 +175,7 @@ namespace paddle { namespace pybind { PyTypeObject *g_place_pytype = nullptr; +PyTypeObject *g_framework_scope_pytype = nullptr; PyTypeObject *g_cudaplace_pytype = nullptr; PyTypeObject *g_cpuplace_pytype = nullptr; PyTypeObject *g_xpuplace_pytype = nullptr; @@ -1352,7 +1353,7 @@ All parameter, weight, gradient are variables in Paddle. BindReader(&m); - py::class_(m, "_Scope", R"DOC( + py::class_ _Scope(m, "_Scope", R"DOC( Scope is an association of a name to Variable. All variables belong to Scope. Variables in a parent scope can be retrieved from local scope. @@ -1372,7 +1373,9 @@ All parameter, weight, gradient are variables in Paddle. param_array = np.full((height, row_numel), 5.0).astype("float32") param.set(param_array, place) - )DOC") + )DOC"); + g_framework_scope_pytype = reinterpret_cast(_Scope.ptr()); + _Scope .def("_remove_from_pool", [](Scope &self) { ScopePool::Instance().Remove(&self); }) .def("var", diff --git a/python/paddle/fluid/tests/unittests/test_eager_run_program.py b/python/paddle/fluid/tests/unittests/test_eager_run_program.py new file mode 100644 index 0000000000..fc6a5d60ec --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_eager_run_program.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import numpy as np +from paddle import _C_ops +from paddle.fluid.framework import _test_eager_guard, Variable +from paddle.fluid import core +from paddle.fluid.layers.utils import _hash_with_id +import paddle.compat as cpt + +import unittest + + +def _append_backward_desc(main_program, outs): + # make sure all status of is_test are False in train mode. + program = main_program.clone() + targets = [] + for out in outs: + if isinstance(out, Variable): + targets.append(program.global_block().var(out.name)) + + if targets: + paddle.fluid.backward.gradients(targets=targets, inputs=[]) + + return program + + +# def _set_grad_type(params, train_program): +# # NOTE: if user set sparse gradient mode, the param's gradient +# # will be SelectedRows, not LoDTensor. But tracer will just +# # set param grad VarBase by forward VarBase(LoDTensor) +# # If we don't change grad_var type here, RunProgramOp need +# # transform SelectedRows to LoDTensor forcibly, it may not +# # be user wanted result. +# for param in params: +# grad_name = param.name + core.grad_var_suffix() +# grad_var = train_program.desc.block(0).find_var( +# cpt.to_bytes(grad_name)) +# # NOTE: cannot find var desc maybe no problem, such as in batch_norm +# if grad_var is None: +# continue +# param._set_grad_type(grad_var.type()) + + +def _create_out(var): + assert isinstance(var, Variable) + var_desc = var.desc + varbase = None + if not core._in_eager_mode(): + var_base = core.VarBase(var_desc.dtype(), + var_desc.shape(), + var_desc.name(), var_desc.type(), False) + else: + var_base = core.eager.Tensor(var_desc.dtype(), + var_desc.shape(), + var_desc.name(), var_desc.type(), False) + return var_base + + +class TestRunProgram(unittest.TestCase): + def test_eager(self): + paddle.set_device('cpu') + paddle.enable_static() + # step 1: construct program + x = paddle.static.data(shape=[2, 4], name='x') + x.stop_gradient = False + y = paddle.static.data(shape=[4, 2], name='y') + y.stop_gradient = False + out = paddle.matmul(x, y) + + main_program = paddle.static.default_main_program() + program = _append_backward_desc(main_program, [out]) + + paddle.disable_static('cpu') + # step 2: call run_program in eager mode + with _test_eager_guard(): + x_t = paddle.ones([2, 4]) + x_t.name = "x" + x_t.stop_gradient = False + y_t = paddle.ones([4, 2]) + y_t.name = "y" + y_t.stop_gradient = False + + fake_var = paddle.zeros([1]) + fake_var.name = 'Fake_var' + + out_t = _create_out(out) + + scope = core.Scope() + attrs = ('global_block', program.desc.block(0), 'start_op_index', 0, + 'end_op_index', main_program.desc.block(0).op_size(), + 'is_test', False, 'program_id', _hash_with_id(program)) + + _C_ops.run_program([x_t, y_t], [fake_var], [out_t], [scope], + [fake_var], *attrs) + + loss = paddle.mean(out_t) + loss.backward() + + self.assertTrue(np.array_equal(np.ones([2, 2]) * 4, out_t.numpy())) + self.assertTrue( + np.array_equal(np.ones([2, 4]) * 0.5, x_t.grad.numpy())) + self.assertTrue( + np.array_equal(np.ones([4, 2]) * 0.5, y_t.grad.numpy())) + + +if __name__ == '__main__': + unittest.main() -- GitLab