// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/imperative/var_helper.h" #include "paddle/fluid/eager/eager_tensor.h" #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/var_type_traits.h" #include "paddle/fluid/imperative/layer.h" #include "paddle/fluid/platform/place.h" #include "paddle/pten/core/selected_rows.h" namespace paddle { namespace imperative { /* GetVariableWrapper */ template <> const std::shared_ptr &GetVariableWrapper( const std::shared_ptr &var) { return var->SharedVar(); } template <> const std::shared_ptr &GetVariableWrapper( const std::shared_ptr &var) { return var; } void InitializeVariable(paddle::framework::Variable *var, paddle::framework::proto::VarType::Type var_type) { if (var_type == paddle::framework::proto::VarType::LOD_TENSOR) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::SELECTED_ROWS) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::FEED_MINIBATCH) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::FETCH_LIST) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::STEP_SCOPES) { var->GetMutable>(); } else if (var_type == paddle::framework::proto::VarType::LOD_RANK_TABLE) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::LOD_TENSOR_ARRAY) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::STRINGS) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::VOCAB) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::PLACE_LIST) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::READER) { var->GetMutable(); } else if (var_type == paddle::framework::proto::VarType::RAW) { // GetMutable will be called in operator } else { PADDLE_THROW(paddle::platform::errors::Unavailable( "paddle::framework::Variable type %d is not in " "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " "LOD_RANK_TABLE, PLACE_LIST, READER, RAW].", var_type)); } } /* GetPlace */ template const paddle::platform::Place &GetPlace(const std::shared_ptr &var) { paddle::framework::Variable variable = var->Var(); if (variable.IsType()) { return variable.Get().place(); } else if (variable.IsType()) { return variable.Get().place(); } else { PADDLE_THROW(paddle::platform::errors::InvalidArgument( "Variable type is %s, expect LoDTensor or SelectedRows.", paddle::framework::ToTypeName(var->Var().Type()))); } } template const paddle::platform::Place &GetPlace( const std::shared_ptr &var); template const paddle::platform::Place &GetPlace( const std::shared_ptr &var); template const paddle::platform::Place &GetPlace( const std::shared_ptr &var); /* GetNameFromVar */ template const std::string &GetNameFromVar(std::shared_ptr var) { return var->Name(); } template <> const std::string &GetNameFromVar( std::shared_ptr tensor) { return tensor->name(); } template const std::string &GetNameFromVar( std::shared_ptr var); template const std::string &GetNameFromVar( std::shared_ptr var); /* SetType */ template void SetType(std::shared_ptr var, framework::proto::VarType::Type type) { var->SetType(type); } template <> void SetType(std::shared_ptr var, framework::proto::VarType::Type type) { switch (type) { case paddle::framework::proto::VarType::LOD_TENSOR: { var->MutableVar()->GetMutable(); break; } case paddle::framework::proto::VarType::SELECTED_ROWS: { var->MutableVar()->GetMutable(); break; } default: { PADDLE_THROW(paddle::platform::errors::NotFound( "Cannot found var type: %s while running runtime InferVarType", paddle::framework::ToTypeName(type))); } } } template void SetType(std::shared_ptr var, framework::proto::VarType::Type type); template void SetType(std::shared_ptr var, framework::proto::VarType::Type type); /* GetType */ template framework::proto::VarType::Type GetType(std::shared_ptr var) { return var->Type(); } template <> framework::proto::VarType::Type GetType( std::shared_ptr var) { if (var->Var().IsInitialized()) { return paddle::framework::ToVarType(var->Var().Type()); } else { return paddle::framework::proto::VarType::LOD_TENSOR; } } template framework::proto::VarType::Type GetType( std::shared_ptr var); template framework::proto::VarType::Type GetType( std::shared_ptr var); /* GetDataType */ template framework::proto::VarType::Type GetDataType(std::shared_ptr var) { return var->DataType(); } template <> framework::proto::VarType::Type GetDataType( std::shared_ptr var) { if (var->Var().IsType()) { return framework::TransToProtoVarType( var->Var().Get().value().type()); } else if (var->Var().IsType()) { return framework::TransToProtoVarType( var->Var().Get().type()); } else { PADDLE_THROW(paddle::platform::errors::PermissionDenied( "We only support pten::SelectedRows and framework::LoDTensor in " "eager mode, but we got %s here, please checkout your var type of " "tensor: %s", paddle::framework::ToTypeName(framework::ToVarType(var->Var().Type())), var->name())); } } template framework::proto::VarType::Type GetDataType( std::shared_ptr var); template framework::proto::VarType::Type GetDataType( std::shared_ptr var); /* CheckCachedKey */ template bool CheckCachedKey(std::shared_ptr var, const paddle::framework::OpKernelType &key) { return GetVariableWrapper(var)->hasCacheKey(key); } template <> bool CheckCachedKey( std::shared_ptr tensor, const paddle::framework::OpKernelType &key) { // TODO(jiabin): Support this later // VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key is // equal to self: " << key == key. return false; } template bool CheckCachedKey( std::shared_ptr var, const paddle::framework::OpKernelType &key); template bool CheckCachedKey( std::shared_ptr var, const paddle::framework::OpKernelType &key); /* GetCachedValue */ template std::shared_ptr GetCachedValue( std::shared_ptr var, const paddle::framework::OpKernelType &key) { return GetVariableWrapper(var)->getCacheValue(key); } template <> std::shared_ptr GetCachedValue( std::shared_ptr var, const paddle::framework::OpKernelType &key) { // TODO(jiabin): Support this later // PADDLE_THROW(platform::errors::Fatal("In eager mode program should not // reach this, support cache and remove this error check later, or this // should not be supported.")); // VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key // is equal to self: " << key == key. return std::make_shared(""); } template std::shared_ptr GetCachedValue( std::shared_ptr var, const paddle::framework::OpKernelType &key); template std::shared_ptr GetCachedValue( std::shared_ptr var, const paddle::framework::OpKernelType &key); /* SetCachedValue */ template void SetCachedValue(std::shared_ptr var, const paddle::framework::OpKernelType &key, std::shared_ptr res) { GetVariableWrapper(var)->setCacheValue(key, GetVariableWrapper(res)); } template <> void SetCachedValue( std::shared_ptr tensor, const paddle::framework::OpKernelType &key, std::shared_ptr res) { // PADDLE_THROW(platform::errors::Fatal("In eager mode program should not // reach this, support cache and remove this error check later, or this // should not be supported.")); // VLOG(10) << "CheckCachedKey with tensor: " << tensor->name() << "and key // is equal to self: " << key == key << " and res name is:" << res->Name(). } template void SetCachedValue( std::shared_ptr var, const paddle::framework::OpKernelType &key, std::shared_ptr res); template void SetCachedValue( std::shared_ptr var, const paddle::framework::OpKernelType &key, std::shared_ptr res); } // namespace imperative } // namespace paddle