dist_strategy (paddle.fleet.distributed_strategy): used to determine the user defined distributed strategy.
dist_context (paddle.fluid.DistributedContext): used to access the distributed_attr of var & op, every Partitioner object could maintain its own DistributedContext member, and partition program base on that shard scenario.
rank_id (int): global rank id to which the partitioned distributed program belong.
take serial forward programs with shard annotation, create a new distributed forward programs based on the serial ones.
instead of modify the input programs inplace, this function will preserve the inputs and create new program for output.
beside replace the serial op with its dist op, if user has defined other strategy in fleet.distributed_strategy, and if
those strategy need to transpile (modify) the forward network program, those forward program modification should also be done within this
function in auto parallel scenario, in order to facilitate distributed inference/evaluation which need to DECOUPLE strategy specific forward transpilation with fleet.distributed_optimizer.minimize().
by now the fleet.distributed_strategy that need transpile forward program are following:
1. (optimizer) sharding
Args:
main_program (paddle.fluid.framework.program): serial main program with forward network only
startup_program (paddle.fluid.framework.program): serial startup program with forward network only
return:
main_program (paddle.fluid.framework.program): distributed main program with forward network only
startup_program (paddle.fluid.framework.program): distributed startup program with forward network only