未验证 提交 b0941102 编写于 作者: W wanghuancoder 提交者: GitHub

fix some bug, test=develop (#36888)

上级 093c4ec5
...@@ -241,13 +241,14 @@ void InterpreterCore::BuildInplace() { ...@@ -241,13 +241,14 @@ void InterpreterCore::BuildInplace() {
auto& outputs = instr.Outputs(); auto& outputs = instr.Outputs();
for (auto& pair : in_to_outs) { for (auto& pair : in_to_outs) {
auto iter = inputs.find(pair.first); auto iter = inputs.find(pair.first);
if (iter != inputs.end()) { if (iter != inputs.end() && !iter->second.empty()) {
if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) { if (BuildInplaceCheckVarIsOnlyInput(iter->second[0])) {
auto iterout = outputs.find(pair.second); auto iterout = outputs.find(pair.second);
if (iterout != outputs.end()) { if (iterout != outputs.end() && !iterout->second.empty()) {
auto invar = global_scope_->Var(iter->second[0]); auto invar = global_scope_->Var(iter->second[0]);
auto outvar = global_scope_->Var(iterout->second[0]); auto outvar = global_scope_->Var(iterout->second[0]);
if (invar && outvar) { if (invar && outvar && invar->IsType<LoDTensor>() &&
outvar->IsType<LoDTensor>()) {
instr.AddInplace(invar, outvar); instr.AddInplace(invar, outvar);
VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type() VLOG(3) << "inplace " << vec_instruction_[i].OpBase()->Type()
<< " " << global_scope_->GetNameById(iter->second[0]) << " " << global_scope_->GetNameById(iter->second[0])
......
...@@ -142,8 +142,8 @@ void build_variable_scope(const framework::ProgramDesc& pdesc, ...@@ -142,8 +142,8 @@ void build_variable_scope(const framework::ProgramDesc& pdesc,
if (nullptr == var_scope->FindVar(var_name)) { if (nullptr == var_scope->FindVar(var_name)) {
var_scope->AddVar(var_desc->Name(), var_desc); var_scope->AddVar(var_desc->Name(), var_desc);
} else { } else {
auto* var_desc = var_scope->VarDesc(var_name); auto* var_desc_tmp = var_scope->VarDesc(var_name);
if (nullptr == var_desc) { if (nullptr == var_desc_tmp) {
VLOG(3) << "update var:" << var_name << " desc from nullptr into " VLOG(3) << "update var:" << var_name << " desc from nullptr into "
<< var_desc; << var_desc;
var_scope->VarMetaInfo(var_name).vardesc_ = var_desc; var_scope->VarMetaInfo(var_name).vardesc_ = var_desc;
...@@ -206,9 +206,22 @@ void apply_device_guard(const OperatorBase* op_base, ...@@ -206,9 +206,22 @@ void apply_device_guard(const OperatorBase* op_base,
VLOG(3) << "Switch into CPUPlace by device_guard."; VLOG(3) << "Switch into CPUPlace by device_guard.";
expected_kernel_key->place_ = platform::CPUPlace(); expected_kernel_key->place_ = platform::CPUPlace();
} else if (op_device.find("gpu") != std::string::npos && } else if (op_device.find("gpu") != std::string::npos &&
platform::is_gpu_place(place)) { (platform::is_gpu_place(place) ||
VLOG(3) << "Switch into " << place << " by device_guard."; platform::is_npu_place(place))) {
// when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
// will be executed and a warning will be given at the same time.
if (op_base->SupportGPU()) {
expected_kernel_key->place_ = place; expected_kernel_key->place_ = place;
} else if (op_base->SupportNPU()) {
expected_kernel_key->place_ = place;
} else {
expected_kernel_key->place_ = platform::CPUPlace();
LOG_FIRST_N(WARNING, 1)
<< "Op(" << op_base->Type()
<< ") has no CUDA implementation. It will be assigned to CPUPlace.";
}
VLOG(3) << "Switch into " << expected_kernel_key->place_
<< " by device_guard.";
} else { } else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("Unsupported current place %s", op_device)); platform::errors::Fatal("Unsupported current place %s", op_device));
......
...@@ -474,6 +474,15 @@ struct VariableMetaInfo { ...@@ -474,6 +474,15 @@ struct VariableMetaInfo {
// TODO(zhiqiu): Maybe we need to add rwlock for VariableScope? // TODO(zhiqiu): Maybe we need to add rwlock for VariableScope?
class VariableScope : public ScopeBase { class VariableScope : public ScopeBase {
public: public:
VariableScope() {
// for @EMPTY@ variable
var_list_.push_back(nullptr);
name2id_[kEmptyVarName] = 0;
VariableMetaInfo info;
info.var_ref_count_ = 0;
info.vardesc_ = nullptr;
vec_meta_info_.push_back(info);
}
Variable* FindVar(const std::string& name) const { Variable* FindVar(const std::string& name) const {
auto it = name2id_.find(name); auto it = name2id_.find(name);
if (it != name2id_.end()) { if (it != name2id_.end()) {
......
...@@ -77,12 +77,35 @@ class FetchV2Op : public framework::OperatorWithKernel { ...@@ -77,12 +77,35 @@ class FetchV2Op : public framework::OperatorWithKernel {
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor, const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const framework::OpKernelType &expected_kernel_type) const override {
if (!tensor.IsInitialized()) {
return expected_kernel_type;
}
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout()); tensor.place(), tensor.layout());
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto *fetch_var = ctx.InputVar("X");
if (fetch_var == nullptr) {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
}
if (fetch_var->IsType<framework::LoDTensor>()) {
auto &src_item = fetch_var->Get<framework::LoDTensor>();
if (!src_item.IsInitialized()) {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
}
} else {
auto &src_item = fetch_var->Get<framework::LoDTensorArray>();
if (src_item.empty() || !src_item[0].IsInitialized()) {
return framework::OpKernelType(framework::proto::VarType::FP32,
platform::CPUPlace());
}
}
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace()); platform::CPUPlace());
...@@ -127,6 +150,9 @@ class FetchV2Kernel { ...@@ -127,6 +150,9 @@ class FetchV2Kernel {
if (fetch_var->IsType<framework::LoDTensor>()) { if (fetch_var->IsType<framework::LoDTensor>()) {
auto &src_item = fetch_var->Get<framework::LoDTensor>(); auto &src_item = fetch_var->Get<framework::LoDTensor>();
if (!src_item.IsInitialized()) {
return;
}
auto *dst_item = &(BOOST_GET(framework::LoDTensor, fetch_list->at(col))); auto *dst_item = &(BOOST_GET(framework::LoDTensor, fetch_list->at(col)));
bool check_place = platform::is_cpu_place(src_item.place()) || bool check_place = platform::is_cpu_place(src_item.place()) ||
platform::is_cuda_pinned_place(src_item.place()); platform::is_cuda_pinned_place(src_item.place());
...@@ -173,9 +199,7 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -173,9 +199,7 @@ class FetchV2OpProtoMaker : public framework::OpProtoAndCheckerMaker {
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
FetchV2 Operator. FetchV2 Operator.
It should not be configured by users directly. It should not be configured by users directly.
)DOC"); )DOC");
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册