#pragma once #include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/lod_tensor.h" namespace paddle { namespace custom_trainer { namespace feed { class ScopeHelper { public: //直接取var template static const T& var(paddle::framework::Scope* scope, const std::string& name) { return scope->Var(name)->Get(); } template static T* mutable_var(paddle::framework::Scope* scope, const std::string& name) { return scope->Var(name)->GetMutable(); } template static T* resize_variable(paddle::framework::Scope* scope, const std::string& name, const paddle::framework::DDim& dim) { auto* tensor = scope->Var(name)->GetMutable(); tensor->Resize(dim); return tensor; } static paddle::framework::LoDTensor* resize_lod_tensor( paddle::framework::Scope* scope, const std::string& name, const paddle::framework::DDim& dim) { return resize_variable(scope, name, dim); } template static void fill_value(paddle::framework::Scope* scope, paddle::platform::Place place, const std::string& name, T& value) { auto* tensor = resize_variable(scope, name, { 1 }); T* data = tensor->mutable_data(place); *data = value; return; } template static T* get_value(paddle::framework::Scope* scope, paddle::platform::Place place, const std::string& name) { auto* tensor = scope->Var(name)->GetMutable(); return tensor->mutable_data(place); } }; } // namespace feed } // namespace custom_trainer } // namespace paddle