未验证 提交 1230a3f4 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] remove variable scope, stage2 (#43936)

* remove class ScopeBase

* reopen test
上级 ef1c8759
...@@ -27,7 +27,7 @@ namespace framework { ...@@ -27,7 +27,7 @@ namespace framework {
namespace details { namespace details {
// assert false when meets NAN or inf // assert false when meets NAN or inf
void CheckVarHasNanOrInf(const std::string& op_type, void CheckVarHasNanOrInf(const std::string& op_type,
const framework::ScopeBase& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
const platform::Place& place); const platform::Place& place);
...@@ -37,7 +37,7 @@ void CheckVarHasNanOrInf(const std::string& op_type, ...@@ -37,7 +37,7 @@ void CheckVarHasNanOrInf(const std::string& op_type,
const platform::Place& place); const platform::Place& place);
void CheckOpHasNanOrInf(const framework::OperatorBase& op, void CheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::ScopeBase& scope, const framework::Scope& scope,
const platform::Place& place); const platform::Place& place);
template <typename VarType> template <typename VarType>
...@@ -56,7 +56,7 @@ void CheckOpHasNanOrInfInDygraph(const std::string& op_type, ...@@ -56,7 +56,7 @@ void CheckOpHasNanOrInfInDygraph(const std::string& op_type,
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op, void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op,
const framework::ScopeBase& scope, const framework::Scope& scope,
const platform::Place& place); const platform::Place& place);
#endif #endif
......
...@@ -450,7 +450,7 @@ void CheckVarHasNanOrInf(const std::string& op_type, ...@@ -450,7 +450,7 @@ void CheckVarHasNanOrInf(const std::string& op_type,
} }
void CheckVarHasNanOrInf(const std::string& op_type, void CheckVarHasNanOrInf(const std::string& op_type,
const framework::ScopeBase& scope, const framework::Scope& scope,
const std::string& var_name, const std::string& var_name,
const platform::Place& place) { const platform::Place& place) {
auto* var = scope.FindVar(var_name); auto* var = scope.FindVar(var_name);
...@@ -486,7 +486,7 @@ static phi::DenseTensor& npu_float_status() { ...@@ -486,7 +486,7 @@ static phi::DenseTensor& npu_float_status() {
} }
void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op, void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op,
const framework::ScopeBase& scope, const framework::Scope& scope,
const platform::Place& place) { const platform::Place& place) {
if (!platform::is_npu_place(place)) return; if (!platform::is_npu_place(place)) return;
...@@ -555,7 +555,7 @@ void PrintNpuVarInfo(const std::string& op_type, ...@@ -555,7 +555,7 @@ void PrintNpuVarInfo(const std::string& op_type,
} }
void PrintNPUOpValueInfo(const framework::OperatorBase& op, void PrintNPUOpValueInfo(const framework::OperatorBase& op,
const framework::ScopeBase& scope, const framework::Scope& scope,
const platform::Place& place) { const platform::Place& place) {
LOG(WARNING) << "There are `nan` or `inf` in operator (" << op.Type() LOG(WARNING) << "There are `nan` or `inf` in operator (" << op.Type()
<< "), here we print some tensor value info of this op."; << "), here we print some tensor value info of this op.";
...@@ -573,7 +573,7 @@ void PrintNPUOpValueInfo(const framework::OperatorBase& op, ...@@ -573,7 +573,7 @@ void PrintNPUOpValueInfo(const framework::OperatorBase& op,
} }
static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op, static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::ScopeBase& scope, const framework::Scope& scope,
const platform::Place& place) { const platform::Place& place) {
if (!platform::is_npu_place(place)) return; if (!platform::is_npu_place(place)) return;
...@@ -609,7 +609,7 @@ static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op, ...@@ -609,7 +609,7 @@ static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op,
#endif #endif
void CheckOpHasNanOrInf(const framework::OperatorBase& op, void CheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::ScopeBase& exec_scope, const framework::Scope& exec_scope,
const platform::Place& place) { const platform::Place& place) {
std::call_once(white_list_init_flag, InitWhiteListFormEnv); std::call_once(white_list_init_flag, InitWhiteListFormEnv);
......
...@@ -140,29 +140,31 @@ ProgramDesc GetLmMainProgram() { ...@@ -140,29 +140,31 @@ ProgramDesc GetLmMainProgram() {
return main_prog; return main_prog;
} }
// TEST(StandaloneExecutor, run) { TEST(StandaloneExecutor, run) {
// auto place = platform::CUDAPlace(0); auto place = platform::CUDAPlace(0);
// ProgramDesc test_prog = load_from_file("lm_startup_program"); ProgramDesc startup_prog = load_from_file("lm_startup_program");
// ProgramDesc main_prog = GetLmMainProgram(); ProgramDesc main_prog = GetLmMainProgram();
// Scope scope;
// StandaloneExecutor exec(place, test_prog, main_prog, &scope);
// exec.Run({}, {}, {});
// auto start = std::chrono::steady_clock::now();
// for (size_t i = 0; i < 10; ++i) { Scope scope;
// if (i % 200 == 0) { StandaloneExecutor startup_exec(place, startup_prog);
// std::cout << i << std::endl; startup_exec.Run(&scope, {}, {});
// } StandaloneExecutor exec(place, main_prog);
exec.Run(&scope, {}, {});
auto start = std::chrono::steady_clock::now();
for (size_t i = 0; i < 10; ++i) {
if (i % 200 == 0) {
std::cout << i << std::endl;
}
// exec.Run({}, {}, {}); exec.Run(&scope, {}, {});
// } }
// auto end = std::chrono::steady_clock::now(); auto end = std::chrono::steady_clock::now();
// std::chrono::duration<double> diff = end - start; std::chrono::duration<double> diff = end - start;
// std::cout << "time cost " << diff.count() << std::endl; std::cout << "time cost " << diff.count() << std::endl;
// } }
TEST(InterpreterCore, skip_gc_vars) { TEST(InterpreterCore, skip_gc_vars) {
auto place = platform::CUDAPlace(0); auto place = platform::CUDAPlace(0);
......
...@@ -73,7 +73,7 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = { ...@@ -73,7 +73,7 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
std::make_tuple(platform::CPUPlace(), LibraryType::kPlain), std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
}; };
static DDim GetDimsDebug(const ScopeBase& scope, static DDim GetDimsDebug(const Scope& scope,
const std::string& name, const std::string& name,
bool get_actual_dim = false) { bool get_actual_dim = false) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
...@@ -97,13 +97,13 @@ static DDim GetDimsDebug(const ScopeBase& scope, ...@@ -97,13 +97,13 @@ static DDim GetDimsDebug(const ScopeBase& scope,
} }
} }
static bool VarInited(const ScopeBase& scope, const std::string& name) { static bool VarInited(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
if (var == nullptr) return false; if (var == nullptr) return false;
return var->IsInitialized(); return var->IsInitialized();
} }
static std::string GetDtype(const ScopeBase& scope, const std::string& name) { static std::string GetDtype(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
if (var == nullptr) { if (var == nullptr) {
return ""; return "";
...@@ -129,7 +129,7 @@ static std::string GetDtype(const ScopeBase& scope, const std::string& name) { ...@@ -129,7 +129,7 @@ static std::string GetDtype(const ScopeBase& scope, const std::string& name) {
} }
} }
static std::string GetPlace(const ScopeBase& scope, const std::string& name) { static std::string GetPlace(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
if (var == nullptr) { if (var == nullptr) {
return ""; return "";
...@@ -158,7 +158,7 @@ static std::string GetPlace(const ScopeBase& scope, const std::string& name) { ...@@ -158,7 +158,7 @@ static std::string GetPlace(const ScopeBase& scope, const std::string& name) {
} }
} }
static int GetRowSize(const ScopeBase& scope, const std::string& name) { static int GetRowSize(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
if (var == nullptr) { if (var == nullptr) {
return -1; return -1;
...@@ -171,7 +171,7 @@ static int GetRowSize(const ScopeBase& scope, const std::string& name) { ...@@ -171,7 +171,7 @@ static int GetRowSize(const ScopeBase& scope, const std::string& name) {
return -1; return -1;
} }
static LoD GetLoDDebug(const ScopeBase& scope, const std::string& name) { static LoD GetLoDDebug(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
auto default_lod = LoD({{}}); auto default_lod = LoD({{}});
...@@ -349,7 +349,7 @@ const std::vector<std::string>& OperatorBase::Outputs( ...@@ -349,7 +349,7 @@ const std::vector<std::string>& OperatorBase::Outputs(
return it->second; return it->second;
} }
std::string OperatorBase::DebugStringEx(const ScopeBase* scope) const { std::string OperatorBase::DebugStringEx(const Scope* scope) const {
std::stringstream ss; std::stringstream ss;
ss << "Op(" << type_ << "), inputs:{"; ss << "Op(" << type_ << "), inputs:{";
......
...@@ -168,7 +168,7 @@ class OperatorBase { ...@@ -168,7 +168,7 @@ class OperatorBase {
virtual void Stop() {} virtual void Stop() {}
/// if scope is not null, also show dimensions of arguments /// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(const ScopeBase* scope) const; virtual std::string DebugStringEx(const Scope* scope) const;
std::string DebugString() const { return DebugStringEx(nullptr); } std::string DebugString() const { return DebugStringEx(nullptr); }
virtual bool SupportGPU() const { return false; } virtual bool SupportGPU() const { return false; }
......
...@@ -38,17 +38,6 @@ class Variable; ...@@ -38,17 +38,6 @@ class Variable;
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// TODO(zhiqiu): add more function in base class
class ScopeBase {
public:
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
/// Caller doesn't own the returned Variable.
virtual Variable* FindVar(const std::string& name) const = 0;
virtual ~ScopeBase() {}
};
/** /**
* @brief Scope that manage all variables. * @brief Scope that manage all variables.
* *
...@@ -57,7 +46,7 @@ class ScopeBase { ...@@ -57,7 +46,7 @@ class ScopeBase {
* One net can run in different scopes and update different variable in the * One net can run in different scopes and update different variable in the
* scope. * scope.
*/ */
class Scope : public ScopeBase { class Scope {
public: public:
Scope() {} Scope() {}
~Scope(); ~Scope();
......
...@@ -79,7 +79,7 @@ class BKCLCommunicator; ...@@ -79,7 +79,7 @@ class BKCLCommunicator;
namespace framework { namespace framework {
class LoDRankTable; class LoDRankTable;
class ScopeBase; class Scope;
class ReaderHolder; class ReaderHolder;
class Scope; class Scope;
} // namespace framework } // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册