// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/operators/controlflow/while_op_helper.h" #include #include #include #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/operators/controlflow/op_variant.h" #include "paddle/fluid/string/string_helper.h" namespace paddle { namespace operators { // Set skip variables of while_op and while_grad_op // These variables should be skipped when eager deletion enables. // It is because: // 1. while_grad_op needs some variables defined in while_op. // 2. while_grad_op needs variables from the previous time step. static void SetSkipVars(const OpVariant &op, std::vector attr) { auto &attrs = const_cast(op.Attrs()); VLOG(2) << "Prepare to skip " << attr.size() << " var(s): " << string::join_strings(attr, ' '); attrs[kSkipEagerDeletionVars] = std::move(attr); } // Check whether the forward while_op and while_grad_op match // The program may have many while_ops. static bool IsMatchedWhileOpAndWhileGradOp(const OpVariant &fwd_op, const OpVariant &grad_op) { return fwd_op.Inputs().at(kX) == grad_op.Inputs().at(kX) && fwd_op.Outputs().at(kOutputs) == grad_op.Inputs().at(kOutputs); } // Test whether the variable is skippable in forward while_op // The variable is skippable in while_op when the variable used in while_grad // is not from grad_block. static bool IsSkippableVar(const std::string &name, framework::BlockDesc *grad_block) { return name != framework::kEmptyVarName && !grad_block->HasVar(name); } static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op, const OpVariant &bwd_op) { auto *grad_block = bwd_op.Attr(kStepBlock); // Find all skippable variables in forward while_op std::unordered_set forward_skip_vars; for (auto *op_desc : grad_block->AllOps()) { for (auto &in_arg_name : op_desc->InputArgumentNames()) { if (IsSkippableVar(in_arg_name, grad_block)) { forward_skip_vars.insert(in_arg_name); } } for (auto &out_arg_name : op_desc->OutputArgumentNames()) { if (IsSkippableVar(out_arg_name, grad_block)) { forward_skip_vars.insert(out_arg_name); } } } SetSkipVars(fwd_op, std::vector(forward_skip_vars.begin(), forward_skip_vars.end())); // Find all skippable variables in while_grad_op // The skipped variables are those which would be used across time steps. auto &fwd_input = fwd_op.Inputs().at(kX); auto &in_grads = bwd_op.Outputs().at(framework::GradVarName(kX)); PADDLE_ENFORCE_EQ( fwd_input.size(), in_grads.size(), "Backward input gradient number does not match forward input number."); std::unordered_set backward_skip_vars; for (size_t i = 0; i < in_grads.size(); ++i) { if (in_grads[i] == framework::kEmptyVarName) { continue; } backward_skip_vars.insert(in_grads[i]); backward_skip_vars.insert(framework::GradVarName(fwd_input[i])); } SetSkipVars(bwd_op, std::vector(backward_skip_vars.begin(), backward_skip_vars.end())); } // Find all while_ops and while_grad_ops in the graph or program // The while_grad_op and while_op may located in different blocks // So we should traverse all blocks in the program and find them out. static void FindAllWhileAndWhileGradOp(const framework::ProgramDesc &program, std::vector *while_ops, std::vector *while_grad_ops) { PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size()); for (size_t i = 1; i < program.Size(); ++i) { auto &block = program.Block(i); for (size_t j = 0; j < block.OpSize(); ++j) { auto *op = block.Op(j); if (op->Type() == "while") { while_ops->emplace_back(op); } else if (op->Type() == "while_grad") { while_grad_ops->emplace_back(op); } } } PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size(), "There are extra while_grad ops in the graph or program"); } static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl( const framework::ProgramDesc &program, std::vector *while_ops, std::vector *while_grad_ops) { FindAllWhileAndWhileGradOp(program, while_ops, while_grad_ops); VLOG(2) << "Found while op num: " << while_ops->size() << ", while grad op num: " << while_grad_ops->size(); if (while_grad_ops->empty()) { return; } std::unordered_set while_op_set( while_ops->begin(), while_ops->end()); for (auto &bwd_op : *while_grad_ops) { const OpVariant *matched_fwd_op = nullptr; for (auto &fwd_op : while_op_set) { if (IsMatchedWhileOpAndWhileGradOp(fwd_op, bwd_op)) { PADDLE_ENFORCE(matched_fwd_op == nullptr, "Found multiple matched while ops"); matched_fwd_op = &fwd_op; } } PADDLE_ENFORCE_NOT_NULL(matched_fwd_op, "Cannot find matched forward while op."); ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op); while_op_set.erase(*matched_fwd_op); } } void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( const framework::ProgramDesc &program, int block_id, const std::vector> &all_ops) { // If block_id is not 0, returns // This is because all while_ops and while_grad_ops in the whole program // would be processed when block_id is 0 (i.e. when Executor::Run() or // ParallelExecutor constructs). // What's more, all while_ops and while_grad_ops must be processed when // block_id is zero. If not, while_op may run first and erase variables // used in while_grad_op, and in this moment, while_grad_ops may be not // constructed yet. if (block_id != 0) return; std::vector fwd_ops, bwd_ops; for (auto &op : all_ops) { if (op->Type() == "while") { fwd_ops.emplace_back(op.get()); } else if (op->Type() == "while_grad") { bwd_ops.emplace_back(op.get()); } } PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops, &bwd_ops); } void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( const framework::ProgramDesc &program, const std::vector &while_ops, const std::vector &while_grad_ops) { std::vector fwd_ops, bwd_ops; fwd_ops.reserve(while_ops.size()); for (auto *op : while_ops) { fwd_ops.emplace_back(op); } bwd_ops.reserve(while_grad_ops.size()); for (auto *op : while_grad_ops) { bwd_ops.emplace_back(op); } PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(program, &fwd_ops, &bwd_ops); } } // namespace operators } // namespace paddle