提交 1e042228 编写于 作者: Q Qiao Longfei

add test_dist_ctr_with_l2_decay.py

上级 25d44d40
...@@ -18,6 +18,7 @@ if(NOT WITH_DISTRIBUTE) ...@@ -18,6 +18,7 @@ if(NOT WITH_DISTRIBUTE)
LIST(REMOVE_ITEM TEST_OPS test_dist_mnist) LIST(REMOVE_ITEM TEST_OPS test_dist_mnist)
LIST(REMOVE_ITEM TEST_OPS test_dist_word2vec) LIST(REMOVE_ITEM TEST_OPS test_dist_word2vec)
LIST(REMOVE_ITEM TEST_OPS test_dist_ctr) LIST(REMOVE_ITEM TEST_OPS test_dist_ctr)
LIST(REMOVE_ITEM TEST_OPS test_dist_ctr_with_l2_decay)
LIST(REMOVE_ITEM TEST_OPS test_dist_simnet_bow) LIST(REMOVE_ITEM TEST_OPS test_dist_simnet_bow)
LIST(REMOVE_ITEM TEST_OPS test_dist_mnist_batch_merge) LIST(REMOVE_ITEM TEST_OPS test_dist_mnist_batch_merge)
LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification) LIST(REMOVE_ITEM TEST_OPS test_dist_text_classification)
...@@ -100,7 +101,7 @@ if(WITH_DISTRIBUTE) ...@@ -100,7 +101,7 @@ if(WITH_DISTRIBUTE)
# FIXME(typhoonzero): add these tests back # FIXME(typhoonzero): add these tests back
# py_test_modules(test_dist_transformer MODULES test_dist_transformer) # py_test_modules(test_dist_transformer MODULES test_dist_transformer)
# set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000) # set_tests_properties(test_dist_transformer PROPERTIES TIMEOUT 1000)
set_tests_properties(test_dist_ctr test_dist_mnist test_dist_mnist_batch_merge test_dist_save_load test_dist_se_resnext test_dist_simnet_bow test_dist_text_classification test_dist_train test_dist_word2vec PROPERTIES RUN_SERIAL TRUE) set_tests_properties(test_dist_ctr test_dist_ctr_with_l2_decay test_dist_mnist test_dist_mnist_batch_merge test_dist_save_load test_dist_se_resnext test_dist_simnet_bow test_dist_text_classification test_dist_train test_dist_word2vec PROPERTIES RUN_SERIAL TRUE)
endif(NOT APPLE) endif(NOT APPLE)
py_test_modules(test_dist_transpiler MODULES test_dist_transpiler) py_test_modules(test_dist_transpiler MODULES test_dist_transpiler)
endif() endif()
......
...@@ -30,11 +30,7 @@ fluid.default_main_program().random_seed = 1 ...@@ -30,11 +30,7 @@ fluid.default_main_program().random_seed = 1
class TestDistCTR2x2(TestDistRunnerBase): class TestDistCTR2x2(TestDistRunnerBase):
def config(self):
self.use_l2_decay = False
def get_model(self, batch_size=2): def get_model(self, batch_size=2):
self.config()
dnn_input_dim, lr_input_dim = dist_ctr_reader.load_data_meta() dnn_input_dim, lr_input_dim = dist_ctr_reader.load_data_meta()
""" network definition """ """ network definition """
...@@ -103,7 +99,8 @@ class TestDistCTR2x2(TestDistRunnerBase): ...@@ -103,7 +99,8 @@ class TestDistCTR2x2(TestDistRunnerBase):
inference_program = paddle.fluid.default_main_program().clone() inference_program = paddle.fluid.default_main_program().clone()
regularization = None regularization = None
if self.use_l2_decay: use_l2_decay = bool(os.getenv(['USE_L2_DECAY'], 0))
if use_l2_decay:
regularization = fluid.regularizer.L2DecayRegularizer( regularization = fluid.regularizer.L2DecayRegularizer(
regularization_coeff=1e-3) regularization_coeff=1e-3)
......
...@@ -18,7 +18,6 @@ import unittest ...@@ -18,7 +18,6 @@ import unittest
from test_dist_base import TestDistBase from test_dist_base import TestDistBase
# FIXME(tangwei): sum op can not handle when inputs is empty.
class TestDistCTR2x2(TestDistBase): class TestDistCTR2x2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
...@@ -28,15 +27,5 @@ class TestDistCTR2x2(TestDistBase): ...@@ -28,15 +27,5 @@ class TestDistCTR2x2(TestDistBase):
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
class TestDistCTR2x2WithL2Decay(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
self.check_with_place(
"dist_ctr_with_l2_decay.py", delta=1e-7, check_error_log=False)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -11,17 +11,26 @@ ...@@ -11,17 +11,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import dist_ctr import os
from test_dist_base import runtime_main import unittest
from test_dist_base import TestDistBase
class TestDistCTR2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
class TestDistCTRWithL2Decay(dist_ctr.TestDistCTR2x2): def test_dist_ctr(self):
def config(self): need_envs = {"USE_L2_DECAY": "1"}
self.use_l2_decay = True self.check_with_place(
"dist_ctr.py",
delta=1e-7,
check_error_log=False,
need_envs=need_envs)
if __name__ == "__main__": if __name__ == "__main__":
runtime_main(TestDistCTRWithL2Decay) unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册