未验证 提交 1230a3f4 编写于 作者: L Leo Chen 提交者: GitHub

[new-exec] remove variable scope, stage2 (#43936)

* remove class ScopeBase

* reopen test
上级 ef1c8759
......@@ -27,7 +27,7 @@ namespace framework {
namespace details {
// assert false when meets NAN or inf
void CheckVarHasNanOrInf(const std::string& op_type,
const framework::ScopeBase& scope,
const framework::Scope& scope,
const std::string& var_name,
const platform::Place& place);
......@@ -37,7 +37,7 @@ void CheckVarHasNanOrInf(const std::string& op_type,
const platform::Place& place);
void CheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::ScopeBase& scope,
const framework::Scope& scope,
const platform::Place& place);
template <typename VarType>
......@@ -56,7 +56,7 @@ void CheckOpHasNanOrInfInDygraph(const std::string& op_type,
#ifdef PADDLE_WITH_ASCEND_CL
void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op,
const framework::ScopeBase& scope,
const framework::Scope& scope,
const platform::Place& place);
#endif
......
......@@ -450,7 +450,7 @@ void CheckVarHasNanOrInf(const std::string& op_type,
}
void CheckVarHasNanOrInf(const std::string& op_type,
const framework::ScopeBase& scope,
const framework::Scope& scope,
const std::string& var_name,
const platform::Place& place) {
auto* var = scope.FindVar(var_name);
......@@ -486,7 +486,7 @@ static phi::DenseTensor& npu_float_status() {
}
void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op,
const framework::ScopeBase& scope,
const framework::Scope& scope,
const platform::Place& place) {
if (!platform::is_npu_place(place)) return;
......@@ -555,7 +555,7 @@ void PrintNpuVarInfo(const std::string& op_type,
}
void PrintNPUOpValueInfo(const framework::OperatorBase& op,
const framework::ScopeBase& scope,
const framework::Scope& scope,
const platform::Place& place) {
LOG(WARNING) << "There are `nan` or `inf` in operator (" << op.Type()
<< "), here we print some tensor value info of this op.";
......@@ -573,7 +573,7 @@ void PrintNPUOpValueInfo(const framework::OperatorBase& op,
}
static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::ScopeBase& scope,
const framework::Scope& scope,
const platform::Place& place) {
if (!platform::is_npu_place(place)) return;
......@@ -609,7 +609,7 @@ static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op,
#endif
void CheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::ScopeBase& exec_scope,
const framework::Scope& exec_scope,
const platform::Place& place) {
std::call_once(white_list_init_flag, InitWhiteListFormEnv);
......
......@@ -140,29 +140,31 @@ ProgramDesc GetLmMainProgram() {
return main_prog;
}
// TEST(StandaloneExecutor, run) {
// auto place = platform::CUDAPlace(0);
// ProgramDesc test_prog = load_from_file("lm_startup_program");
// ProgramDesc main_prog = GetLmMainProgram();
// Scope scope;
// StandaloneExecutor exec(place, test_prog, main_prog, &scope);
// exec.Run({}, {}, {});
// auto start = std::chrono::steady_clock::now();
TEST(StandaloneExecutor, run) {
auto place = platform::CUDAPlace(0);
ProgramDesc startup_prog = load_from_file("lm_startup_program");
ProgramDesc main_prog = GetLmMainProgram();
// for (size_t i = 0; i < 10; ++i) {
// if (i % 200 == 0) {
// std::cout << i << std::endl;
// }
Scope scope;
StandaloneExecutor startup_exec(place, startup_prog);
startup_exec.Run(&scope, {}, {});
StandaloneExecutor exec(place, main_prog);
exec.Run(&scope, {}, {});
auto start = std::chrono::steady_clock::now();
for (size_t i = 0; i < 10; ++i) {
if (i % 200 == 0) {
std::cout << i << std::endl;
}
// exec.Run({}, {}, {});
// }
exec.Run(&scope, {}, {});
}
// auto end = std::chrono::steady_clock::now();
// std::chrono::duration<double> diff = end - start;
auto end = std::chrono::steady_clock::now();
std::chrono::duration<double> diff = end - start;
// std::cout << "time cost " << diff.count() << std::endl;
// }
std::cout << "time cost " << diff.count() << std::endl;
}
TEST(InterpreterCore, skip_gc_vars) {
auto place = platform::CUDAPlace(0);
......
......@@ -73,7 +73,7 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
};
static DDim GetDimsDebug(const ScopeBase& scope,
static DDim GetDimsDebug(const Scope& scope,
const std::string& name,
bool get_actual_dim = false) {
Variable* var = scope.FindVar(name);
......@@ -97,13 +97,13 @@ static DDim GetDimsDebug(const ScopeBase& scope,
}
}
static bool VarInited(const ScopeBase& scope, const std::string& name) {
static bool VarInited(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
if (var == nullptr) return false;
return var->IsInitialized();
}
static std::string GetDtype(const ScopeBase& scope, const std::string& name) {
static std::string GetDtype(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
return "";
......@@ -129,7 +129,7 @@ static std::string GetDtype(const ScopeBase& scope, const std::string& name) {
}
}
static std::string GetPlace(const ScopeBase& scope, const std::string& name) {
static std::string GetPlace(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
return "";
......@@ -158,7 +158,7 @@ static std::string GetPlace(const ScopeBase& scope, const std::string& name) {
}
}
static int GetRowSize(const ScopeBase& scope, const std::string& name) {
static int GetRowSize(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
if (var == nullptr) {
return -1;
......@@ -171,7 +171,7 @@ static int GetRowSize(const ScopeBase& scope, const std::string& name) {
return -1;
}
static LoD GetLoDDebug(const ScopeBase& scope, const std::string& name) {
static LoD GetLoDDebug(const Scope& scope, const std::string& name) {
Variable* var = scope.FindVar(name);
auto default_lod = LoD({{}});
......@@ -349,7 +349,7 @@ const std::vector<std::string>& OperatorBase::Outputs(
return it->second;
}
std::string OperatorBase::DebugStringEx(const ScopeBase* scope) const {
std::string OperatorBase::DebugStringEx(const Scope* scope) const {
std::stringstream ss;
ss << "Op(" << type_ << "), inputs:{";
......
......@@ -168,7 +168,7 @@ class OperatorBase {
virtual void Stop() {}
/// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(const ScopeBase* scope) const;
virtual std::string DebugStringEx(const Scope* scope) const;
std::string DebugString() const { return DebugStringEx(nullptr); }
virtual bool SupportGPU() const { return false; }
......
......@@ -38,17 +38,6 @@ class Variable;
namespace paddle {
namespace framework {
// TODO(zhiqiu): add more function in base class
class ScopeBase {
public:
/// Find a variable in the scope or any of its ancestors. Returns
/// nullptr if cannot find.
/// Caller doesn't own the returned Variable.
virtual Variable* FindVar(const std::string& name) const = 0;
virtual ~ScopeBase() {}
};
/**
* @brief Scope that manage all variables.
*
......@@ -57,7 +46,7 @@ class ScopeBase {
* One net can run in different scopes and update different variable in the
* scope.
*/
class Scope : public ScopeBase {
class Scope {
public:
Scope() {}
~Scope();
......
......@@ -79,7 +79,7 @@ class BKCLCommunicator;
namespace framework {
class LoDRankTable;
class ScopeBase;
class Scope;
class ReaderHolder;
class Scope;
} // namespace framework
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册