提交 1e042228 编写于 作者: Q Qiao Longfei

add test_dist_ctr_with_l2_decay.py

上级 25d44d40
......@@ -18,6 +18,7 @@ if(NOT WITH_DISTRIBUTE)
LIST(REMOVE_ITEM TEST_OPS test_dist_mnist)
LIST(REMOVE_ITEM TEST_OPS test_dist_word2vec)
LIST(REMOVE_ITEM TEST_OPS test_dist_ctr)
LIST(REMOVE_ITEM TEST_OPS test_dist_ctr_with_l2_decay)
LIST(REMOVE_ITEM TEST_OPS test_dist_simnet_bow)
LIST(REMOVE_ITEM TEST_OPS test_dist_mnist_batch_merge)
LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification)
......@@ -100,7 +101,7 @@ if(WITH_DISTRIBUTE)
# FIXME(typhoonzero): add these tests back
# py_test_modules(test_dist_transformer MODULES test_dist_transformer)
# set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
set_tests_properties(test_dist_ctr test_dist_mnist test_dist_mnist_batch_merge test_dist_save_load test_dist_se_resnext test_dist_simnet_bow test_dist_text_classification test_dist_train test_dist_word2vec PROPERTIES RUN_SERIAL TRUE)
set_tests_properties(test_dist_ctr test_dist_ctr_with_l2_decay test_dist_mnist test_dist_mnist_batch_merge test_dist_save_load test_dist_se_resnext test_dist_simnet_bow test_dist_text_classification test_dist_train test_dist_word2vec PROPERTIES RUN_SERIAL TRUE)
endif(NOT APPLE)
py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
endif()
......
......@@ -30,11 +30,7 @@ fluid.default_main_program().random_seed = 1
class TestDistCTR2x2(TestDistRunnerBase):
def config(self):
self.use_l2_decay = False
def get_model(self, batch_size=2):
self.config()
dnn_input_dim, lr_input_dim = dist_ctr_reader.load_data_meta()
""" network definition """
......@@ -103,7 +99,8 @@ class TestDistCTR2x2(TestDistRunnerBase):
inference_program = paddle.fluid.default_main_program().clone()
regularization = None
if self.use_l2_decay:
use_l2_decay = bool(os.getenv(['USE_L2_DECAY'], 0))
if use_l2_decay:
regularization = fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-3)
......
......@@ -18,7 +18,6 @@ import unittest
from test_dist_base import TestDistBase
# FIXME(tangwei): sum op can not handle when inputs is empty.
class TestDistCTR2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
......@@ -28,15 +27,5 @@ class TestDistCTR2x2(TestDistBase):
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
class TestDistCTR2x2WithL2Decay(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
self.check_with_place(
"dist_ctr_with_l2_decay.py", delta=1e-7, check_error_log=False)
if __name__ == "__main__":
unittest.main()
......@@ -11,17 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import dist_ctr
from test_dist_base import runtime_main
import os
import unittest
from test_dist_base import TestDistBase
class TestDistCTR2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
class TestDistCTRWithL2Decay(dist_ctr.TestDistCTR2x2):
def config(self):
self.use_l2_decay = True
def test_dist_ctr(self):
need_envs = {"USE_L2_DECAY": "1"}
self.check_with_place(
"dist_ctr.py",
delta=1e-7,
check_error_log=False,
need_envs=need_envs)
if __name__ == "__main__":
runtime_main(TestDistCTRWithL2Decay)
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册