提交 7dbab103 编写于 作者: S sandyhouse

update, test=develop

上级 d7c5e849
...@@ -13,7 +13,6 @@ limitations under the License. */ ...@@ -13,7 +13,6 @@ limitations under the License. */
#include <float.h> #include <float.h>
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
......
...@@ -736,6 +736,60 @@ class DistributedStrategy(object): ...@@ -736,6 +736,60 @@ class DistributedStrategy(object):
"sharding_configs") "sharding_configs")
assign_configs_value(self.strategy.sharding_configs, configs) assign_configs_value(self.strategy.sharding_configs, configs)
@property
def model_parallel(self):
"""
Indicating whether we are using model parallel parallelism for distributed training.
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.model_parallel = True
"""
return self.strategy.model_parallel
@model_parallel.setter
@is_strict_auto
def model_parallel(self, flag):
if isinstance(flag, bool):
self.strategy.model_parallel = flag
else:
print("WARNING: model_parallel should have value of bool type")
@property
def model_parallel_configs(self):
"""
Set model_parallel parallelism configurations.
**Notes**:
**Detailed arguments for model_parallel_configs**
**parallelism**: degree of model parallel
Examples:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.model_parallel = True
strategy.model_parallel_configs = {"parallelism": 12}
"""
return get_msg_dict(self.strategy.model_parallel_configs)
@model_parallel_configs.setter
@is_strict_auto
def model_parallel_configs(self, configs):
check_configs_key(self.strategy.model_parallel_configs, configs,
"model_parallel_configs")
assign_configs_value(self.strategy.model_parallel_configs, configs)
@property @property
def pipeline(self): def pipeline(self):
""" """
......
...@@ -50,7 +50,8 @@ class AMPOptimizer(MetaOptimizerBase): ...@@ -50,7 +50,8 @@ class AMPOptimizer(MetaOptimizerBase):
self.inner_opt, amp_lists, config['init_loss_scaling'], self.inner_opt, amp_lists, config['init_loss_scaling'],
config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'], config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'],
config['incr_ratio'], config['decr_ratio'], config['incr_ratio'], config['decr_ratio'],
config['use_dynamic_loss_scaling']) config['use_dynamic_loss_scaling'], config['use_pure_fp16'],
config['use_fp16_guard'])
# if worker_num > 1, all cards will communication with each other, # if worker_num > 1, all cards will communication with each other,
# add is_distributed to optimize amp, overlap communication and # add is_distributed to optimize amp, overlap communication and
...@@ -113,3 +114,11 @@ class AMPOptimizer(MetaOptimizerBase): ...@@ -113,3 +114,11 @@ class AMPOptimizer(MetaOptimizerBase):
self.wrapped_opt.minimize(loss, startup_program, self.wrapped_opt.minimize(loss, startup_program,
parameter_list, no_grad_set) parameter_list, no_grad_set)
return optimize_ops, params_grads return optimize_ops, params_grads
def amp_init(self,
place,
scope=None,
test_program=None,
use_fp16_test=False):
return self.wrapped_opt.amp_init(place, scope, test_program,
use_fp16_test)
...@@ -87,8 +87,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -87,8 +87,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self._as_outer_parallelism = self.user_defined_strategy.sharding_configs[ self._as_outer_parallelism = self.user_defined_strategy.sharding_configs[
"as_outer_parallelism"] "as_outer_parallelism"]
self._inner_parallelism_size = int( self._inner_parallelism_size = int(
self.user_defined_strategy.sharding_configs[ self.user_defined_strategy.sharding_configs["parallelism"])
"inner_parallelism_size"])
self.use_pipeline = self.user_defined_strategy.sharding_configs[ self.use_pipeline = self.user_defined_strategy.sharding_configs[
"use_pipeline"] "use_pipeline"]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册