// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/executor_cache.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/variable.h" #include "paddle/fluid/jit/base_function.h" #include "paddle/fluid/jit/function_schema.h" #include "paddle/fluid/jit/layer_utils.h" namespace paddle { namespace jit { class PEFunction : public BaseFunction { public: PEFunction(const std::shared_ptr &info, const Name2VariableMap ¶ms_dict, const phi::Place &place) : info_(info), place_(place) { ShareParamsIntoScope(info_->GetParamNames(), params_dict, &scope_); VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_); } ~PEFunction() noexcept {} std::vector operator()(const std::vector &inputs) { // bool is_test = true; std::string prog_string; std::hash string_hash; auto &program_desc = info_->GetProgramDesc(); const_cast(&program_desc) ->Proto() ->SerializePartialToString(&prog_string); // program_desc.Proto()->SerializePartialToString(&prog_string); int64_t program_id = static_cast(string_hash(prog_string)); const framework::BlockDesc &global_block = program_desc.Block(0); int64_t start_op_index = 0; int64_t end_op_index = static_cast(global_block.OpSize()); ShareInputsIntoScope(info_->GetInputArgNames(), inputs, &scope_); std::vector input_var_names = info_->GetInputArgNames(); std::vector output_var_names = info_->GetOutputArgNames(); std::vector dout_var_names; if (end_op_index > start_op_index) { // TODO(dev): support other devices auto cache_info = framework::GetExecutorInfoFromCache(program_desc, place_, start_op_index, end_op_index, /*is_grad=*/false, program_id, &scope_); auto ¶llel_executor = cache_info.first; auto &skip_eager_delete_vars = framework::ExecutorInfoCache::Instance().SkipEagerDeleteVars( program_id, false); if (cache_info.second /*is_new_created*/) { parallel_executor->SkipMemoryReuse(/*scope_idx=*/0, input_var_names); skip_eager_delete_vars.insert(skip_eager_delete_vars.end(), output_var_names.begin(), output_var_names.end()); skip_eager_delete_vars.insert(skip_eager_delete_vars.end(), dout_var_names.begin(), dout_var_names.end()); framework::details::ParseSafeEagerDeletionSkipVars( program_desc, end_op_index, output_var_names, &skip_eager_delete_vars); } parallel_executor->RunWithoutFetch(skip_eager_delete_vars); } VLOG(6) << framework::GenScopeTreeDebugInfo(&scope_); std::vector res; FetchVarsByNames(info_->GetOutputArgNames(), scope_, &res); return res; } private: std::shared_ptr info_; framework::Scope scope_; phi::Place place_; }; } // namespace jit } // namespace paddle