From 0837a2ccad11cdc42b498ee02374ebceff0a177d Mon Sep 17 00:00:00 2001 From: jianghaicheng Date: Wed, 19 Jan 2022 18:22:32 +0800 Subject: [PATCH] ipu python interface p1 (#38096) * ipu_commit_tests p1 * resolve comments * resolve comments * resolve comments * resolve comments * resolve comments * resolve comments * resolve comments * update lint and ipustrategy introduction * update ipu_config * update __init__ of static * update doc * update doc 2 * update doc 3 * update doc 4 * update doc 5 * update doc 5 * update doc 6 * update lint * update lint 2 * update ipustrategy * add IpuStrategy to all * update ipustrategy * update ipu_shard_guard * update ipu_shard_guard 2 Co-authored-by: yaozhixin <522190855@qq.com> --- paddle/fluid/pybind/pybind.cc | 50 +-- python/paddle/fluid/compiler.py | 419 +++++++++++++++++- python/paddle/fluid/executor.py | 3 + python/paddle/fluid/framework.py | 85 +++- .../fluid/tests/unittests/CMakeLists.txt | 4 + python/paddle/static/__init__.py | 6 + 6 files changed, 525 insertions(+), 42 deletions(-) diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 63f1e817137..47f97944b2d 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -3575,17 +3575,14 @@ All parameter, weight, gradient are variables in Paddle. .def("set_scope", &platform::ipu::IpuBackend::SetScope) .def("set_ipu_strategy", &platform::ipu::IpuBackend::SetIpuStrategy); - py::class_(m, "IpuStrategy") - .def(py::init()) + py::class_ ipu_strategy(m, "IpuStrategy"); + ipu_strategy.def(py::init()) .def_property( "num_ipus", [](const platform::ipu::IpuStrategy &self) { return self.num_ipus; }, [](platform::ipu::IpuStrategy &self, int num_ipus) { self.num_ipus = num_ipus; - }, - R"DOC( - Int type, set the number ipu we need. Default 1. - )DOC") + }) .def_property( "accumulationFactor", [](const platform::ipu::IpuStrategy &self) { @@ -3593,31 +3590,21 @@ All parameter, weight, gradient are variables in Paddle. }, [](platform::ipu::IpuStrategy &self, int accumulationFactor) { self.popart_options_.accumulationFactor = accumulationFactor; - }, - R"DOC( - Specify the number of micro-batches to accumulate before - applying the varUpdate. Default 1. - )DOC") + }) .def_property("batches_per_step", [](const platform::ipu::IpuStrategy &self) { return self.batches_per_step; }, [](platform::ipu::IpuStrategy &self, int batches_per_step) { self.batches_per_step = batches_per_step; - }, - R"DOC( - Int type, set batches_per_step. Default 1. - )DOC") + }) .def_property("is_training", [](const platform::ipu::IpuStrategy &self) { return self.is_training; }, [](platform::ipu::IpuStrategy &self, bool is_training) { self.is_training = is_training; - }, - R"DOC( - Bool type, True for training, False inference. Default True. - )DOC") + }) .def_property( "enable_pipelining", [](const platform::ipu::IpuStrategy &self) { @@ -3625,10 +3612,7 @@ All parameter, weight, gradient are variables in Paddle. }, [](platform::ipu::IpuStrategy &self, bool enable_pipelining) { self.popart_options_.enablePipelining = enable_pipelining; - }, - R"DOC( - Bool type, True enable pipeline, otherwise disable. Default False. - )DOC") + }) .def_property( "enable_manual_shard", [](const platform::ipu::IpuStrategy &self) { @@ -3643,40 +3627,28 @@ All parameter, weight, gradient are variables in Paddle. self.popart_options_.virtualGraphMode = platform::ipu::VirtualGraphMode::Off; } - }, - R"DOC( - Bool type, True enable model sharding, otherwise disable. Default " - "False. - )DOC") + }) .def_property("need_avg_shard", [](const platform::ipu::IpuStrategy &self) { return self.need_avg_shard; }, [](platform::ipu::IpuStrategy &self, bool need_avg_shard) { self.need_avg_shard = need_avg_shard; - }, - R"DOC( - Bool type, True enable avg shard, otherwise disable. Default False. - )DOC") + }) .def_property("batch_size", [](const platform::ipu::IpuStrategy &self) { return self.batch_size; }, [](platform::ipu::IpuStrategy &self, int batch_size) { self.batch_size = batch_size; - }, - R"DOC( - Int type, used to make batch size fixed. Default 1. - )DOC") + }) .def_property("enable_fp16", [](const platform::ipu::IpuStrategy &self) { return self.enable_fp16; }, [](platform::ipu::IpuStrategy &self, bool enable_fp16) { self.enable_fp16 = enable_fp16; - }, - R"DOC( - Bool type, True enable float16 mode, otherwise disable. Default False.)DOC"); + }); #endif BindFleetWrapper(&m); diff --git a/python/paddle/fluid/compiler.py b/python/paddle/fluid/compiler.py index 1fa86d0aeea..7e3dfde5d4f 100644 --- a/python/paddle/fluid/compiler.py +++ b/python/paddle/fluid/compiler.py @@ -22,7 +22,10 @@ from .framework import _get_paddle_place, _get_paddle_place_list from .framework import cuda_places, cpu_places, xpu_places from . import core -__all__ = ['CompiledProgram', 'ExecutionStrategy', 'BuildStrategy'] +__all__ = [ + 'CompiledProgram', 'ExecutionStrategy', 'BuildStrategy', + 'IpuCompiledProgram', 'IpuStrategy' +] ExecutionStrategy = core.ParallelExecutor.ExecutionStrategy BuildStrategy = core.ParallelExecutor.BuildStrategy @@ -493,3 +496,417 @@ class CompiledProgram(object): place_list = cpu_places() assert place_list, "No places for execution." return place_list + + +class IpuStrategy(object): + """ + Help users precisely control the graph building in :code:`paddle.static.IpuCompiledProgram` . + + Args: + None. + + Returns: + The IpuStrategy instance. + + Examples: + .. code-block:: python + + # required: ipu + + import paddle + import paddle.static as static + + paddle.enable_static() + ipu_strategy = static.IpuStrategy() + """ + + def __init__(self): + if core.is_compiled_with_ipu(): + self._ipu_strategy = core.IpuStrategy() + else: + raise RuntimeError( + "Can not use IpuStrategy in non IPU compiled environment, please re-compile with WITH_IPU=ON." + ) + + def SetGraphConfig(self, + num_ipus=1, + is_training=True, + batch_size=1, + enable_manual_shard=False, + need_avg_shard=False): + """ + Set graph configuration to the IpuStrategy instance. + + Args: + num_ipus (int, optional): Number of IPU devices. Default 1, which means only use 1 IPU. + is_training (bool, optional): True is training graph, False is inference graph. Default True, which means is training mode. + batch_size (int, optional): The batch-size in the graph. Used to make the graph batch-size fixed, + if the batch-size in the graph is dynamic. Default 1, which means the batch-size would be set 1, if the batch-size is dynamice. + enable_manual_shard (bool, optional): Enable graph sharding or not. Only if num_ipus > 1, enable_manual_shard is able to be set True. + Default False, which means disabled. + need_avg_shard (bool, optional): Enable auto graph sharding or not. Only if num_ipus > 1 and enable_manual_shard=True, need_avg_shard is able to be set Trues. + Default False, which means disabled. + + Returns: + None. + + Examples: + .. code-block:: python + + # required: ipu + + import paddle + import paddle.static as static + + paddle.enable_static() + ipu_strategy = static.IpuStrategy() + ipu_strategy.SetGraphConfig(num_ipus=1, + is_training=True, + batch_size=1, + enable_manual_shard=False, + need_avg_shard=False) + """ + + self._ipu_strategy.num_ipus = num_ipus + self._ipu_strategy.is_training = is_training + self._ipu_strategy.batch_size = batch_size + self._ipu_strategy.enable_manual_shard = enable_manual_shard + if self._ipu_strategy.num_ipus == 1 and self._ipu_strategy.enable_manual_shard: + raise RuntimeError( + "Only if num_ipus > 1, enable_manual_shard is able to be set True." + ) + self._ipu_strategy.need_avg_shard = need_avg_shard + if self._ipu_strategy.enable_manual_shard != True and self._ipu_strategy.need_avg_shard: + raise RuntimeError( + "Only if enable_manual_shard=True, need_avg_shard is able to be set True." + ) + + def SetPipeliningConfig(self, + enable_pipelining=False, + batches_per_step=1, + accumulationFactor=1): + """ + Set pipelining configuration to the IpuStrategy instance. Used to optimize the throughput performance. + + Args: + enable_pipelining (bool, optional): Enable data pipelining between subgraphs. Only if enable_manual_shard=True, enable_pipelining is able to be set True. + Default False, which means disabled. + batches_per_step (int, optional): Set the batches per run in data pipelining mode. Only if enable_pipelining=True, batches_per_step is able to be set > 1. + Default 1, which means no data pipelining. + accumulationFactor (int, optional): Specify the number of micro-batches to accumulate + before applying the varUpdate. Default 1, which means disable the accumulation. + + Returns: + None. + + Examples: + .. code-block:: python + + # required: ipu + + import paddle + import paddle.static as static + + paddle.enable_static() + + ipu_strategy = static.IpuStrategy() + ipu_strategy.SetPipeliningConfig(enable_pipelining=False, + batches_per_step=1, + accumulationFactor=1) + """ + self._ipu_strategy.enable_pipelining = enable_pipelining + if self._ipu_strategy.enable_manual_shard != True and self._ipu_strategy.enable_pipelining: + raise RuntimeError( + "Only if enable_manual_shard=True, enable_pipelining is able to be set True." + ) + self._ipu_strategy.batches_per_step = batches_per_step + if self._ipu_strategy.enable_pipelining != True and self._ipu_strategy.batches_per_step > 1: + raise RuntimeError( + "Only if enable_pipelining=True, batches_per_step is able to be set > 1." + ) + self._ipu_strategy.accumulationFactor = accumulationFactor + + def SetHalfConfig(self, enable_fp16=False): + """ + Set half computation configuration to the IpuStrategy instance. Used to optimize the performance. + + Args: + enable_fp16 (bool, optional): Enable FLOAT16 mode and transform FLOAT32 to FLOAT16. Default False, which means disable FLOAT16 mode. + + Returns: + None. + + Examples: + .. code-block:: python + + # required: ipu + + import paddle + import paddle.static as static + + paddle.enable_static() + + ipu_strategy = static.IpuStrategy() + ipu_strategy.SetHalfConfig(enable_fp16=False) + """ + + self._ipu_strategy.enable_fp16 = enable_fp16 + + @property + def num_ipus(self): + """ + Get the number of IPU devices from IpuStrategy instance. + """ + return self._ipu_strategy.num_ipus + + @property + def is_training(self): + """ + Get the boolean of training or inference from IpuStrategy instance. + """ + return self._ipu_strategy.is_training + + @property + def batch_size(self): + """ + Get the batch_size used in dynamic batch_size graph from IpuStrategy instance. + """ + return self._ipu_strategy.batch_size + + @property + def enable_manual_shard(self): + """ + Get the boolean of enable manual shard or not from IpuStrategy instance. + """ + return self._ipu_strategy.enable_manual_shard + + @property + def need_avg_shard(self): + """ + Get the boolean of need average shard or not from IpuStrategy instance. + """ + return self._ipu_strategy.need_avg_shard + + @property + def enable_pipelining(self): + """ + Get the boolean of enable pipelining or not from IpuStrategy instance. + """ + return self._ipu_strategy.enable_pipelining + + @property + def batches_per_step(self): + """ + Get the number of batch_size per run in the pipelining mode from IpuStrategy instance. + """ + return self._ipu_strategy.batches_per_step + + @property + def accumulationFactor(self): + """ + Get the number of micro-batches to accumulate before applying the varUpdate from IpuStrategy instance. + """ + return self._ipu_strategy.accumulationFactor + + @property + def enable_fp16(self): + """ + Get the boolean of float16 mode or not from IpuStrategy instance. + """ + return self._ipu_strategy.enable_fp16 + + +class IpuCompiledProgram(object): + """ + The IpuCompiledProgram is used to transform a program to a ipu-target program, + such as forward graph extraction, computing graph transformation, useless scale Ops clean, etc. + + Args: + program(Program, optional): This parameter represents the :code:`Program` + to be executed. Default is None, which means the program will be set to + the default program :code:`paddle.static.default_main_program()` . + scope(Scope, optional): The scope used to run this program, you can switch + it to different scope. Default is None, which means use the global + scope :code:`paddle.static.global_scope()` . + ipu_strategy(IpuStrategy, optional): This argument is used to build the program with the + specified options, such as half computation, training or inference session, the number of IPUs, etc. + Default is None, which means build the program based on the default `ipu_strategy`. + + Returns: + IpuCompiledProgram + + Example: + .. code-block:: python + + # required: ipu + + import paddle + import paddle.static as static + + paddle.enable_static() + + a = static.data(name='data', shape=[None, 1], dtype='int32') + b = a + 1 + main_prog = static.default_main_program() + + ipu_strategy = static.IpuStrategy() + ipu_strategy.SetGraphConfig(num_ipus=1, is_training=True, batch_size=1) + ipu_strategy.SetPipeliningConfig(enable_pipelining=False, batches_per_step=1, accumulationFactor=1) + ipu_strategy.SetHalfConfig(enable_fp16=False) + + ipu_compiled_program = static.IpuCompiledProgram( + main_prog, + ipu_strategy=ipu_strategy) + """ + + def __init__(self, program=None, scope=None, ipu_strategy=None): + if not core.is_compiled_with_ipu(): + raise ValueError( + "Can not use this function since PaddlePaddle is not compiled with IPU" + ) + + if program is None: + program = default_main_program() + + if not isinstance(program, framework.Program): + raise TypeError( + "The type of program is wrong, expected Program, but got %s" % + type(program)) + # import here to avoiding confused + import paddle + + self._program = program + self._compiled = False + + if scope is not None: + self._scope = scope + else: + self._scope = paddle.static.global_scope() + + if ipu_strategy is not None: + self._ipu_strategy = ipu_strategy._ipu_strategy + else: + self._ipu_strategy = core.IpuStrategy() + + self._backend = core.IpuBackend() + self._backend.set_scope(self._scope) + self._backend.set_ipu_strategy(self._ipu_strategy) + self._graph_passes = [ + "optimizer_extract_pass", "optimizer_state_align_pass", + "forward_graph_extract_pass", "infer_shape_pass", "avg_shard_pass", + "popart_canonicalization_pass" + ] + global ipu_compiler_ref + ipu_compiler_ref = self + + def compile(self, feed_list, fetch_list): + """ + This interface is used to compile the input Program to a program + to run the model on the ipu. + + Args: + feed_list(list): This parameter represents the input Tensors of the model. + + fetch_list(list): This parameter represents the Tensors that need to be returned + after the model. + + Returns: + Program + + Example: + .. code-block:: python + + # required: ipu + + import paddle + import paddle.static as static + + paddle.enable_static() + + a = static.data(name='data', shape=[None, 1], dtype='int32') + b = a + 1 + main_prog = static.default_main_program() + + ipu_strategy = static.IpuStrategy() + ipu_strategy.SetGraphConfig(num_ipus=1, is_training=True, batch_size=1) + ipu_strategy.SetPipeliningConfig(enable_pipelining=False, batches_per_step=1, accumulationFactor=1) + ipu_strategy.SetHalfConfig(enable_fp16=False) + + program = static.IpuCompiledProgram( + main_prog, + ipu_strategy=ipu_strategy).compile([a.name], [b.name]) + """ + # feed and fetch doesn't have corresponding popart op, so we rm both here + global_block = self._program.global_block() + need_to_remove_op_index = [] + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "feed" or op.type == "fetch": + need_to_remove_op_index.append(i) + + for index in need_to_remove_op_index[::-1]: + global_block._remove_op(index) + + for var in ['feed', 'fetch']: + if global_block.has_var(var): + global_block._remove_var(var) + + self._program.desc.flush() + self._graph = core.Graph(self._program.desc) + + for pass_name in self._graph_passes: + graph_pass = core.get_pass(pass_name) + if pass_name == "infer_shape_pass": + graph_pass.set("feed_list", feed_list) + graph_pass.apply(self._graph) + + ipu_inplace_pass = core.get_pass("ipu_inplace_pass") + ipu_inplace_pass.set("feed_list", feed_list) + ipu_inplace_pass.set("fetch_list", fetch_list) + ipu_inplace_pass.apply(self._graph) + + ipu_graph_builder_pass = core.get_pass("ipu_graph_builder_pass") + ipu_graph_builder_pass.set("feed_list", feed_list) + ipu_graph_builder_pass.set("fetch_list", fetch_list) + ipu_graph_builder_pass.apply(self._graph) + + ipu_runtime_replacer_pass = core.get_pass("ipu_runtime_replacer_pass") + ipu_runtime_replacer_pass.set("feed_list", feed_list) + ipu_runtime_replacer_pass.set("fetch_list", fetch_list) + ipu_runtime_replacer_pass.apply(self._graph) + + convert_pass = core.get_pass('graph_to_program_pass') + desc = core.ProgramDesc() + convert_pass.set_not_owned('program', desc) + convert_pass.apply(self._graph) + program = framework.Program._construct_from_desc(desc) + + if hasattr(self._program, 'lr_sheduler'): + # how to share var between two different block ? + lr_var_name = self._program.lr_sheduler._var_name + + program.lr_sheduler = self._program.lr_sheduler + # Program.clone will clone lr_sheduler, so i set lr_var as + # lr_sheduler attribute + global_block = self._program.global_block() + program.lr_sheduler.lr_var = global_block.vars[lr_var_name] + + # with popart, we need to support batches_per_step, what means + # the shape of feed_var and feed_tensor(maybe numpy array) will + # mismatch, so we set need_check_feed to False. Thus we can avoid + # modify logic of run. + program_global_block = program.global_block() + for feed_name in feed_list: + feed_var = program_global_block.var(feed_name) + feed_var.desc.set_need_check_feed(False) + + if not hasattr(program, 'org_program'): + program.org_program = self._program + + return program + + def clean(self): + self._backend.clear() + + def __del__(self): + self.clean() diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index d67d4944c69..5ae1403f632 100644 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -1581,6 +1581,9 @@ class Executor(object): lr_sheduler = program.lr_sheduler lr_value = lr_sheduler() lr_var = program.global_block().vars[lr_sheduler._var_name] + if core.is_compiled_with_ipu(): + if hasattr(program.lr_sheduler, 'lr_var'): + lr_var = program.lr_sheduler.lr_var data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype)) tensor = core.get_variable_tensor(scope, lr_sheduler._var_name) tensor.set(data, self.place) diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index a26e322cbd9..5ee7b04248e 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -48,6 +48,7 @@ __all__ = [ 'default_main_program', 'program_guard', 'name_scope', + 'ipu_shard_guard', 'cuda_places', 'cpu_places', 'xpu_places', @@ -104,6 +105,65 @@ def _test_eager_guard(tracer=None): _C_ops.switch_to_core_ops() +global_ipu_index = None +global_ipu_stage = None +ipu_index_attr_name = 'ipu_index' +ipu_stage_attr_name = 'ipu_stage' + + +@signature_safe_contextmanager +def ipu_shard_guard(index=None, stage=None): + """ + Used to shard the graph on IPUs. Set each Op run on which IPU in the sharding and which stage in the pipelining. + + Args: + index(int, optional): Specify which ipu the Tensor is computed on, (such as ‘0, 1, 2, 3’). + The default value is None, which means the Op only run on IPU 0. + stage(int, optional): Specify the computation order of the sharded model(such as ‘0, 1, 2, 3’). + The sharded model will be computed from small to large. The default value is None, + which means no pipelining computation order and run Ops in terms of graph. + + **Note**: + Only if the enable_manual_shard=True, the ‘index’ is able to be set not None. Please refer + to :code:`paddle.static.IpuStrategy` . + Only if the enable_pipelining=True, the ‘stage’ is able to be set not None. Please refer + to :code:`paddle.static.IpuStrategy` . + A index is allowed to match none stage or a stage. A stage is only allowed to match a new or + duplicated index. + + Examples: + .. code-block:: python + + # required: ipu + + import paddle + paddle.enable_static() + a = paddle.static.data(name='data', shape=[None, 1], dtype='int32') + with paddle.static.ipu_shard_guard(index=0, stage=0): + b = a + 1 + with paddle.static.ipu_shard_guard(index=1, stage=1): + c = b + 1 + with paddle.static.ipu_shard_guard(index=0, stage=2): + d = c + 1 + """ + if not core.is_compiled_with_ipu(): + raise ValueError( + "Can not use this function since PaddlePaddle is not compiled with IPU" + ) + + global global_ipu_index + global global_ipu_stage + prev_ipu_index = global_ipu_index + prev_ipu_stage = global_ipu_stage + global_ipu_index = index + global_ipu_stage = stage + try: + yield + finally: + global_ipu_index = prev_ipu_index + global_ipu_stage = prev_ipu_stage + + def require_version(min_version, max_version=None): """ Check if the installed version of PaddlePaddle is in [min_version, max_version], @@ -2573,6 +2633,15 @@ class Operator(object): attr_val = op_attrs[attr_name] self._update_desc_attr(attr_name, attr_val) + # proto.attrs doesn't include ipu_index + if core.is_compiled_with_ipu(): + if global_ipu_index is not None: + self._update_desc_attr(ipu_index_attr_name, + global_ipu_index) + if global_ipu_stage is not None: + self._update_desc_attr(ipu_stage_attr_name, + global_ipu_stage) + self.desc.check_attrs() if self._has_kernel(type): self.desc.infer_var_type(self.block.desc) @@ -6845,7 +6914,7 @@ def _get_paddle_place(place): return place if isinstance(place, (core.Place, core.XPUPlace, core.CPUPlace, core.CUDAPinnedPlace, core.CUDAPlace, core.NPUPlace, - core.MLUPlace)): + core.IPUPlace, core.MLUPlace)): return place if not isinstance(place, str): @@ -6900,6 +6969,18 @@ def _get_paddle_place(place): device_id = int(device_id) return core.NPUPlace(device_id) + # IPU + avaliable_ipu_place = re.match(r'ipu:\d+', place) + if avaliable_ipu_place: + if not core.is_compiled_with_ipu(): + raise ValueError( + "The device should not be {}, since PaddlePaddle is " \ + "not compiled with IPU".format(avaliable_ipu_place)) + place_info_list = place.split(':', 1) + device_id = place_info_list[1] + device_id = int(device_id) + return core.IPUPlace(device_id) + # MLU avaliable_mlu_place = re.match(r'mlu:\d+', place) if avaliable_mlu_place: @@ -6913,7 +6994,7 @@ def _get_paddle_place(place): return core.MLUPlace(device_id) raise ValueError( - "Paddle supports CPUPlace, CUDAPlace,CUDAPinnedPlace, XPUPlace, MLUPlace and NPUPlace, but received {}.". + "Paddle supports CPUPlace, CUDAPlace,CUDAPinnedPlace, XPUPlace, IPUPlace, MLUPlace and NPUPlace, but received {}.". format(place)) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index c0c13866ccd..915af18a570 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -805,6 +805,10 @@ if (WITH_MKLDNN) add_subdirectory(mkldnn) endif() +if (WITH_IPU) + add_subdirectory(ipu) +endif() + if (WITH_MLU) add_subdirectory(mlu) endif() diff --git a/python/paddle/static/__init__.py b/python/paddle/static/__init__.py index f18b77997a5..bce77380d1f 100644 --- a/python/paddle/static/__init__.py +++ b/python/paddle/static/__init__.py @@ -35,6 +35,8 @@ from ..fluid.backward import append_backward # noqa: F401 from ..fluid.backward import gradients # noqa: F401 from ..fluid.compiler import BuildStrategy # noqa: F401 from ..fluid.compiler import CompiledProgram # noqa: F401 +from ..fluid.compiler import IpuCompiledProgram # noqa: F401 +from ..fluid.compiler import IpuStrategy # noqa: F401 from ..fluid.compiler import ExecutionStrategy # noqa: F401 from ..fluid.framework import default_main_program # noqa: F401 from ..fluid.framework import default_startup_program # noqa: F401 @@ -48,6 +50,7 @@ from ..fluid.framework import xpu_places # noqa: F401 from ..fluid.framework import mlu_places # noqa: F401 from ..fluid.framework import npu_places # noqa: F401 from ..fluid.framework import Variable # noqa: F401 +from ..fluid.framework import ipu_shard_guard # noqa: F401 from ..fluid.layers.control_flow import Print # noqa: F401 from ..fluid.layers.nn import py_func # noqa: F401 from ..fluid.parallel_executor import ParallelExecutor # noqa: F401 @@ -74,6 +77,9 @@ __all__ = [ #noqa 'scope_guard', 'BuildStrategy', 'CompiledProgram', + 'ipu_shard_guard', + 'IpuCompiledProgram', + 'IpuStrategy', 'Print', 'py_func', 'ExecutionStrategy', -- GitLab