提交 a82beb19 编写于 作者: L lixinqi

add ssp variable_proxy


Former-commit-id: ead9222f14b4395f211260f9f2feef1e83daf2db
上级 6924f279
......@@ -26,8 +26,8 @@ REGISTER_FUNCTION_CONFIG_DEF()
.ListInt64("ssp_partition_scope_ids", {}, "type: list[int64]. ssp partition scope symbol ids");
REGISTER_SCOPE_CONFIG_DEF()
.Int64("ssp_num_stages", -1, "total number of ssp stages")
.Int64("ssp_stage_id", -1, "current ssp stage id ");
.Int64("num_stages", -1, "total number of stages")
.Int64("stage_id", -1, "current stage id ");
} // namespace
......
......@@ -125,8 +125,8 @@ class AddSspVariableProxyPass final : public JobPass {
const Scope& scope = JUST(Global<vm::SymbolStorage<Scope>>::Get()->MaybeGet(scope_symbol_id));
int64_t buffer_size = 0;
{
int64_t num_stages = scope.Int64("ssp_num_stages");
int64_t stage_id = scope.Int64("ssp_stage_id");
int64_t num_stages = scope.Int64("num_stages");
int64_t stage_id = scope.Int64("stage_id");
CHECK_GT(num_stages, 0);
CHECK_GE(stage_id, 0);
CHECK_LT(stage_id, num_stages);
......
......@@ -139,7 +139,7 @@ class Test1dSspVariableProxy(flow.unittest.TestCase):
with flow.scope.placement(
"cpu", device_name
), flow.experimental.scope.config(
ssp_num_stages=buffer_size, ssp_stage_id=0
num_stages=buffer_size, stage_id=0
):
w = flow.get_variable(
"w",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册