未验证 提交 5bac67d4 编写于 作者: R Ruibiao Chen 提交者: GitHub

Improve new executor static build (#51149)

* Improve new executor static build

* Skip GC for static build

* Skip infershape for static build

* Handle read_op

* Add fused_attention to OpsWithFluidKernelNeedMoveToPhi

* Fix argsort typos

* Add sequence_pool to OpsWithFluidKernelNeedMoveToPhi

* Fix skip share lod errors

* Fix errors for adam

* Fix errors for eigvals, memcpy and fake_quantize

* Add static_build.cc

* Add black list

* Fix CI errors

* Fix CI errors

* Fix CI errors

* Fix TensorArray

* Fix TensorArray

* Add update_loss_scaling to OpsNeedSetOutputDtypeWhenRegisterPhiKernel

* Fix copy

* Fix errors

* Fix momentum

* Skip mkldnn

* Fix CI errors

* Fix c_sync_calc_stream_op

* Fix CINN

* Fix while op

* All CI pass, disable FLAGS to merge code, enable it after more tests in future

* Add UTs

* Fix typos

* Fix typos

* Add mkldnn UT

* Remove mkldnn test

* Fix typos

* Fix dist test

* Fix typos

* Fix CI errors

* Fix CI errors

* Add UTs

* Fix typos

* Fix typos

* Add sparse tests

* ToComplexType -> ToComplex

* Add test_matmul_op_static_build to disable_win_inference_test
上级 076bc5d6
set(INTERPRETER_SRCS data_transfer.cc dependency_builder.cc execution_config.cc
interpreter_util.cc stream_analyzer.cc)
interpreter_util.cc static_build.cc stream_analyzer.cc)
set(INTERPRETER_DEPS
buffered_reader
device_context
global_utils
op_registry
phi_tensor_utils
scope
framework_proto
data_feed_proto
......
......@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/static_build.h"
#include "paddle/phi/core/kernel_context.h"
#include "paddle/phi/core/kernel_factory.h"
......@@ -37,7 +38,7 @@ bool DataTranferHelper::apply(const phi::KernelKey& kernel_type_for_var,
std::vector<OpFuncNode>* op_func_nodes,
bool use_local_scope,
bool is_fetch_v2,
bool skip_run) {
bool static_build) {
bool is_transferred = false;
auto* src_var_name = &var_name;
......@@ -52,7 +53,7 @@ bool DataTranferHelper::apply(const phi::KernelKey& kernel_type_for_var,
is_fetch_v2);
if (op) {
RunAndConstructOpFuncNode(
op, *src_var_name, *new_var_name, op_func_nodes, skip_run);
op, *src_var_name, *new_var_name, op_func_nodes, static_build);
}
// update src_var_name
src_var_name = new_var_name;
......@@ -70,7 +71,7 @@ bool DataTranferHelper::apply(const phi::KernelKey& kernel_type_for_var,
scope_);
if (op) {
RunAndConstructOpFuncNode(
op, *src_var_name, *new_var_name, op_func_nodes, skip_run);
op, *src_var_name, *new_var_name, op_func_nodes, static_build);
}
// update src_var_name
src_var_name = new_var_name;
......@@ -87,7 +88,7 @@ bool DataTranferHelper::apply(const phi::KernelKey& kernel_type_for_var,
*src_var_name, new_var_name, src_place, dst_place, var_scope_, scope_);
if (op) {
RunAndConstructOpFuncNode(
op, *src_var_name, *new_var_name, op_func_nodes, skip_run);
op, *src_var_name, *new_var_name, op_func_nodes, static_build);
}
is_transferred = true;
}
......@@ -98,7 +99,7 @@ void DataTranferHelper::RunAndConstructShareNode(
const std::string& src_var_name,
const std::string& dst_var_name,
std::vector<OpFuncNode>* op_func_nodes,
bool skip_run) {
bool static_build) {
VariableNameMap in_name_map = {{"X", {src_var_name}}};
VariableNameMap out_name_map = {{"Out", {dst_var_name}}};
AttributeMap attr_map;
......@@ -112,7 +113,7 @@ void DataTranferHelper::RunAndConstructShareNode(
"Insert %s with %s -> %s.", op_type, src_var_name, dst_var_name);
RunAndConstructOpFuncNode(
op, src_var_name, dst_var_name, op_func_nodes, skip_run);
op, src_var_name, dst_var_name, op_func_nodes, static_build);
}
void DataTranferHelper::RunAndConstructOpFuncNode(
......@@ -120,15 +121,18 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
const std::string& var_name,
const std::string& new_var_name,
std::vector<OpFuncNode>* new_op_func_nodes,
bool skip_run) {
bool static_build) {
auto& op_type = op->Type();
// 1. Construct RuntimeContext
RuntimeContext runtime_context({}, {});
runtime_context.inputs["X"] = {scope_->FindVar(var_name)};
runtime_context.outputs["Out"] = {scope_->Var(new_var_name)};
RuntimeInferShapeContext infer_shape_ctx(*op, runtime_context);
op.get()->Info().infer_shape_(&infer_shape_ctx);
if (!static_build) {
RuntimeInferShapeContext infer_shape_ctx(*op, runtime_context);
op->Info().infer_shape_(&infer_shape_ctx);
}
// 2. choose kernel
......@@ -203,8 +207,9 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
} else {
new_op_func_node.phi_kernel_ = op_with_kernel->PhiKernel();
if (skip_run) {
if (static_build) {
FakeInitializeOutputsForFunctionKernel(
*op,
*(new_op_func_node.phi_kernel_),
*(op_with_kernel->PhiKernelSignature()),
runtime_context,
......@@ -449,7 +454,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
OpFuncNode* op_func_node,
std::vector<OpFuncNode>* new_op_func_nodes,
bool use_local_scope,
bool skip_run) {
bool static_build) {
Scope* local_scope = use_local_scope ? var_scope->GetMutableLocalScope()
: var_scope->GetMutableScope();
......@@ -546,7 +551,11 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
op_base->Type() == "fetch_v2");
if (op) {
data_transfer_helper.RunAndConstructOpFuncNode(
op, var_name, new_var_name, new_op_func_nodes, skip_run);
op,
var_name,
new_var_name,
new_op_func_nodes,
static_build);
}
is_transferred = true;
} else {
......@@ -611,7 +620,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
new_op_func_nodes,
use_local_scope,
op_base->Type() == "fetch_v2",
skip_run);
static_build);
}
if (is_transferred) {
......@@ -741,7 +750,7 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
VariableScope* var_scope,
std::vector<OpFuncNode>* op_func_nodes,
framework::Scope* local_scope,
bool skip_run) {
bool static_build) {
DataTranferHelper data_transfer_helper(place, var_scope, local_scope);
for (auto& var_name_item : out_names) {
std::vector<Variable*>& vars = out_vars->at(var_name_item.first);
......@@ -817,9 +826,9 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
auto op = TransferDtype(
var_name, &new_var_name, src_type, dst_type, var_scope, local_scope);
data_transfer_helper.RunAndConstructOpFuncNode(
op, var_name, new_var_name, op_func_nodes, skip_run);
op, var_name, new_var_name, op_func_nodes, static_build);
data_transfer_helper.RunAndConstructShareNode(
new_var_name, var_name, op_func_nodes, skip_run);
new_var_name, var_name, op_func_nodes, static_build);
}
}
}
......
......@@ -61,10 +61,9 @@ const std::string StringizeDownstreamMap(
const std::map<size_t, std::set<size_t>>& DependencyBuilder::Build(
const std::vector<Instruction>& instructions) {
PADDLE_ENFORCE_EQ(
is_build_,
false,
phi::errors::AlreadyExists("The op dependency has been built"));
if (is_build_) {
return op_downstream_map_;
}
instructions_ = &instructions;
op_num_ = instructions_->size();
......
......@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h"
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
#include "paddle/fluid/framework/new_executor/interpreter/static_build.h"
#include "paddle/fluid/memory/stats.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
......@@ -48,34 +49,6 @@ namespace interpreter {
using VariableIdMap = std::map<std::string, std::vector<int>>;
// These Op needs set output dtype when register phi kernel, but they didn't
static std::set<std::string> OpsNeedSetOutputDtypeWhenRegisterPhiKernel = {
"abs",
"adam",
"adamw",
"any_raw",
"eig_grad",
"eigh",
"lamb",
"layer_norm",
"layer_norm_grad",
"less_equal",
"less_than",
"merged_adam",
"sync_batch_norm_grad",
"unique",
"unique_consecutive_flattened_tensor",
"unique_raw"};
// These Ops can use InferMeta to infer the output dtype
static std::set<std::string> OpsWithAvailablePhiInferMeta = {
"abs", "adam", "adamw", "layer_norm", "layer_norm_grad", "merged_adam"};
// Cannot static analysis these Ops' output dtype or backend because their
// kernels have not moved to PHI yet.
static std::set<std::string> OpsWithFluidKernelNeedMoveToPhi = {
"fused_batch_norm_act", "fused_batch_norm_act_grad"};
// NOTE(Ruibiao): SingleStreamGuard make some multi-strem op (i.e.,
// c_allreduce_sum) run in single stream. It is dedicated to BuildOpFuncList
// which run kernel without stream synchronization.
......@@ -145,48 +118,6 @@ void AsyncWorkQueue::AddTask(const OpFuncType& op_func_type,
queue_group_->AddTask(op_func_type == OpFuncType::kGpuAsync, std::move(fn));
}
bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
// has_fluid_kernel = (kernelCode >> 3) & 1
// has_structed_kernel = (kernelCode >> 2) & 1
// need_move_to_phi = (kernelCode >> 1) & 1
// need_set_dtype = KernelCode & 1
using KernelCode = int8_t;
std::set<std::pair<std::string, KernelCode>> invalid_ops;
for (auto& op : block.AllOps()) {
auto op_type = op->Type();
bool has_fluid_kernel = OperatorWithKernel::AllOpKernels().count(op_type);
bool has_structured_kernel =
phi::KernelFactory::Instance().HasStructuredKernel(op_type);
bool need_move_to_phi = (has_fluid_kernel || has_structured_kernel) &&
OpsWithFluidKernelNeedMoveToPhi.count(op_type);
bool need_set_dtype =
!has_fluid_kernel && !has_structured_kernel &&
OpsNeedSetOutputDtypeWhenRegisterPhiKernel.count(op_type) &&
!OpsWithAvailablePhiInferMeta.count(op_type);
KernelCode kernel_code = (has_fluid_kernel << 3) +
(has_structured_kernel << 2) +
(need_move_to_phi << 1) + need_set_dtype;
if (need_move_to_phi || need_set_dtype) {
invalid_ops.insert(std::make_pair(op_type, kernel_code));
}
}
if (!invalid_ops.empty()) {
std::stringstream ss;
ss << "The following OPs are unable to static build:\n";
for (auto& item : invalid_ops) {
ss << item.first << " [has_fluid_kernel = " << (item.second >> 3 & 1)
<< ", has_structed_kerenl = " << (item.second >> 2 & 1)
<< ", need_move_to_phi = " << (item.second >> 1 & 1)
<< ", need_set_dtype = " << (item.second & 1) << "]\n";
}
VLOG(0) << ss.str();
}
return invalid_ops.empty();
}
bool IsCommunicationOp(const std::string& op_name) {
const std::set<std::string> special_comm_op_set = {
"send",
......@@ -492,17 +423,25 @@ void ApplyDeviceGuard(const OperatorBase* op_base,
}
void HandleOperatorBase(const platform::Place& place,
const VariableScope* var_scope,
std::shared_ptr<OperatorBase> op_base,
std::shared_ptr<OperatorBase> op,
OpFuncNode* op_func_node,
Scope* local_scope) {
Scope* scope,
bool static_build) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
// input, output is prepared. set the other attributes.
op_func_node->operator_base_ = op_base;
op_func_node->operator_base_ = op;
op_func_node->type_ = AnalyseOpFuncType(*op_func_node, place);
op_func_node->kernel_func_ = nullptr;
op_base->Run(*local_scope, place); // Run without data transformer.
if (static_build) {
if (OperatorBasesMustRunInStaticBuild.count(op->Type())) {
op->Run(*scope, place);
}
FakeInitializeOutputsForOperatorBase(*op, place, scope);
} else {
op->Run(*scope, place); // Run without data transformer.
}
op_func_node->dev_ctx_ = dev_ctx;
}
......@@ -636,7 +575,7 @@ void BuildOpFuncList(const platform::Place& place,
VLOG(4) << "HandleOperatorBase";
// op is not a operatorwithkernel, so direcly run OperatorBase::Run()
HandleOperatorBase(
place, var_scope, ops[i], &op_func_node, local_scope);
place, ops[i], &op_func_node, local_scope, static_build);
vec_func_list->emplace_back(op_func_node);
} else {
VLOG(4) << "OP is not null";
......@@ -754,15 +693,18 @@ void BuildOpFuncList(const platform::Place& place,
use_local_scope,
static_build);
VLOG(4) << "apply data transform done. ";
// step 4. infershape, see OperatorWithKernel::RunImpl in operator.cc
// for why.
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
// step 4. infershape
if (!static_build) {
VLOG(4) << "infer shape";
RuntimeInferShapeContext infer_shape_ctx(*op, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted from OperatorWithKernel.
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
// see kAllKernelsMustComputeRuntimeShape in operator.h for why
if (!(op->HasAttr(kAllKernelsMustComputeRuntimeShape) &&
op->Attr<bool>(kAllKernelsMustComputeRuntimeShape))) {
RuntimeInferShapeContext infer_shape_ctx(*op, runtime_context);
// TODO(Aurelius84): In case of control flow ops, they are NOT
// inheritted from OperatorWithKernel.
op_with_kernel->Info().infer_shape_(&infer_shape_ctx);
}
}
// step 5. run kernel
......@@ -772,6 +714,7 @@ void BuildOpFuncList(const platform::Place& place,
VLOG(6) << op_type << " run function kernel";
if (static_build) {
FakeInitializeOutputsForFunctionKernel(
*op,
*(op_func_node.phi_kernel_),
*(op_with_kernel->PhiKernelSignature()),
runtime_context,
......@@ -826,7 +769,27 @@ void BuildOpFuncList(const platform::Place& place,
auto* original_tensor =
GetMutableLoDTensorOrSelectedRowsValueFromVar(
local_scope->FindVar(var_scope->GetNameById(p.second)));
original_tensor->ShareDataWith(*transformed_tensor);
// avoid overwriting valid data
if (static_build && original_tensor->initialized()) {
const phi::Place& target_place = transformed_tensor->place();
platform::DeviceContext* dev_ctx_for_copy;
if (target_place.GetType() != AllocationType::CPU) {
dev_ctx_for_copy = pool.Get(target_place);
} else {
dev_ctx_for_copy = pool.Get(original_tensor->place());
}
phi::Copy(*dev_ctx_for_copy,
*original_tensor,
target_place,
/*blocking=*/true,
original_tensor);
original_tensor->set_type(transformed_tensor->dtype());
original_tensor->set_layout(transformed_tensor->layout());
} else {
original_tensor->ShareDataWith(*transformed_tensor);
}
VLOG(4) << "Transfer inplace variable back form "
<< var_scope->GetNameById(p.first) << " to "
<< var_scope->GetNameById(p.second);
......@@ -866,32 +829,35 @@ void BuildOpFuncList(const platform::Place& place,
VLOG(4) << "End run " << place << " "
<< op_func_node.operator_base_->DebugStringEx(local_scope);
// gc---------------------------------------------
auto iter = unused_var_map.find(op);
if (iter == unused_var_map.end()) {
interpreter::LogDeviceMemoryStats(place);
continue;
}
auto& delete_vars = iter->second;
std::deque<std::shared_ptr<memory::Allocation>>* garbages =
new std::deque<std::shared_ptr<memory::Allocation>>();
for (auto& var_name : delete_vars) {
auto* var = local_scope->FindVar(var_name);
if (var == nullptr || skip_gc_vars.find(var_name) != skip_gc_vars.end()) {
if (!static_build) {
// gc---------------------------------------------
auto iter = unused_var_map.find(op);
if (iter == unused_var_map.end()) {
interpreter::LogDeviceMemoryStats(place);
continue;
}
VLOG(6) << "Erase variable " << var_name;
if (var->IsType<phi::DenseTensor>()) {
garbages->emplace_back(
var->GetMutable<phi::DenseTensor>()->MoveMemoryHolder());
auto& delete_vars = iter->second;
std::deque<std::shared_ptr<memory::Allocation>>* garbages =
new std::deque<std::shared_ptr<memory::Allocation>>();
for (auto& var_name : delete_vars) {
auto* var = local_scope->FindVar(var_name);
if (var == nullptr ||
skip_gc_vars.find(var_name) != skip_gc_vars.end()) {
continue;
}
VLOG(6) << "Erase variable " << var_name;
if (var->IsType<phi::DenseTensor>()) {
garbages->emplace_back(
var->GetMutable<phi::DenseTensor>()->MoveMemoryHolder());
}
}
}
delete garbages; // free mem
delete garbages; // free mem
interpreter::LogDeviceMemoryStats(place);
interpreter::LogDeviceMemoryStats(place);
}
}
}
......@@ -942,160 +908,6 @@ void BuildVariableScope(const framework::BlockDesc& block,
}
}
phi::TensorBase* GetTensorFormVar(framework::Variable* var) {
if (var) {
if (var->template IsType<phi::DenseTensor>()) {
return var->template GetMutable<phi::DenseTensor>();
} else if (var->template IsType<phi::SelectedRows>()) {
return var->template GetMutable<phi::SelectedRows>();
} else if (var->template IsType<phi::SparseCooTensor>()) {
return var->template GetMutable<phi::SparseCooTensor>();
} else if (var->template IsType<framework::LoDTensorArray>()) {
return var->template GetMutable<framework::LoDTensorArray>();
} else if (var->template IsType<framework::Strings>()) {
return var->template GetMutable<framework::Strings>();
} else if (var->template IsType<paddle::framework::RawTensor>()) {
return var->template GetMutable<paddle::framework::RawTensor>();
} else if (!var->IsInitialized()) {
// The following is for RAW type of var
return var->template GetMutable<paddle::framework::RawTensor>();
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type when get tensor.",
framework::ToTypeName(var->Type())));
}
} else {
VLOG(4) << "Var is nullptr";
return nullptr;
}
}
void FakeInitializeTensor(const platform::DeviceContext& dev_ctx,
const phi::DataType& dtype,
const phi::Place& place,
phi::TensorBase* tensor) {
PADDLE_ENFORCE_NOT_NULL(
tensor,
phi::errors::InvalidArgument(
"The tensor to fake intialize should not be null."));
if (place == phi::CPUPlace()) {
dev_ctx.HostAlloc(tensor,
dtype,
/*requested_size=*/0,
/*fake_alloc=*/true);
} else {
PADDLE_ENFORCE_EQ(
place,
dev_ctx.GetPlace(),
phi::errors::Unavailable("The place %s for fack alloc is not equal to "
"the place %s of DeviceContext.",
place,
dev_ctx.GetPlace()));
dev_ctx.Alloc(tensor,
dtype,
/*requested_size=*/0,
/*pinned=*/false,
/*fake_alloc=*/true);
}
}
void FakeInitializeOutputsForFunctionKernel(
const phi::Kernel& phi_kernel,
const phi::KernelSignature& kernel_sig,
const RuntimeContext& ctx,
const platform::DeviceContext& dev_ctx) {
std::string op_name = std::string(kernel_sig.name);
if (OpsNeedSetOutputDtypeWhenRegisterPhiKernel.count(op_name)) {
PADDLE_ENFORCE_GT(
OpsWithAvailablePhiInferMeta.count(op_name),
0,
phi::errors::Unavailable(
"Cannot static build for op %s because it did not set output dtype "
"in phi kernel register. Please set its output dtype and remove it "
"from OpsNeedSetOutputDtypeWhenRegisterPhiKernel set, or add it to "
" OpsWithAvailablePhiInferMeta set if its InferMeta is available.",
op_name));
}
auto output_names = kernel_sig.output_names;
auto output_defs = phi_kernel.args_def().output_defs();
PADDLE_ENFORCE_EQ(output_names.size(),
output_defs.size(),
platform::errors::InvalidArgument(
"The size of outputs_args names (%d) must be equal to "
"the size of kernel output_defs (%d).",
output_names.size(),
output_defs.size()));
size_t start_idx = 0;
for (size_t i = 0; i < output_names.size(); ++i) {
auto it = ctx.outputs.find(output_names[i]);
// Deal with the case that some outputs are not found or be NULL when run
// the kernel. For example : the outputs of matmul_grad are dx and dy,
// sometimes dx or dy may be NULL.
if (it == ctx.outputs.end() || it->second.empty()) {
VLOG(4) << "Output " << output_names[i] << " not found";
++start_idx;
continue;
}
auto& outs_vector = it->second;
for (size_t offset = 0; offset < outs_vector.size(); ++offset) {
phi::TensorBase* out_tensor = GetTensorFormVar(outs_vector[offset]);
if (out_tensor && !out_tensor->initialized()) {
phi::TensorArgDef& tensor_arg_def = output_defs[start_idx + offset];
phi::DataType dtype = tensor_arg_def.dtype;
phi::Place place = tensor_arg_def.backend == phi::Backend::CUSTOM
? dev_ctx.GetPlace()
: phi::TransToPhiPlace(tensor_arg_def.backend);
if (dtype == DataType::UNDEFINED ||
OpsNeedSetOutputDtypeWhenRegisterPhiKernel.count(
std::string(kernel_sig.name))) {
VLOG(4) << "Get dtype result from InferMeta";
dtype = out_tensor->dtype(); // dtype from InferMeta
}
VLOG(4) << output_names[i] << " fake alloc with type " << dtype
<< " on place " << place << " " << out_tensor;
FakeInitializeTensor(dev_ctx, dtype, place, out_tensor);
}
}
start_idx += outs_vector.size();
}
}
void FakeInitializeOutputsForStructureKernel(
const framework::OpKernelType& op_kernel_type,
ExecutionContext* execution_context) {
const std::string& op_type = execution_context->Type();
if (op_type == "fetch_v2") {
return;
}
const VariableNameMap& outputs = execution_context->GetOp().Outputs();
for (auto& item : outputs) {
const std::string& parameter_name = item.first;
auto multi_output_var = execution_context->MultiOutputVar(parameter_name);
for (Variable* var : multi_output_var) {
phi::TensorBase* out_tensor = GetTensorFormVar(var);
if (out_tensor && !out_tensor->initialized()) {
phi::DataType dtype =
phi::TransToPhiDataType(op_kernel_type.data_type_);
phi::Place place = execution_context->GetPlace();
VLOG(4) << parameter_name << " fake alloc with type " << dtype
<< " on place " << place << " " << out_tensor;
FakeInitializeTensor(
execution_context->device_context(), dtype, place, out_tensor);
}
}
}
}
void LogDeviceMemoryStats(const platform::Place& place) {
if (FLAGS_new_executor_log_memory_stats && platform::is_gpu_place(place)) {
VLOG(0) << "memory_allocated: "
......
......@@ -65,8 +65,6 @@ class AsyncWorkQueue {
std::unique_ptr<WorkQueueGroup> queue_group_;
};
bool BlockCanBeStaticBuilt(const framework::BlockDesc& block);
bool IsCommunicationOp(const std::string& op_name);
bool IsCommunicationOp(const Instruction& instr);
......@@ -99,16 +97,6 @@ void BuildVariableScope(const framework::BlockDesc& block,
const ExecutionConfig& execution_config,
VariableScope* var_scope);
void FakeInitializeOutputsForFunctionKernel(
const phi::Kernel& phi_kernel,
const phi::KernelSignature& kernel_sig,
const RuntimeContext& ctx,
const platform::DeviceContext& dev_ctx);
void FakeInitializeOutputsForStructureKernel(
const framework::OpKernelType& op_kernel_type,
ExecutionContext* execution_context);
void LogDeviceMemoryStats(const platform::Place& place);
void SetDeviceCommContext(framework::OperatorBase* operator_base,
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
extern std::set<std::string> OperatorBasesMustRunInStaticBuild;
namespace paddle {
namespace framework {
namespace interpreter {
bool BlockCanBeStaticBuilt(const framework::BlockDesc& block);
void FakeInitializeOutputsForOperatorBase(const OperatorBase& op,
const platform::Place& place,
Scope* scope);
void FakeInitializeOutputsForFunctionKernel(
const framework::OperatorBase& op,
const phi::Kernel& phi_kernel,
const phi::KernelSignature& kernel_sig,
const RuntimeContext& ctx,
const platform::DeviceContext& dev_ctx);
void FakeInitializeOutputsForStructureKernel(
const framework::OpKernelType& op_kernel_type,
ExecutionContext* execution_context);
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -21,6 +21,7 @@
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/static_build.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/os_info.h"
......@@ -112,6 +113,8 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
VLOG(4) << "InterpreterCore(): " << this << " on " << place_;
static_build_ = FLAGS_new_executor_static_build &&
!FLAGS_new_executor_use_cuda_graph &&
!execution_config.used_for_control_flow_op &&
interpreter::BlockCanBeStaticBuilt(block);
exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
......@@ -281,12 +284,12 @@ paddle::framework::FetchList InterpreterCore::Run(
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph
Convert(&op_func_nodes);
is_build_ = true;
UpdateSyncOpNum();
if (static_build_) {
VLOG(4) << "RUN impl";
RunImpl();
}
is_build_ = true;
} else {
RunImpl();
}
......@@ -597,7 +600,7 @@ void InterpreterCore::BuildOperatorDependences() {
// analysis the dependences between ops, add next_instr_list to each instr,
// and set the dependecy_count_
size_t instr_num = vec_instruction_.size();
dependecy_count_.resize(instr_num);
dependecy_count_ = std::vector<size_t>(instr_num, 0);
auto downstream_map = dependency_builder_.Build(vec_instruction_);
for (size_t instr_id = 0; instr_id < instr_num; ++instr_id) {
......@@ -657,6 +660,7 @@ void InterpreterCore::Convert(
auto& vec_meta_info = var_scope_.MutableVecMetaInfo();
auto nodes = *op_func_nodes;
auto op_nums = nodes.size();
vec_instruction_.clear();
vec_instruction_.reserve(op_nums);
for (size_t op_idx = 0; op_idx < op_nums; ++op_idx) {
auto& op_func_node = nodes[op_idx];
......@@ -825,8 +829,6 @@ void InterpreterCore::Convert(
BuildAndCacheInstructionCtx(&vec_instruction_[i]);
}
BuildSkipShareLoDInfo();
bool inplaced = false;
for (const Instruction& inst : vec_instruction_) {
if (inst.OpBase()->Type() == "share_buffer" ||
......@@ -867,6 +869,10 @@ void InterpreterCore::BuildSkipShareLoDInfo() {
}
}
}
if (can_skip_lod) {
VLOG(8) << "skip share lod for: " << vec_instruction_[i].OpBase()->Type()
<< " (" << i << ")";
}
vec_instruction_[i].InnerInferShapeContext()->SetSkipLoD(can_skip_lod);
}
}
......@@ -1060,6 +1066,7 @@ void InterpreterCore::ExecuteInstructionList(
// EOF is not a fatal error.
if (exception_holder_.Type() != "EOF") {
async_work_queue_->Cancel();
async_work_queue_.reset();
}
VLOG(4) << "Cancel ok";
PADDLE_ENFORCE_EQ(
......@@ -1297,11 +1304,12 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names,
// convert vec func_list to graph
Convert(&op_func_nodes);
UpdateSyncOpNum();
is_build_ = true;
if (static_build_) {
VLOG(4) << "RUN impl";
RunImpl();
}
BuildSkipShareLoDInfo();
is_build_ = true;
}
// NOTE: Because feed_tensor will be GC after
// paddle::framework::BuildOpFuncList, so we should
......
......@@ -2753,13 +2753,6 @@ void OperatorWithKernel::ParseInputDataType(
t = &(var->Get<phi::SelectedRows>().value());
} else if (var->IsType<phi::SparseCooTensor>()) {
const phi::SparseCooTensor* sp_t = &(var->Get<phi::SparseCooTensor>());
PADDLE_ENFORCE_EQ(
sp_t->initialized(),
true,
platform::errors::InvalidArgument("The %s Op's Input Variable `%s` "
"contains uninitialized Tensor.",
Type(),
name));
*data_type = paddle::framework::TransToProtoVarType(sp_t->dtype());
return;
} else if (var->IsType<LoDTensorArray>()) {
......
......@@ -37,8 +37,23 @@ REGISTER_OP_WITHOUT_GRADIENT(c_sync_calc_stream,
ops::CSyncCalcStreamOp,
ops::CSyncCalcStreamOpMaker);
REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>);
REGISTER_OP_NPU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>);
REGISTER_OP_MLU_KERNEL(c_sync_calc_stream, ops::CSyncCalcStreamKernel<float>);
REGISTER_OP_CUDA_KERNEL(c_sync_calc_stream,
ops::CSyncCalcStreamKernel<float>,
ops::CSyncCalcStreamKernel<double>,
ops::CSyncCalcStreamKernel<int>,
ops::CSyncCalcStreamKernel<int64_t>,
ops::CSyncCalcStreamKernel<paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(c_sync_calc_stream,
ops::CSyncCalcStreamKernel<float>,
ops::CSyncCalcStreamKernel<double>,
ops::CSyncCalcStreamKernel<int>,
ops::CSyncCalcStreamKernel<int64_t>,
ops::CSyncCalcStreamKernel<paddle::platform::float16>);
REGISTER_OP_MLU_KERNEL(c_sync_calc_stream,
ops::CSyncCalcStreamKernel<float>,
ops::CSyncCalcStreamKernel<double>,
ops::CSyncCalcStreamKernel<int>,
ops::CSyncCalcStreamKernel<int64_t>,
ops::CSyncCalcStreamKernel<paddle::platform::float16>);
......@@ -30,7 +30,8 @@ class CSyncCalcStreamOp : public framework::OperatorWithKernel {
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.GetPlace());
}
};
......
......@@ -41,7 +41,7 @@ class ConcatOp : public framework::OperatorWithKernel {
auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto *input : inputs) {
if (input->IsInitialized() && input->numel() > 0) {
if (input->IsInitialized()) {
input_data_type = framework::TransToProtoVarType(input->dtype());
flag = 1;
break;
......
......@@ -95,7 +95,7 @@ class PartialConcatOp : public framework::OperatorWithKernel {
auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto *input : inputs) {
if (input->IsInitialized() && input->numel() > 0) {
if (input->IsInitialized()) {
input_data_type = framework::TransToProtoVarType(input->dtype());
flag = 1;
break;
......
......@@ -97,7 +97,7 @@ class PartialSumOp : public framework::OperatorWithKernel {
auto input_data_type = framework::proto::VarType::Type(0);
bool flag = 0;
for (auto *input : inputs) {
if (input->IsInitialized() && input->numel() > 0) {
if (input->IsInitialized()) {
input_data_type = framework::TransToProtoVarType(input->dtype());
flag = 1;
break;
......
......@@ -55,6 +55,8 @@ class BufferedReader : public framework::DecoratedReader {
~BufferedReader() override;
platform::Place GetPlace() const { return place_; }
private:
void ReadTillBufferFullAsync();
......
......@@ -54,7 +54,7 @@ class SumOp : public framework::OperatorWithKernel {
x_vars_name[idx]));
auto tensor =
framework::GetLoDTensorOrSelectedRowsValueFromVar(*x_vars[idx]);
if (tensor->numel() <= 0 || (!tensor->IsInitialized())) {
if (!tensor->IsInitialized()) {
continue;
}
if (dtype == -1) {
......
......@@ -57,9 +57,9 @@ class TransferLayoutOp : public framework::OperatorWithKernel {
}
auto place =
in_tensor->IsInitialized() ? in_tensor->place() : platform::CPUPlace();
// dtype is not important
return phi::KernelKey(framework::proto::VarType::FP32, place);
phi::DataType dtype = in_tensor->IsInitialized() ? in_tensor->dtype()
: phi::DataType::FLOAT32;
return phi::KernelKey(phi::TransToProtoVarType(dtype), place);
}
phi::KernelKey GetKernelTypeForVar(
......
......@@ -146,15 +146,26 @@ struct DeviceContext::Impl {
// NOTE(paddle-dev): In case of tensor has already hold allocation and
// is going to allocate allocation on new place, we will clear its holder
// firstly and then re-alloc it.
if (tensor->initialized() && tensor->place() != place) {
ClearHolder(tensor);
if (phi::DenseTensor::classof(tensor)) {
// NOTE(Ruibiao): The tensor hold zero-size allocation is not regarded as
// `initialized`. Fix other tensor class when needed.
if (static_cast<phi::DenseTensor*>(tensor)->Holder() &&
tensor->place() != place) {
ClearHolder(tensor);
}
} else {
if (tensor->initialized() && tensor->place() != place) {
ClearHolder(tensor);
}
}
auto* allocator =
(tensor->numel() == 0 || fake_alloc) && requested_size == 0
(fake_alloc || tensor->numel() == 0) && requested_size == 0
? zero_allocator_
: (pinned ? pinned_allocator_ : device_allocator_);
#ifdef PADDLE_WITH_CUDA
bool must_cuda_graph_allocator = (tensor->numel() != 0) && !pinned;
bool must_cuda_graph_allocator =
(!fake_alloc && tensor->numel() != 0) && !pinned;
if (must_cuda_graph_allocator &&
place.GetType() == phi::AllocationType::GPU &&
phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) {
......@@ -189,11 +200,22 @@ struct DeviceContext::Impl {
if (dtype == DataType::UNDEFINED) {
dtype = tensor->dtype();
}
if (tensor->initialized() && tensor->place() != CPUPlace()) {
ClearHolder(tensor);
if (phi::DenseTensor::classof(tensor)) {
// NOTE(Ruibiao): The tensor holds zero-size allocation is not regarded as
// `initialized`. Fix other tensor class when needed.
if (static_cast<phi::DenseTensor*>(tensor)->Holder() &&
tensor->place() != CPUPlace()) {
ClearHolder(tensor);
}
} else {
if (tensor->initialized() && tensor->place() != CPUPlace()) {
ClearHolder(tensor);
}
}
auto* allocator =
(tensor->numel() == 0 || fake_alloc) && requested_size == 0
(fake_alloc || tensor->numel() == 0) && requested_size == 0
? host_zero_allocator_
: host_allocator_;
return tensor->AllocateFrom(
......@@ -246,8 +268,6 @@ struct DeviceContext::Impl {
private:
void ClearHolder(TensorBase* tensor) const {
if (!tensor->initialized()) return;
if (DenseTensor::classof(tensor)) {
static_cast<DenseTensor*>(tensor)->clear();
} else if (SelectedRows::classof(tensor)) {
......
......@@ -139,10 +139,14 @@ class SelectedRows : public TensorBase,
/// \return The data type of the tensor.
DataType dtype() const noexcept override { return impl_->dtype(); }
void set_type(const DataType dtype) { impl_->set_type(dtype); }
/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return impl_->layout(); }
void set_layout(const DataLayout layout) { impl_->set_layout(layout); }
/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
const Place& place() const override { return impl_->place(); };
......
......@@ -159,10 +159,14 @@ class SelectedRowsImpl {
/// \return The data type of the tensor.
DataType dtype() const noexcept { return value_->dtype(); }
void set_type(const DataType dtype) { value_->set_type(dtype); }
/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const noexcept { return value_->layout(); }
void set_layout(const DataLayout layout) { value_->set_layout(layout); }
/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
const Place& place() const { return value_->place(); }
......
......@@ -104,11 +104,14 @@ class SparseCooTensor : public TensorBase,
/// \brief Returns the data type of the tensor.
/// \return The data type of the tensor.
DataType dtype() const noexcept override { return meta_.dtype; }
void set_type(const DataType dtype) { meta_.dtype = dtype; }
/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return meta_.layout; }
void set_layout(const DataLayout layout) { meta_.layout = layout; }
/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
const Place& place() const override { return non_zero_elements_.place(); }
......
......@@ -110,10 +110,14 @@ class SparseCsrTensor : public TensorBase,
/// \return The data type of the tensor.
DataType dtype() const noexcept override { return meta_.dtype; }
void set_type(const DataType dtype) { meta_.dtype = dtype; }
/// \brief Returns the data layout of the tensor.
/// \return The data layout of the tensor.
DataLayout layout() const noexcept override { return meta_.layout; }
void set_layout(const DataLayout layout) { meta_.layout = layout; }
/// \brief Returns the data place of the tensor.
/// \return The data place of the tensor.
const Place& place() const override { return non_zero_elements_.place(); }
......
......@@ -23,8 +23,12 @@ TensorArray::TensorArray(const std::vector<DenseTensor>& vec) {
/// \brief Test whether the tensor's storage in TensorArray is allocated.
/// return Whether all tensors in TensorArray is allocated.
bool TensorArray::initialized() const {
if (tensors_.empty()) {
return false;
}
for (auto tensor : tensors_) {
if (!tensor.IsInitialized()) {
if (!tensor.initialized()) {
return false;
}
}
......@@ -42,18 +46,69 @@ const DDim& TensorArray::dims() const {
}
const Place& TensorArray::place() const {
PADDLE_THROW(errors::Unavailable("place() can't be used in TensorArray"));
return tensors_[0].place();
PADDLE_ENFORCE_NE(
tensors_.size(), 0, errors::Unavailable("TensorArray is not assigned."));
const Place& place = tensors_[0].place();
for (size_t i = 1; i < tensors_.size(); ++i) {
PADDLE_ENFORCE_EQ(
tensors_[i].place(),
place,
errors::Unavailable(
"The Place of all tensors in TensorArray must be consistent. The "
"current place is %s, but the previous place is %s.",
tensors_[i].place(),
place));
}
return place;
}
DataType TensorArray::dtype() const {
PADDLE_THROW(errors::Unavailable("dtype() can't be used in TensorArray"));
return DataType::UNDEFINED;
PADDLE_ENFORCE_NE(
tensors_.size(), 0, errors::Unavailable("TensorArray is not assigned."));
const DataType dtype = tensors_[0].dtype();
for (size_t i = 1; i < tensors_.size(); ++i) {
PADDLE_ENFORCE_EQ(
tensors_[i].dtype(),
dtype,
errors::Unavailable(
"The DataType of all tensors in TensorArray must be consistent. "
"The current dtype is %s, but the previous dtype is %s.",
tensors_[i].dtype(),
dtype));
}
return dtype;
}
void TensorArray::set_type(const DataType dtype) {
for (size_t i = 0; i < tensors_.size(); ++i) {
tensors_[i].set_type(dtype);
}
}
DataLayout TensorArray::layout() const {
PADDLE_THROW(errors::Unavailable("layout() can't be used in TensorArray"));
return DataLayout::UNDEFINED;
PADDLE_ENFORCE_NE(
tensors_.size(), 0, errors::Unavailable("TensorArray is not assigned."));
const DataLayout layout = tensors_[0].layout();
for (size_t i = 1; i < tensors_.size(); ++i) {
PADDLE_ENFORCE_EQ(
tensors_[i].layout(),
layout,
errors::Unavailable(
"The DataLayout of all tensors in TensorArray must be consistent. "
"The current layout is %s, but the previous layout is %s.",
tensors_[i].layout(),
layout));
}
return layout;
}
void TensorArray::set_layout(DataLayout layout) {
for (size_t i = 0; i < tensors_.size(); ++i) {
tensors_[i].set_layout(layout);
}
}
bool TensorArray::valid() const {
......
......@@ -63,12 +63,14 @@ class TensorArray : public TensorBase,
/// \brief This overrided function is not used in TensorArray.
const Place& place() const override;
/// \brief This overrided function is not used in TensorArray.
DataType dtype() const override;
/// \brief This overrided function is not used in TensorArray.
void set_type(const DataType dtype);
DataLayout layout() const override;
void set_layout(const DataLayout layout);
/// \brief This overrided function is not used in TensorArray.
bool valid() const override;
......
......@@ -316,6 +316,16 @@ void Copy(const Context& dev_ctx,
dst->set_dims(src.dims());
}
template <typename Context>
void Copy(const Context& dev_ctx,
const TensorArray& src,
Place dst_place,
bool blocking,
TensorArray* dst) {
// NOTE(Ruibiao): implements Copy() for TensorArray when needed.
PADDLE_THROW(errors::Unimplemented("Copy for TensorArray is unimplemented."));
}
template void Copy(const CPUContext& dev_ctx,
const DenseTensor& src,
Place dst_place,
......@@ -363,6 +373,18 @@ template void Copy(const DeviceContext& dev_ctx,
bool blocking,
SparseCsrTensor* dst);
template void Copy(const CPUContext& dev_ctx,
const TensorArray& src,
Place dst_place,
bool blocking,
TensorArray* dst);
template void Copy(const DeviceContext& dev_ctx,
const TensorArray& src,
Place dst_place,
bool blocking,
TensorArray* dst);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template void Copy(const GPUContext& dev_ctx,
const DenseTensor& src,
......@@ -384,6 +406,11 @@ template void Copy(const GPUContext& dev_ctx,
Place dst_place,
bool blocking,
SparseCsrTensor* dst);
template void Copy(const GPUContext& dev_ctx,
const TensorArray& src,
Place dst_place,
bool blocking,
TensorArray* dst);
#endif
#ifdef PADDLE_WITH_XPU
......@@ -392,6 +419,11 @@ template void Copy(const XPUContext& dev_ctx,
Place dst_place,
bool blocking,
DenseTensor* dst);
template void Copy(const XPUContext& dev_ctx,
const TensorArray& src,
Place dst_place,
bool blocking,
TensorArray* dst);
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
......@@ -400,6 +432,11 @@ template void Copy(const CustomContext& dev_ctx,
Place dst_place,
bool blocking,
DenseTensor* dst);
template void Copy(const CustomContext& dev_ctx,
const TensorArray& src,
Place dst_place,
bool blocking,
TensorArray* dst);
#endif
#ifdef PADDLE_WITH_MKLDNN
......@@ -408,6 +445,11 @@ template void Copy(const OneDNNContext& dev_ctx,
Place dst_place,
bool blocking,
DenseTensor* dst);
template void Copy(const OneDNNContext& dev_ctx,
const TensorArray& src,
Place dst_place,
bool blocking,
TensorArray* dst);
#endif
template <typename T>
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include "paddle/phi/core/selected_rows.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/core/sparse_csr_tensor.h"
#include "paddle/phi/core/tensor_array.h"
#include "paddle/phi/core/tensor_meta.h"
namespace phi {
......@@ -109,6 +110,13 @@ void Copy(const Context& dev_ctx,
bool blocking,
SparseCsrTensor* dst);
template <typename Context>
void Copy(const Context& dev_ctx,
const TensorArray& src,
Place dst_place,
bool blocking,
TensorArray* dst);
template <typename T>
void TensorFromVector(const std::vector<T>& src,
const phi::DeviceContext& ctx,
......
......@@ -79,7 +79,7 @@ void GetMemSizeAndDtype(const std::vector<const DenseTensor *> &lod_tensors,
size_of_dtype
: static_cast<size_t>(size);
const void *ptr =
lod_tensors[i]->IsInitialized() ? lod_tensors[i]->data() : nullptr;
lod_tensors[i]->initialized() ? lod_tensors[i]->data() : nullptr;
VLOG(4) << size << " " << len;
ss << "input(" << i << "-th tensor) dim:(" << lod_tensors[i]->dims() << ") "
<< " addres:" << ptr << " len: " << len << ", ";
......@@ -127,7 +127,7 @@ void CoalesceTensorKernel(const Context &dev_ctx,
output[i],
errors::InvalidArgument("The %d-th output tensor cannot be nullptr.",
i));
if (!input[i]->IsInitialized()) {
if (!input[i]->initialized()) {
has_not_init_in_vars = true;
}
}
......@@ -142,7 +142,7 @@ void CoalesceTensorKernel(const Context &dev_ctx,
for (size_t i = 0; i < input.size(); ++i) {
phi::DDim dims(concated_shapes.data() + accumulated_ranks,
concated_ranks[i]);
if (!input[i]->IsInitialized()) {
if (!input[i]->initialized()) {
PADDLE_ENFORCE_EQ(
input[i],
output[i],
......@@ -220,7 +220,7 @@ void CoalesceTensorKernel(const Context &dev_ctx,
auto sub_tensor = fused_output->Slice(static_cast<int64_t>(offset),
static_cast<int64_t>(offset + len));
// some var may not persistable, or persistable var may not init
if (output[i]->IsInitialized()) {
if (output[i]->initialized()) {
phi::Copy(dev_ctx, *output[i], dev_ctx.GetPlace(), false, &sub_tensor);
}
offset += use_align
......@@ -270,7 +270,9 @@ PD_REGISTER_KERNEL(coalesce_tensor,
phi::CoalesceTensorKernel,
int,
float,
double) {}
double) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(coalesce_tensor,
......@@ -282,6 +284,7 @@ PD_REGISTER_KERNEL(coalesce_tensor,
float,
double) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}
#endif
......@@ -295,5 +298,6 @@ PD_REGISTER_KERNEL(coalesce_tensor,
float,
double) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
}
#endif
......@@ -46,4 +46,6 @@ PD_REGISTER_KERNEL(abs,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
......@@ -97,6 +97,6 @@ PD_REGISTER_KERNEL(
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->InputAt(2).SetDataType(phi::DataType::INT64);
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(2).SetDataType(phi::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
......@@ -201,7 +201,9 @@ PD_REGISTER_KERNEL(argmin,
int32_t,
int64_t,
int16_t,
uint8_t) {}
uint8_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
PD_REGISTER_KERNEL(argmax,
CPU,
......@@ -212,4 +214,6 @@ PD_REGISTER_KERNEL(argmax,
int32_t,
int64_t,
int16_t,
uint8_t) {}
uint8_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -15,8 +15,11 @@
#include "paddle/phi/kernels/as_complex_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/as_complex_impl.h"
PD_REGISTER_KERNEL(
as_complex, CPU, ALL_LAYOUT, phi::AsComplexKernel, float, double) {}
as_complex, CPU, ALL_LAYOUT, phi::AsComplexKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
......@@ -53,4 +53,8 @@ PD_REGISTER_KERNEL(average_accumulates,
ALL_LAYOUT,
phi::AverageAccumulatesKernel,
float,
double) {}
double) {
kernel->OutputAt(3).SetDataType(phi::DataType::INT64);
kernel->OutputAt(4).SetDataType(phi::DataType::INT64);
kernel->OutputAt(5).SetDataType(phi::DataType::INT64);
}
......@@ -105,6 +105,7 @@ PD_REGISTER_KERNEL(eig,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
const phi::DataType& out_dtype = phi::dtype::ToComplex(kernel_key.dtype());
kernel->OutputAt(0).SetDataType(out_dtype);
kernel->OutputAt(1).SetDataType(out_dtype);
}
......@@ -258,5 +258,5 @@ PD_REGISTER_KERNEL(eigvals,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/fft_grad_kernel.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fft_grad_kernel_impl.h"
......@@ -23,10 +24,14 @@ PD_REGISTER_KERNEL(fft_c2c_grad,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(
fft_c2r_grad, CPU, ALL_LAYOUT, phi::FFTC2RGradKernel, float, double) {}
fft_c2r_grad, CPU, ALL_LAYOUT, phi::FFTC2RGradKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(fft_r2c_grad,
CPU,
ALL_LAYOUT,
phi::FFTR2CGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/fft_kernel.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fft_kernel_impl.h"
......@@ -28,8 +29,8 @@ PD_REGISTER_KERNEL(fft_c2r,
phi::FFTC2RKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(fft_r2c, CPU, ALL_LAYOUT, phi::FFTR2CKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
......@@ -141,4 +141,7 @@ void LayerNormKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
layer_norm, CPU, ALL_LAYOUT, phi::LayerNormKernel, float, double) {}
layer_norm, CPU, ALL_LAYOUT, phi::LayerNormKernel, float, double) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -54,4 +54,6 @@ PD_REGISTER_KERNEL(sum_grad,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -76,4 +76,6 @@ PD_REGISTER_KERNEL(abs,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
......@@ -144,6 +144,6 @@ PD_REGISTER_KERNEL(accuracy,
double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->InputAt(2).SetDataType(phi::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(2).SetDataType(phi::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
}
......@@ -15,8 +15,11 @@
#include "paddle/phi/kernels/as_complex_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/as_complex_impl.h"
PD_REGISTER_KERNEL(
as_complex, GPU, ALL_LAYOUT, phi::AsComplexKernel, float, double) {}
as_complex, GPU, ALL_LAYOUT, phi::AsComplexKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
......@@ -97,4 +97,8 @@ PD_REGISTER_KERNEL(average_accumulates,
ALL_LAYOUT,
phi::AverageAccumulatesKernel,
float,
double) {}
double) {
kernel->OutputAt(3).SetDataType(phi::DataType::INT64);
kernel->OutputAt(4).SetDataType(phi::DataType::INT64);
kernel->OutputAt(5).SetDataType(phi::DataType::INT64);
}
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/fft_grad_kernel.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fft_grad_kernel_impl.h"
......@@ -23,10 +24,14 @@ PD_REGISTER_KERNEL(fft_c2c_grad,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(
fft_c2r_grad, GPU, ALL_LAYOUT, phi::FFTC2RGradKernel, float, double) {}
fft_c2r_grad, GPU, ALL_LAYOUT, phi::FFTC2RGradKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(fft_r2c_grad,
GPU,
ALL_LAYOUT,
phi::FFTR2CGradKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/phi/kernels/fft_kernel.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/fft_kernel_impl.h"
......@@ -28,8 +29,8 @@ PD_REGISTER_KERNEL(fft_c2r,
phi::FFTC2RKernel,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
PD_REGISTER_KERNEL(fft_r2c, GPU, ALL_LAYOUT, phi::FFTR2CKernel, float, double) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
......@@ -117,7 +117,12 @@ PD_REGISTER_KERNEL(layer_norm_grad,
ALL_LAYOUT,
phi::LayerNormGradKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
#elif CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(layer_norm_grad,
GPU,
......@@ -126,7 +131,12 @@ PD_REGISTER_KERNEL(layer_norm_grad,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
#else
PD_REGISTER_KERNEL(layer_norm_grad,
GPU,
......@@ -134,5 +144,10 @@ PD_REGISTER_KERNEL(layer_norm_grad,
phi::LayerNormGradKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
#endif
......@@ -673,7 +673,10 @@ PD_REGISTER_KERNEL(layer_norm,
ALL_LAYOUT,
phi::LayerNormKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
}
#elif CUDNN_VERSION_MIN(8, 1, 0)
PD_REGISTER_KERNEL(layer_norm,
GPU,
......@@ -682,7 +685,10 @@ PD_REGISTER_KERNEL(layer_norm,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
}
#else
PD_REGISTER_KERNEL(layer_norm,
GPU,
......@@ -690,5 +696,8 @@ PD_REGISTER_KERNEL(layer_norm,
phi::LayerNormKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
}
#endif
......@@ -22,4 +22,9 @@ PD_REGISTER_KERNEL(merged_momentum,
phi::MergedMomentumKernel,
phi::dtype::float16,
float,
double) {}
double) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
......@@ -25,8 +25,10 @@ PD_REGISTER_KERNEL(momentum,
float,
double,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad,
......@@ -36,6 +38,8 @@ PD_REGISTER_KERNEL(momentum_dense_param_sparse_grad,
float,
double,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
......@@ -70,4 +70,6 @@ PD_REGISTER_KERNEL(sum_grad,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -188,7 +188,9 @@ PD_REGISTER_KERNEL(sgd,
phi::dtype::float16,
float,
double) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
}
}
PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad,
......
......@@ -146,13 +146,17 @@ PD_REGISTER_GENERAL_KERNEL(memcpy_d2h,
CPU,
ALL_LAYOUT,
phi::MemcpyD2HKernel<phi::CPUContext>,
ALL_DTYPE) {}
ALL_DTYPE) {
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
}
PD_REGISTER_GENERAL_KERNEL(memcpy_d2h_multi_io,
CPU,
ALL_LAYOUT,
phi::MemcpyD2HMultiIOKernel<phi::CPUContext>,
ALL_DTYPE) {}
ALL_DTYPE) {
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
}
PD_REGISTER_GENERAL_KERNEL(
memcpy, CPU, ALL_LAYOUT, phi::MemcpyKernel<phi::CPUContext>, ALL_DTYPE) {
......@@ -170,13 +174,17 @@ PD_REGISTER_GENERAL_KERNEL(memcpy_d2h,
GPU,
ALL_LAYOUT,
phi::MemcpyD2HKernel<phi::GPUContext>,
ALL_DTYPE) {}
ALL_DTYPE) {
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
}
PD_REGISTER_GENERAL_KERNEL(memcpy_d2h_multi_io,
GPU,
ALL_LAYOUT,
phi::MemcpyD2HMultiIOKernel<phi::GPUContext>,
ALL_DTYPE) {}
ALL_DTYPE) {
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
}
PD_REGISTER_GENERAL_KERNEL(
memcpy, GPU, ALL_LAYOUT, phi::MemcpyKernel<phi::GPUContext>, ALL_DTYPE) {
......@@ -196,12 +204,16 @@ PD_REGISTER_GENERAL_KERNEL(memcpy_d2h,
XPU,
ALL_LAYOUT,
phi::MemcpyD2HKernel<phi::XPUContext>,
ALL_DTYPE) {}
ALL_DTYPE) {
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
}
PD_REGISTER_GENERAL_KERNEL(memcpy_d2h_multi_io,
XPU,
ALL_LAYOUT,
phi::MemcpyD2HMultiIOKernel<phi::XPUContext>,
ALL_DTYPE) {}
ALL_DTYPE) {
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
}
#endif
......@@ -42,4 +42,5 @@ void SumGradKernel(const Context& dev_ctx,
PD_REGISTER_KERNEL(
sum_grad, OneDNN, ONEDNN, phi::SumGradKernel, float, phi::dtype::bfloat16) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -129,4 +129,9 @@ PD_REGISTER_KERNEL(layer_norm_grad,
ALL_LAYOUT,
phi::LayerNormGradKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
}
}
......@@ -69,7 +69,4 @@ PD_REGISTER_KERNEL(momentum,
ALL_LAYOUT,
phi::MomentumDenseKernel,
float,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::UNDEFINED);
kernel->OutputAt(2).SetDataType(phi::DataType::UNDEFINED);
}
phi::dtype::float16) {}
......@@ -71,4 +71,5 @@ void ReduceSumGradKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(sum_grad, XPU, ALL_LAYOUT, phi::ReduceSumGradKernel, float) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -913,6 +913,11 @@ if(WITH_DISTRIBUTE)
set_tests_properties(test_dist_fleet_raw_program_optimizer_fuse_allreduce
PROPERTIES TIMEOUT 60)
set_tests_properties(test_dist_dygraph_apis PROPERTIES TIMEOUT 120)
# NODE(Ruibiao): Remove it after static build is enabled by default.
set_tests_properties(
test_dist_mnist_fp16_allreduce test_dist_mnist_pg
PROPERTIES ENVIRONMENT FLAGS_new_executor_static_build=true)
endif()
# setting timeout value as 15S
......@@ -1229,3 +1234,52 @@ set_tests_properties(
set_tests_properties(
test_cuda_graph_static_mode_error
PROPERTIES ENVIRONMENT "FLAGS_CUDA_GRAPH_USE_STANDALONE_EXECUTOR=1")
# These UTs are to temporarily test static build for standalone_executor, will be removed after static build is enabled by default.
set(STATIC_BUILD_TESTS
test_adagrad_op
test_adamw_op
test_arg_min_max_op
test_bincount_op
test_decoupled_py_reader
test_fake_quantize_op
test_fetch_lod_tensor_array
test_imperative_optimizer
test_lamb_op
test_layer_norm_op
test_lookup_table_bf16_op
test_lookup_table_v2_op
test_matmul_op
test_matmul_v2_op
test_merged_adam_op
test_momentum_op
test_nce
test_paddle_save_load_binary
test_reduce_op
test_segment_ops
test_sparse_momentum_op
test_sgd_op_bf16
test_softmax_mask_fuse_upper_triangle_op
test_sparse_conv_op
test_sparse_norm_op
test_sparse_pooling_op
test_tensor_array_to_tensor
test_while_op
test_one_hot_v2_op)
foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
py_test_modules(
${STATIC_BUILD_TEST}_static_build MODULES ${STATIC_BUILD_TEST} ENVS
FLAGS_new_executor_static_build=true)
endforeach()
set_tests_properties(test_decoupled_py_reader_static_build PROPERTIES TIMEOUT
120)
set_tests_properties(test_imperative_optimizer_static_build PROPERTIES TIMEOUT
250)
set_tests_properties(test_matmul_op_static_build PROPERTIES TIMEOUT 120)
set_tests_properties(test_matmul_v2_op_static_build PROPERTIES TIMEOUT 120)
set_tests_properties(test_layer_norm_op_static_build PROPERTIES TIMEOUT 1500)
set_tests_properties(test_paddle_save_load_binary_static_build
PROPERTIES TIMEOUT 120)
set_tests_properties(test_reduce_op_static_build PROPERTIES TIMEOUT 500)
......@@ -24,6 +24,18 @@ py_test_modules(
test_standalone_executor_stats MODULES test_standalone_executor ENVS
FLAGS_host_trace_level=10 FLAGS_static_executor_perfstat_filepath=./perfstat)
# These UTs are to temporarily test static build for standalone_executor, will be removed after static build is enabled by default.
set(STATIC_BUILD_TESTS
test_standalone_controlflow test_standalone_cuda_graph_multi_stream
test_standalone_custom_stream test_standalone_executor
test_standalone_multiply_write)
foreach(STATIC_BUILD_TEST ${STATIC_BUILD_TESTS})
py_test_modules(
${STATIC_BUILD_TEST}_static_build MODULES ${STATIC_BUILD_TEST} ENVS
FLAGS_new_executor_static_build=true)
endforeach()
set_tests_properties(test_standalone_cross_step_overlap PROPERTIES TIMEOUT 30)
set_tests_properties(test_standalone_executor_aot_choose_kernel
PROPERTIES TIMEOUT 60)
......@@ -1705,6 +1705,7 @@ class TestDistBase(unittest.TestCase):
"http_proxy": "",
"NCCL_P2P_DISABLE": "1",
"NCCL_SHM_DISABLE": "1",
"FLAGS_new_executor_static_build": "1",
}
if check_error_log:
......
......@@ -69,13 +69,13 @@ class TestSoftmaxMaskFuseOp1(OpTest):
def test_check_output(self):
try:
self.check_output_with_place(core.CPUPlace())
except NotImplementedError:
except (NotImplementedError, RuntimeError):
pass
def test_check_grad(self):
try:
self.check_grad_with_place(core.CPUPlace(), ["X"], "Out")
except NotImplementedError:
except (NotImplementedError, RuntimeError):
pass
......
......@@ -49,6 +49,9 @@ if(WITH_TESTING)
py_test(test_multi_out_jit SRCS test_multi_out_jit.py)
py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py)
py_test(test_custom_concat SRCS test_custom_concat.py)
set_tests_properties(
test_custom_concat PROPERTIES ENVIRONMENT
FLAGS_new_executor_static_build=true)
py_test(test_custom_conj SRCS test_custom_conj.py)
py_test(test_custom_linear SRCS test_custom_linear.py)
py_test(test_custom_simple_slice SRCS test_custom_simple_slice.py)
......
......@@ -14,6 +14,14 @@ endif()
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach()
# NODE(Ruibiao): Remove it after static build is enabled by default.
if(WITH_MKLDNN AND NOT WIN32)
py_test_modules(
test_dequantize_mkldnn_op_static_build MODULES test_dequantize_mkldnn_op
ENVS FLAGS_new_executor_static_build=true)
endif()
set_tests_properties(test_concat_mkldnn_op PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv3d_mkldnn_op PROPERTIES TIMEOUT 120)
if(WITH_MKLDNN AND NOT WIN32)
......
......@@ -306,6 +306,12 @@ if [ "${HAS_MODIFIED_SETUP}" != "" ] || ([ "${HAS_MODIFIED_SETUP_IN}" != "" ] &&
check_approval 1 risemeup1 zhangbo9674
fi
HAS_MODIFIED_STATIC_BUILD=`git diff --name-only upstream/$BRANCH | grep "new_executor/interpreter/static_build.cc" || true`
if [ "${HAS_MODIFIED_STATIC_BUILD}" != "" ] && [ "${GIT_PR_ID}" != ""]; then
echo_line="You must have one RD (From00 or zhiqiu) approval for file changes in new_executor/interpreter/static_build.cc.\n"
check_approval 1 From00 zhiqiu
fi
ALL_PADDLE_ENFORCE=`git diff -U0 upstream/$BRANCH |grep "^+" |grep -zoE "PADDLE_ENFORCE\(.[^,\);]+.[^;]*\);\s" || true`
if [ "${ALL_PADDLE_ENFORCE}" != "" ] && [ "${GIT_PR_ID}" != "" ]; then
echo_line="PADDLE_ENFORCE is not recommended. Please use PADDLE_ENFORCE_EQ/NE/GT/GE/LT/LE or PADDLE_ENFORCE_NOT_NULL or PADDLE_ENFORCE_GPU_SUCCESS instead, see [ https://github.com/PaddlePaddle/Paddle/wiki/PADDLE_ENFORCE-Rewriting-Specification ] for details.\nYou must have one RD (chenwhql (Recommend), luotao1 (Recommend) or lanxianghit) approval for the usage (either add or delete) of PADDLE_ENFORCE.\n${ALL_PADDLE_ENFORCE}\n"
......
......@@ -20,6 +20,7 @@ disable_wingpu_test="^test_model$|\
^test_add_reader_dependency$|\
^test_add_reader_dependency_for_interpretercore$|\
^test_decoupled_py_reader$|\
^test_decoupled_py_reader_static_build$|\
^test_generator_dataloader$|\
^test_parallel_dygraph_sync_batch_norm$|\
^test_py_reader_using_executor$|\
......@@ -103,6 +104,7 @@ disable_win_inference_test="^trt_quant_int8_yolov3_r50_test$|\
^test_conv3d_transpose_part2_op$|\
^test_deform_conv2d$|\
^test_matmul_op$|\
^test_matmul_op_static_build$|\
^test_basic_api_transformation$|\
^test_deformable_conv_op$|\
^test_variable$|\
......@@ -153,6 +155,7 @@ disable_win_inference_test="^trt_quant_int8_yolov3_r50_test$|\
^test_add_reader_dependency_for_interpretercore$|\
^test_compat$|\
^test_decoupled_py_reader$|\
^test_decoupled_py_reader_static_build$|\
^test_generator_dataloader$|\
^test_py_reader_using_executor$|\
^test_dataloader_keep_order$|\
......@@ -223,6 +226,7 @@ long_time_test="^test_gru_op$|\
^test_imperative_lod_tensor_to_selected_rows$|\
^test_imperative_selected_rows_to_lod_tensor$|\
^test_layer_norm_op$|\
^test_layer_norm_op_static_build$|\
^test_multiclass_nms_op$|\
^test_nearest_interp_v2_op$|\
^test_nn_grad$|\
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册