未验证 提交 186bbebf 编写于 作者: C Chengmo 提交者: GitHub

【Paddle.Fleet】【Cherry-Pick】fix grad_clip & gaussian_random & dataset & profiler (#31945)

* Remove PE special profiler (#30886)

* remove pe special profiler

* add profiler info

* add truncated gaussian random (#30922)

add truncated gaussian random

* 【Paddle.Fleet】fix dataset zip py3 bug (#31441)

* fix zip py3 bug

* 【Paddle.Fleet】Fix one ps gradient clip  (#31664)

* fix one ps gradient clip
上级 8140485a
...@@ -385,6 +385,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches, ...@@ -385,6 +385,7 @@ void Communicator::SendGlobalStep(const CommContext &ctx, int batches,
if (batches == 0) { if (batches == 0) {
return; return;
} }
platform::RecordEvent record_event("Communicator->SendGlobalStep");
auto &table_id = ctx.table_id; auto &table_id = ctx.table_id;
size_t request_call_num = _worker_ptr->get_server_nums(); size_t request_call_num = _worker_ptr->get_server_nums();
...@@ -788,6 +789,7 @@ void SyncCommunicator::BarrierRecv() { ...@@ -788,6 +789,7 @@ void SyncCommunicator::BarrierRecv() {
void GeoCommunicator::Send(const std::vector<std::string> &var_names, void GeoCommunicator::Send(const std::vector<std::string> &var_names,
const framework::Scope &scope) { const framework::Scope &scope) {
platform::RecordEvent record_event("GeoCommunicator->Send");
waiting_ = false; waiting_ = false;
auto before_send = GetCurrentUS(); auto before_send = GetCurrentUS();
auto table_name = var_names[0]; auto table_name = var_names[0];
...@@ -1024,6 +1026,7 @@ void GeoCommunicator::InitSparse(const std::string &var_name, int table_id) { ...@@ -1024,6 +1026,7 @@ void GeoCommunicator::InitSparse(const std::string &var_name, int table_id) {
std::vector<int64_t> GeoCommunicator::MergeSparseIds( std::vector<int64_t> GeoCommunicator::MergeSparseIds(
const std::string &send_varname) { const std::string &send_varname) {
platform::RecordEvent record_event("GeoCommunicator->MergeSparseIds");
size_t merge_num = 0, wait_times = 0; size_t merge_num = 0, wait_times = 0;
std::unordered_set<int64_t> sparse_ids; std::unordered_set<int64_t> sparse_ids;
while (merge_num < static_cast<size_t>(max_merge_var_num_)) { while (merge_num < static_cast<size_t>(max_merge_var_num_)) {
......
...@@ -28,6 +28,8 @@ void CommonDenseTable::create_initializer(const std::string& attr, ...@@ -28,6 +28,8 @@ void CommonDenseTable::create_initializer(const std::string& attr,
initializers_[name] = new FillConstantInitializer(slices); initializers_[name] = new FillConstantInitializer(slices);
} else if (slices[0] == "uniform_random") { } else if (slices[0] == "uniform_random") {
initializers_[name] = new UniformInitializer(slices); initializers_[name] = new UniformInitializer(slices);
} else if (slices[0] == "truncated_gaussian_random") {
initializers_[name] = new TruncatedGaussianInitializer(slices);
} else { } else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::InvalidArgument("%s can not be supported", name)); platform::errors::InvalidArgument("%s can not be supported", name));
......
...@@ -17,12 +17,15 @@ ...@@ -17,12 +17,15 @@
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <random>
#include <string> #include <string>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/operators/truncated_gaussian_random_op.h"
namespace paddle { namespace paddle {
namespace distributed { namespace distributed {
...@@ -108,6 +111,40 @@ class GaussianInitializer : public Initializer { ...@@ -108,6 +111,40 @@ class GaussianInitializer : public Initializer {
std::normal_distribution<float> dist_; std::normal_distribution<float> dist_;
}; };
class TruncatedGaussianInitializer : public Initializer {
public:
explicit TruncatedGaussianInitializer(const std::vector<std::string> &attrs) {
name_ = attrs[0];
seed_ = static_cast<unsigned int>(std::stoi(attrs[1]));
mean_ = std::stof(attrs[2]);
std_ = std::stof(attrs[3]);
std::uniform_real_distribution<float> dist_(
std::numeric_limits<float>::min(), 1.0);
random_engine_ = framework::GetCPURandomEngine(seed_);
}
float GetValue() override {
paddle::operators::TruncatedNormal<float> truncated_normal(mean_, std_);
float value = truncated_normal(dist_(*random_engine_));
return value;
}
void GetValue(float *value, int numel) {
paddle::operators::TruncatedNormal<float> truncated_normal(mean_, std_);
for (int x = 0; x < numel; ++x) {
value[x] = truncated_normal(dist_(*random_engine_));
}
}
private:
float std_;
float mean_;
std::shared_ptr<std::mt19937_64> random_engine_;
std::uniform_real_distribution<float> dist_;
};
class FillConstantInitializer : public Initializer { class FillConstantInitializer : public Initializer {
public: public:
explicit FillConstantInitializer(const std::vector<std::string> &attrs) { explicit FillConstantInitializer(const std::vector<std::string> &attrs) {
......
...@@ -125,6 +125,9 @@ class ValueBlock { ...@@ -125,6 +125,9 @@ class ValueBlock {
} else if (slices[0] == "uniform_random") { } else if (slices[0] == "uniform_random") {
initializers_.emplace_back( initializers_.emplace_back(
std::make_shared<UniformInitializer>(slices)); std::make_shared<UniformInitializer>(slices));
} else if (slices[0] == "truncated_gaussian_random") {
initializers_.emplace_back(
std::make_shared<TruncatedGaussianInitializer>(slices));
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s can not be supported", attr)); "%s can not be supported", attr));
......
...@@ -1121,8 +1121,6 @@ void ParallelExecutor::BCastParamsToDevices( ...@@ -1121,8 +1121,6 @@ void ParallelExecutor::BCastParamsToDevices(
FetchResultType ParallelExecutor::Run( FetchResultType ParallelExecutor::Run(
const std::vector<std::string> &fetch_tensors, bool return_merged) { const std::vector<std::string> &fetch_tensors, bool return_merged) {
VLOG(3) << "enter ParallelExecutor Run"; VLOG(3) << "enter ParallelExecutor Run";
platform::RecordEvent parallel_executor_event(
"ParallelExecutor::Run", paddle::platform::EventRole::kSpecial);
#ifdef WITH_GPERFTOOLS #ifdef WITH_GPERFTOOLS
if (gProfileStarted) { if (gProfileStarted) {
ProfilerFlush(); ProfilerFlush();
......
...@@ -32,11 +32,11 @@ class DataGenerator(object): ...@@ -32,11 +32,11 @@ class DataGenerator(object):
''' '''
Set batch size of current DataGenerator Set batch size of current DataGenerator
This is necessary only if a user wants to define generator_batch This is necessary only if a user wants to define generator_batch
Example: Example:
.. code-block:: python .. code-block:: python
import paddle.distributed.fleet.data_generator as dg import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator): class MyData(dg.DataGenerator):
...@@ -52,7 +52,7 @@ class DataGenerator(object): ...@@ -52,7 +52,7 @@ class DataGenerator(object):
yield ("words", s[1].extend([s[1][0]])) yield ("words", s[1].extend([s[1][0]]))
mydata = MyData() mydata = MyData()
mydata.set_batch(128) mydata.set_batch(128)
''' '''
self.batch_size_ = batch_size self.batch_size_ = batch_size
...@@ -63,7 +63,7 @@ class DataGenerator(object): ...@@ -63,7 +63,7 @@ class DataGenerator(object):
Example: Example:
.. code-block:: python .. code-block:: python
import paddle.distributed.fleet.data_generator as dg import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator): class MyData(dg.DataGenerator):
...@@ -100,9 +100,9 @@ class DataGenerator(object): ...@@ -100,9 +100,9 @@ class DataGenerator(object):
generated. generated.
Example: Example:
.. code-block:: python .. code-block:: python
import paddle.distributed.fleet.data_generator as dg import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator): class MyData(dg.DataGenerator):
...@@ -161,7 +161,7 @@ class DataGenerator(object): ...@@ -161,7 +161,7 @@ class DataGenerator(object):
The data format is list or tuple: The data format is list or tuple:
[(name, [feasign, ...]), ...] [(name, [feasign, ...]), ...]
or ((name, [feasign, ...]), ...) or ((name, [feasign, ...]), ...)
For example: For example:
[("words", [1926, 08, 17]), ("label", [1])] [("words", [1926, 08, 17]), ("label", [1])]
or (("words", [1926, 08, 17]), ("label", [1])) or (("words", [1926, 08, 17]), ("label", [1]))
...@@ -174,7 +174,7 @@ class DataGenerator(object): ...@@ -174,7 +174,7 @@ class DataGenerator(object):
Example: Example:
.. code-block:: python .. code-block:: python
import paddle.distributed.fleet.data_generator as dg import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator): class MyData(dg.DataGenerator):
...@@ -206,7 +206,7 @@ class DataGenerator(object): ...@@ -206,7 +206,7 @@ class DataGenerator(object):
Example: Example:
.. code-block:: python .. code-block:: python
import paddle.distributed.fleet.data_generator as dg import paddle.distributed.fleet.data_generator as dg
class MyData(dg.DataGenerator): class MyData(dg.DataGenerator):
...@@ -259,6 +259,9 @@ class MultiSlotStringDataGenerator(DataGenerator): ...@@ -259,6 +259,9 @@ class MultiSlotStringDataGenerator(DataGenerator):
Returns: Returns:
Return a string data that can be read directly by the MultiSlotDataFeed. Return a string data that can be read directly by the MultiSlotDataFeed.
''' '''
if sys.version > '3' and isinstance(line, zip):
line = list(line)
if not isinstance(line, list) and not isinstance(line, tuple): if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError( raise ValueError(
"the output of process() must be in list or tuple type" "the output of process() must be in list or tuple type"
...@@ -289,7 +292,7 @@ class MultiSlotDataGenerator(DataGenerator): ...@@ -289,7 +292,7 @@ class MultiSlotDataGenerator(DataGenerator):
>>> [ids_num id1 id2 ...] ... >>> [ids_num id1 id2 ...] ...
The proto_info will be in this format: The proto_info will be in this format:
>>> [(name, type), ...] >>> [(name, type), ...]
For example, if the input is like this: For example, if the input is like this:
>>> [("words", [1926, 08, 17]), ("label", [1])] >>> [("words", [1926, 08, 17]), ("label", [1])]
>>> or (("words", [1926, 08, 17]), ("label", [1])) >>> or (("words", [1926, 08, 17]), ("label", [1]))
...@@ -304,6 +307,9 @@ class MultiSlotDataGenerator(DataGenerator): ...@@ -304,6 +307,9 @@ class MultiSlotDataGenerator(DataGenerator):
Returns: Returns:
Return a string data that can be read directly by the MultiSlotDataFeed. Return a string data that can be read directly by the MultiSlotDataFeed.
''' '''
if sys.version > '3' and isinstance(line, zip):
line = list(line)
if not isinstance(line, list) and not isinstance(line, tuple): if not isinstance(line, list) and not isinstance(line, tuple):
raise ValueError( raise ValueError(
"the output of process() must be in list or tuple type" "the output of process() must be in list or tuple type"
......
...@@ -150,7 +150,8 @@ class CommonAccessor: ...@@ -150,7 +150,8 @@ class CommonAccessor:
oop = None oop = None
for op in optimizer_ops: for op in optimizer_ops:
if op.input("Param")[0] == param_name: if ("Param" in op.input_names) and (
op.input("Param")[0] == param_name):
oop = op oop = op
break break
......
...@@ -31,7 +31,7 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import RoundR ...@@ -31,7 +31,7 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.ps_dispatcher import RoundR
from paddle.fluid.transpiler.details.program_utils import delete_ops from paddle.fluid.transpiler.details.program_utils import delete_ops
OP_NAME_SCOPE = "op_namescope" OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "@CLIP" CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@" STEP_COUNTER = "@PS_STEP_COUNTER@"
LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@" LEARNING_RATE_DECAY_COUNTER = "@LR_DECAY_COUNTER@"
......
...@@ -32,7 +32,7 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_ta ...@@ -32,7 +32,7 @@ from paddle.fluid.incubate.fleet.parameter_server.ir.public import get_sparse_ta
from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode from paddle.fluid.incubate.fleet.parameter_server.mode import DistributedMode
OP_NAME_SCOPE = "op_namescope" OP_NAME_SCOPE = "op_namescope"
CLIP_OP_NAME_SCOPE = "@CLIP" CLIP_OP_NAME_SCOPE = "gradient_clip"
STEP_COUNTER = "@PS_STEP_COUNTER@" STEP_COUNTER = "@PS_STEP_COUNTER@"
OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName() OP_ROLE_VAR_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleVarAttrName()
RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName() RPC_OP_ROLE_ATTR_NAME = core.op_proto_and_checker_maker.kOpRoleAttrName()
......
...@@ -95,6 +95,32 @@ class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator): ...@@ -95,6 +95,32 @@ class MyMultiSlotDataGenerator_error_5(fleet.MultiSlotDataGenerator):
return data_iter return data_iter
class MyMultiSlotStringDataGenerator_zip(fleet.MultiSlotStringDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
feature_name = ["words", "label"]
data = [["1", "2", "3", "4"], ["0"]]
yield zip(feature_name, data)
return data_iter
class MyMultiSlotDataGenerator_zip(fleet.MultiSlotDataGenerator):
def generate_sample(self, line):
def data_iter():
for i in range(40):
if i == 1:
yield None
feature_name = ["words", "label"]
data = [[1, 2, 3, 4], [0]]
yield zip(feature_name, data)
return data_iter
class TestMultiSlotDataGenerator(unittest.TestCase): class TestMultiSlotDataGenerator(unittest.TestCase):
def test_MultiSlotDataGenerator_basic(self): def test_MultiSlotDataGenerator_basic(self):
my_ms_dg = MyMultiSlotDataGenerator() my_ms_dg = MyMultiSlotDataGenerator()
...@@ -149,5 +175,19 @@ class TestMultiSlotDataGenerator_error_5(unittest.TestCase): ...@@ -149,5 +175,19 @@ class TestMultiSlotDataGenerator_error_5(unittest.TestCase):
my_ms_dg.run_from_memory() my_ms_dg.run_from_memory()
class TestMultiSlotStringDataGeneratorZip(unittest.TestCase):
def test_MultiSlotStringDataGenerator_zip(self):
my_ms_dg = MyMultiSlotStringDataGenerator_zip()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
class TestMultiSlotDataGeneratorZip(unittest.TestCase):
def test_MultiSlotDataGenerator_zip(self):
my_ms_dg = MyMultiSlotDataGenerator_zip()
my_ms_dg.set_batch(1)
my_ms_dg.run_from_memory()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -18,6 +18,7 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distribu ...@@ -18,6 +18,7 @@ from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distribu
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
import paddle.distributed.fleet.base.role_maker as role_maker import paddle.distributed.fleet.base.role_maker as role_maker
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle
""" """
high level unit test for distribute fleet. high level unit test for distribute fleet.
""" """
...@@ -112,23 +113,21 @@ class FleetDistRunnerBase(object): ...@@ -112,23 +113,21 @@ class FleetDistRunnerBase(object):
def build_optimizer(self, avg_cost, strategy): def build_optimizer(self, avg_cost, strategy):
use_grad_clip = int(os.getenv('GRAD_CLIP', 0)) use_grad_clip = int(os.getenv('GRAD_CLIP', 0))
grad_clip = None
if use_grad_clip: if use_grad_clip:
# 1: clip_by_value; 2: clip_by_norm; 3:clip_by_global_norm # 1: clip_by_value; 2: clip_by_norm; 3:clip_by_global_norm
if use_grad_clip == 1: if use_grad_clip == 1:
fluid.clip.set_gradient_clip( grad_clip = paddle.nn.ClipGradByValue(min=-5.0, max=5.0)
clip=fluid.clip.GradientClipByValue(2.0))
elif use_grad_clip == 2: elif use_grad_clip == 2:
fluid.clip.set_gradient_clip( grad_clip = paddle.nn.ClipGradByNorm(2.0)
clip=fluid.clip.GradientClipByNorm(2.0))
elif use_grad_clip == 3: elif use_grad_clip == 3:
fluid.clip.set_gradient_clip( grad_clip = paddle.nn.ClipGradByGlobalNorm(2.0)
clip=fluid.clip.GradientClipByGlobalNorm(2.0))
use_decay = int(os.getenv("USE_DECAY", "0")) use_decay = int(os.getenv("USE_DECAY", "0"))
if use_decay: if use_decay:
scheduler = paddle.optimizer.lr.ExponentialDecay( scheduler = paddle.optimizer.lr.ExponentialDecay(
learning_rate=LEARNING_RATE, gamma=0.999, verbose=True) learning_rate=LEARNING_RATE, gamma=0.999, verbose=True)
optimizer = fluid.optimizer.SGD(scheduler) optimizer = fluid.optimizer.SGD(scheduler, grad_clip=grad_clip)
""" """
# learning rate decay method before 2.0 # learning rate decay method before 2.0
optimizer = fluid.optimizer.SGD( optimizer = fluid.optimizer.SGD(
...@@ -139,7 +138,7 @@ class FleetDistRunnerBase(object): ...@@ -139,7 +138,7 @@ class FleetDistRunnerBase(object):
staircase=True)) staircase=True))
""" """
else: else:
optimizer = fluid.optimizer.SGD(LEARNING_RATE) optimizer = fluid.optimizer.SGD(LEARNING_RATE, grad_clip=grad_clip)
optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy=strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
......
...@@ -16,53 +16,66 @@ from __future__ import print_function ...@@ -16,53 +16,66 @@ from __future__ import print_function
import os import os
import unittest import unittest
import paddle.fluid as fluid
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet
from paddle.fluid.transpiler.distribute_transpiler import DistributeTranspilerConfig
from test_dist_fleet_base import TestFleetBase from test_dist_fleet_base import TestFleetBase
from dist_fleet_simnet_bow import train_network
@unittest.skip(reason="Skip unstable ut, add it after PR 22957 merged") class TestDistGeoClipByGlobalNorm(TestFleetBase):
class TestDistGeoClipByGlobalNormTranspiler(unittest.TestCase): def _setup_config(self):
def test_pserver(self): self._mode = "geo"
role = role_maker.UserDefinedRoleMaker( self._reader = "dataset"
current_id=0, self._geo_sgd_need_push_nums = 5
role=role_maker.Role.SERVER, self._grad_clip_mode = 3
worker_num=2,
server_endpoints=["127.0.0.1:36011", "127.0.0.1:36012"])
fleet.init(role) def check_with_place(self,
model_file,
delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
}
required_envs.update(need_envs)
batch_size = 128 tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
is_sparse = True
is_distribute = False
strategy = DistributeTranspilerConfig() def test_dist_train(self):
strategy.sync_mode = False self.check_with_place(
strategy.geo_sgd_mode = True "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
strategy.geo_sgd_need_push_nums = 5
avg_cost, _, _, _ = train_network(batch_size, is_distribute, is_sparse) def _setup_config(self):
fluid.clip.set_gradient_clip( self._sync_mode = False
clip=fluid.clip.GradientClipByGlobalNorm(2.0)) self._grad_clip_mode = 2
optimizer = fluid.optimizer.SGD(0.1) def check_with_place(self,
optimizer = fleet.distributed_optimizer(optimizer, strategy) model_file,
optimizer.minimize(avg_cost) delta=1e-3,
check_error_log=False,
need_envs={}):
required_envs = {
"PATH": os.getenv("PATH", ""),
"PYTHONPATH": os.getenv("PYTHONPATH", ""),
"LD_LIBRARY_PATH": os.getenv("LD_LIBRARY_PATH", ""),
"FLAGS_rpc_deadline": "5000", # 5sec to fail fast
"http_proxy": ""
}
required_envs.update(need_envs)
tr0_losses, tr1_losses = self._run_cluster(model_file, required_envs)
pserver_startup_program = fleet.startup_program def test_dist_train(self):
pserver_mian_program = fleet.main_program self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
@unittest.skip(reason="Skip unstable ut, add it after PR 22957 merged") class TestDistASyncClipByValue(TestFleetBase):
class TestDistGeoClipByGlobalNorm(TestFleetBase):
def _setup_config(self): def _setup_config(self):
self._mode = "geo" self._mode = "async"
self._reader = "dataset" self._reader = "dataset"
self._geo_sgd_need_push_nums = 5 self._grad_clip_mode = 1
self._grad_clip_mode = 3
def check_with_place(self, def check_with_place(self,
model_file, model_file,
...@@ -84,8 +97,11 @@ class TestDistGeoClipByGlobalNorm(TestFleetBase): ...@@ -84,8 +97,11 @@ class TestDistGeoClipByGlobalNorm(TestFleetBase):
self.check_with_place( self.check_with_place(
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True) "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
class TestDistASyncClipByNorm(TestFleetBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._mode = "async"
self._reader = "dataset"
self._grad_clip_mode = 2 self._grad_clip_mode = 2
def check_with_place(self, def check_with_place(self,
...@@ -109,7 +125,6 @@ class TestDistGeoClipByGlobalNorm(TestFleetBase): ...@@ -109,7 +125,6 @@ class TestDistGeoClipByGlobalNorm(TestFleetBase):
"dist_fleet_ctr.py", delta=1e-5, check_error_log=True) "dist_fleet_ctr.py", delta=1e-5, check_error_log=True)
@unittest.skip(reason="Skip unstable ut, add it after PR 22957 merged")
class TestDistASyncClipByGlobalNorm(TestFleetBase): class TestDistASyncClipByGlobalNorm(TestFleetBase):
def _setup_config(self): def _setup_config(self):
self._mode = "async" self._mode = "async"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册