// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" namespace paddle { namespace distributed { namespace { using OperatorBase = TaskNode::OperatorBase; } TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank, int64_t task_id, int64_t max_run_times) : program_(program), rank_(rank), task_id_(task_id), max_run_times_(max_run_times) { // TODO(liyurui): Will be removed when execute program is supported. Init(); } TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank) : program_(program), rank_(rank), task_id_(rank) { max_run_times_ = 1; LOG(INFO) << "Constructing TaskNode for DistModelInf. The TaskNode's id is: " << rank << ". And the TaskNode's max_run_time and max_slot_num will be set to 1."; } void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) { program_ = program; } void TaskNode::SetVarsToDtype( const std::map& vars_to_dtype) { vars_to_dtype_ = vars_to_dtype; } void TaskNode::SetVarsToShape( const std::map>& vars_to_shape) { vars_to_shape_ = vars_to_shape; } void TaskNode::Init(bool use_feed_fetch_ops) { if (!use_feed_fetch_ops) { VLOG(3) << "TaskNode will be inited without feed and fetch ops"; } if (ops_.empty()) { // Q (for fleet executor dev): should we need another reset funct? VLOG(3) << "Task node will be inited by calling Init()."; for (const auto& op_desc : program_->Block(0).AllOps()) { if (!use_feed_fetch_ops && (op_desc->Type() == "feed" || op_desc->Type() == "fetch")) { VLOG(3) << "TaskNode will skip [" << op_desc->Input("X")[0] << "], " << op_desc->Type() << " -> " << op_desc->Output("Out")[0]; continue; } ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*op_desc)); } for (const auto& op : ops_vec_) { ops_.emplace_back(op.get()); } } } TaskNode::TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times) : rank_(rank), task_id_(task_id), max_run_times_(max_run_times) {} TaskNode::TaskNode(int32_t role, const std::vector& op_descs, int64_t rank, int64_t task_id, int64_t max_run_times) : role_(role), rank_(rank), task_id_(task_id), max_run_times_(max_run_times) { if (op_descs.empty()) { return; } VLOG(3) << "Task node will be inited by providing list of ops."; for (const auto& desc : op_descs) { ops_vec_.emplace_back(framework::OpRegistry::CreateOp(*desc)); } for (const auto& op : ops_vec_) { ops_.emplace_back(op.get()); } } TaskNode::TaskNode(int32_t role, const std::vector& ops, int64_t rank, int64_t task_id, int64_t max_run_times) : ops_(ops), role_(role), rank_(rank), task_id_(task_id), max_run_times_(max_run_times) {} TaskNode::TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times) : role_(role), rank_(rank), task_id_(task_id), max_run_times_(max_run_times) {} bool TaskNode::AddUpstreamTask(int64_t task_id, int64_t buff_size, DependType type) { const auto& ret = upstream_.emplace(task_id, buff_size); id_to_dep_type_.emplace(task_id, type); return ret.second; } bool TaskNode::AddDownstreamTask(int64_t task_id, int64_t buff_size, DependType type) { const auto& ret = downstream_.emplace(task_id, buff_size); id_to_dep_type_.emplace(task_id, type); return ret.second; } std::string TaskNode::DebugString() const { std::ostringstream os; os << "role: " << role_ << ", task_id: " << task_id_ << "\n"; for (std::size_t i = 0; i < ops_.size(); ++i) { os << ops_[i]->Type() << " "; } os << "\n"; return os.str(); } void TaskNode::SetRunPerSteps(int64_t value) { PADDLE_ENFORCE_GE(value, 1, platform::errors::InvalidArgument( "run_per_steps must >= 1, but received %ld", value)); run_per_steps_ = value; } void TaskNode::SetRunAtOffset(int64_t value) { PADDLE_ENFORCE_GE(value, 0, platform::errors::InvalidArgument( "run_at_offset must >= 0, but received %ld", value)); run_at_offset_ = value; } void TaskNode::SetReplyUpPerSteps(int64_t value) { PADDLE_ENFORCE_GE( value, 1, platform::errors::InvalidArgument( "reply_up_per_steps must >= 1, but received %ld", value)); reply_up_per_steps_ = value; } void TaskNode::SetSendDownPerSteps(int64_t value) { PADDLE_ENFORCE_GE( value, 1, platform::errors::InvalidArgument( "send_down_per_steps must >= 1, but received %ld", value)); send_down_per_steps_ = value; } } // namespace distributed } // namespace paddle