// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/eager/legacy/tensor_helper.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/var_type_traits.h" #include "paddle/fluid/platform/place.h" namespace egr { void InitializeVariable(paddle::framework::Variable *var, paddle::framework::proto::VarType::Type var_type) { if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::FEED_MINIBATCH) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::FETCH_LIST) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::STEP_SCOPES) { var->GetMutable>(); } else if (var_type == paddle::framework::proto::VarType::LOD_RANK_TABLE) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::STRINGS) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::VOCAB) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::PLACE_LIST) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::READER) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::RAW) { // GetMutable will be called in operator } else { PADDLE_THROW(paddle::platform::errors::Unavailable( "paddle::framework::Variable type %d is not in " "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " "LOD_RANK_TABLE, PLACE_LIST, READER, RAW].", var_type)); } } void CopyVariable(const paddle::framework::Variable &src_var, paddle::framework::Variable *dst_var) { // only support cpu now auto cpu_place = paddle::platform::CPUPlace(); if (src_var.IsType()) { auto *tmp_grad_tensor = dst_var->GetMutable(); auto &src_tensor = src_var.Get(); tmp_grad_tensor->set_lod(src_tensor.lod()); paddle::framework::TensorCopy(src_tensor, cpu_place, tmp_grad_tensor); } else if (src_var.IsType()) { auto &src_slr = src_var.Get(); auto *tmp_grad_slr = dst_var->GetMutable(); tmp_grad_slr->set_rows(src_slr.rows()); tmp_grad_slr->set_height(src_slr.height()); auto &src_t = src_slr.value(); auto *dst_t = tmp_grad_slr->mutable_value(); paddle::framework::TensorCopy(src_t, cpu_place, dst_t); } else { PADDLE_THROW(paddle::platform::errors::Unavailable( "Unknown variable type to copy.")); } } paddle::framework::proto::VarType::Type GetDtypeFromVar( const paddle::framework::Variable &var) { if (var.IsType()) { return var.Get().type(); } else if (var.IsType()) { return var.Get().value().type(); } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "Variable type is %s, expect LoDTensor or SelectedRows.", paddle::framework::ToTypeName(var.Type()))); } } const paddle::platform::Place &GetPlaceFromVar( const paddle::framework::Variable &var) { if (var.IsType()) { return var.Get().place(); } else if (var.IsType()) { return var.Get().place(); } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "Variable type is %s, expect LoDTensor or SelectedRows.", paddle::framework::ToTypeName(var.Type()))); } } } // namespace egr