未验证 提交 5fb1e824 编写于 作者: R Ruibiao Chen 提交者: GitHub

Improve performance of coalesce_tensor and depend op in standalone executor (#47606)

* Dispath computation OPs before communication in standalone executor

* Update code

* Fix CI errors

* Improve performance of coalesce_tensor and depend OP in standalone executor

* pre-commit check
上级 7648f429
...@@ -84,7 +84,6 @@ bool DependencyBuilder::OpHappensBefore(int prior_op_idx, ...@@ -84,7 +84,6 @@ bool DependencyBuilder::OpHappensBefore(int prior_op_idx,
} }
void DependencyBuilder::AddDependencyForCoalesceTensorOp() { void DependencyBuilder::AddDependencyForCoalesceTensorOp() {
const std::string kCoalesceTensor = "coalesce_tensor";
for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) { for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) {
if (instructions_->at(op_idx).OpBase()->Type() == kCoalesceTensor) { if (instructions_->at(op_idx).OpBase()->Type() == kCoalesceTensor) {
VLOG(4) << "Add depend for " << kCoalesceTensor << " " << op_idx; VLOG(4) << "Add depend for " << kCoalesceTensor << " " << op_idx;
......
...@@ -300,6 +300,30 @@ void BuildVariableScope(const framework::BlockDesc& block, ...@@ -300,6 +300,30 @@ void BuildVariableScope(const framework::BlockDesc& block,
} }
} }
OpFuncType AnalyseOpFuncType(const OpFuncNode& op_func_node,
const platform::Place& place) {
if (platform::is_cpu_place(place)) {
return OpFuncType::kQueueSync;
}
PADDLE_ENFORCE_EQ(IsSupportedHeterPlace(place),
true,
phi::errors::Fatal("Unsupported current place %s", place));
// Some GPU OPs do not launch CUDA Kernel, but spend a lot of time on CPU
// computing. They execute serially in device thread and block CUDA kernel
// launching in other GPU OPs. To improve performance, set them as kQueueSync
// and so that they would be dispatched to host thread.
std::shared_ptr<OperatorBase> op = op_func_node.operator_base_;
if (op->Type() == kCoalesceTensor &&
op->Attr<bool>("set_constant") == false &&
op->Attr<bool>("copy_data") == false) {
return OpFuncType::kQueueSync;
}
return OpFuncType::kQueueAsync;
}
void CreateAllOps(const framework::BlockDesc& block, void CreateAllOps(const framework::BlockDesc& block,
std::vector<std::unique_ptr<OperatorBase>>* ops) { std::vector<std::unique_ptr<OperatorBase>>* ops) {
for (auto& op : block.AllOps()) { for (auto& op : block.AllOps()) {
...@@ -448,14 +472,7 @@ void HandleOperatorBase(const platform::Place& place, ...@@ -448,14 +472,7 @@ void HandleOperatorBase(const platform::Place& place,
auto* dev_ctx = pool.Get(place); auto* dev_ctx = pool.Get(place);
// input, output is prepared. set the other attributes. // input, output is prepared. set the other attributes.
op_func_node->operator_base_ = op_base; op_func_node->operator_base_ = op_base;
if (IsSupportedHeterPlace(place)) { op_func_node->type_ = AnalyseOpFuncType(*op_func_node, place);
op_func_node->type_ = OpFuncType::kQueueAsync;
} else if (platform::is_cpu_place(place)) {
op_func_node->type_ = OpFuncType::kQueueSync;
} else {
PADDLE_THROW(
platform::errors::Fatal("Unsupported current place %s", place));
}
op_func_node->kernel_func_ = nullptr; op_func_node->kernel_func_ = nullptr;
op_base->Run(*local_scope, place); // Run without data transformer. op_base->Run(*local_scope, place); // Run without data transformer.
std::unordered_set<int> no_data_transform_index; std::unordered_set<int> no_data_transform_index;
...@@ -663,14 +680,9 @@ void BuildOpFuncList(const platform::Place& place, ...@@ -663,14 +680,9 @@ void BuildOpFuncList(const platform::Place& place,
dev_ctx = pool.Get(kernel_type.place_); dev_ctx = pool.Get(kernel_type.place_);
} }
op_func_node.dev_ctx_ = dev_ctx; op_func_node.dev_ctx_ = dev_ctx;
if (IsSupportedHeterPlace(kernel_type.place_)) { op_func_node.type_ =
op_func_node.type_ = OpFuncType::kQueueAsync; AnalyseOpFuncType(op_func_node, kernel_type.place_);
} else if (platform::is_cpu_place(kernel_type.place_)) {
op_func_node.type_ = OpFuncType::kQueueSync;
} else {
PADDLE_THROW(platform::errors::Fatal("Unsupported current place %s",
kernel_type.place_));
}
VLOG(3) << op_with_kernel->Type() VLOG(3) << op_with_kernel->Type()
<< " : finally selected kernel_key: " << kernel_type; << " : finally selected kernel_key: " << kernel_type;
......
...@@ -420,7 +420,7 @@ void InterpreterCore::BuildInplace() { ...@@ -420,7 +420,7 @@ void InterpreterCore::BuildInplace() {
std::set<std::string> skip_inplace_outvars; std::set<std::string> skip_inplace_outvars;
for (Instruction& instr : vec_instruction_) { for (Instruction& instr : vec_instruction_) {
OperatorBase* op = instr.OpBase(); OperatorBase* op = instr.OpBase();
if (op->Type() == "coalesce_tensor") { if (op->Type() == kCoalesceTensor) {
const std::vector<std::string>& outputs = const std::vector<std::string>& outputs =
op->OutputVars(/*has_intermediate=*/false); op->OutputVars(/*has_intermediate=*/false);
skip_inplace_outvars.insert(outputs.begin(), outputs.end()); skip_inplace_outvars.insert(outputs.begin(), outputs.end());
...@@ -897,8 +897,9 @@ void InterpreterCore::RunNextInstructions( ...@@ -897,8 +897,9 @@ void InterpreterCore::RunNextInstructions(
int64_t first_op = -1; int64_t first_op = -1;
for (auto next_id : direct_run_ops) { for (auto next_id : direct_run_ops) {
if (IsReady(next_id)) { if (IsReady(next_id)) {
// only keep one op running in current thread // only keep one sync op running in current thread
if (first_op == -1) { if (first_op == -1 &&
vec_instruction_[next_id].KernelType() == OpFuncType::kQueueSync) {
first_op = next_id; first_op = next_id;
continue; continue;
} }
...@@ -935,11 +936,11 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) { ...@@ -935,11 +936,11 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
try { try {
interpreter::WaitEvent(instr_node, place_); interpreter::WaitEvent(instr_node, place_);
if (!instr_node.IsArtificial()) {
RunInstruction(instr_node); RunInstruction(instr_node);
CheckGC(instr_node); CheckGC(instr_node);
interpreter::LogDeviceMemoryStats(place_); interpreter::LogDeviceMemoryStats(place_);
}
interpreter::RecordEvent(instr_node, place_); interpreter::RecordEvent(instr_node, place_);
} catch (platform::EnforceNotMet& ex) { } catch (platform::EnforceNotMet& ex) {
......
...@@ -675,7 +675,8 @@ Instruction::Instruction(size_t id, ...@@ -675,7 +675,8 @@ Instruction::Instruction(size_t id,
OpFuncNode&& op_func_node, OpFuncNode&& op_func_node,
const platform::DeviceContext& dev_ctx, const platform::DeviceContext& dev_ctx,
const Priority priority) const Priority priority)
: id_(id), : is_artificial_(op_func_node.operator_base_->Type() == "depend"),
id_(id),
op_func_node_(op_func_node), op_func_node_(op_func_node),
dev_ctx_(dev_ctx), dev_ctx_(dev_ctx),
priority_(priority) { priority_(priority) {
......
...@@ -32,7 +32,7 @@ namespace framework { ...@@ -32,7 +32,7 @@ namespace framework {
using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>; using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
constexpr int kEmptyVarIndex = 0; constexpr const char* kCoalesceTensor = "coalesce_tensor";
// stream types // stream types
constexpr const char* kCustomStream = "CustromStream"; constexpr const char* kCustomStream = "CustromStream";
...@@ -40,6 +40,8 @@ constexpr const char* kDefaultStream = "DefaultStream"; ...@@ -40,6 +40,8 @@ constexpr const char* kDefaultStream = "DefaultStream";
constexpr const char* kD2HStream = "D2HStream"; constexpr const char* kD2HStream = "D2HStream";
constexpr const char* kH2DStream = "H2DStream"; constexpr const char* kH2DStream = "H2DStream";
constexpr int kEmptyVarIndex = 0;
enum class Priority { kLowest, kNormal }; enum class Priority { kLowest, kNormal };
class InterpretercoreInferShapeContext : public InferShapeContext { class InterpretercoreInferShapeContext : public InferShapeContext {
...@@ -305,6 +307,8 @@ class Instruction { ...@@ -305,6 +307,8 @@ class Instruction {
const platform::DeviceContext& dev_ctx, const platform::DeviceContext& dev_ctx,
const Priority priority); const Priority priority);
bool IsArtificial() const { return is_artificial_; }
size_t Id() const; size_t Id() const;
const std::map<std::string, std::vector<int>>& Inputs() const; const std::map<std::string, std::vector<int>>& Inputs() const;
...@@ -368,6 +372,9 @@ class Instruction { ...@@ -368,6 +372,9 @@ class Instruction {
Priority GetPriority() const { return priority_; } Priority GetPriority() const { return priority_; }
private: private:
bool is_artificial_; // Instruction is artificial means that it is only used
// to assist scheduling and no need to be executed.
size_t id_; size_t id_;
OpFuncNode op_func_node_; OpFuncNode op_func_node_;
const platform::DeviceContext& dev_ctx_; // not owned const platform::DeviceContext& dev_ctx_; // not owned
......
...@@ -239,11 +239,15 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext( ...@@ -239,11 +239,15 @@ platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
*/ */
bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr, bool StreamAnalyzer::IsDirectRun(Instruction& cur_instr,
const Instruction& next_instr) { const Instruction& next_instr) {
if (&cur_instr.DeviceContext() == &next_instr.DeviceContext()) return true; if (cur_instr.KernelType() == next_instr.KernelType() &&
(&cur_instr.DeviceContext() == &next_instr.DeviceContext())) {
return true;
}
// xpu&ipu memcpy kerenl is synchronous. // xpu&ipu memcpy kerenl is synchronous.
if (platform::is_ipu_place(place_) || platform::is_xpu_place(place_)) if (platform::is_ipu_place(place_) || platform::is_xpu_place(place_)) {
return true; return true;
}
// npu d2h kernel is asynchronous. // npu d2h kernel is asynchronous.
if (platform::is_npu_place(place_) || platform::is_custom_place(place_)) { if (platform::is_npu_place(place_) || platform::is_custom_place(place_)) {
......
...@@ -84,12 +84,17 @@ y = opB(x) ...@@ -84,12 +84,17 @@ y = opB(x)
} }
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERER(DependNoNeedBufferVarsInferer, "X", "Dep");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR( REGISTER_OPERATOR(
depend, depend,
paddle::operators::DependOp, ops::DependOp,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::DependOpProtoMaker); ops::DependOpProtoMaker,
ops::DependNoNeedBufferVarsInferer);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册