未验证 提交 c8429d36 编写于 作者: S Steffy-zxf 提交者: GitHub

[cherry-pick 2.2]fix data parallel when VOCAB var in program (#37546)

* fix data parallel when VOCAB var in program

* fix ci coverage
上级 824c4ef9
......@@ -365,6 +365,9 @@ def sync_params_buffers(model,
if getattr(param, "no_sync", False):
continue
if param.type == core.VarDesc.VarType.VOCAB:
continue
model_vars.append(param.detach())
if len(model_vars) == 0:
return
......
......@@ -554,6 +554,7 @@ py_test_modules(test_imperative_static_runner_mnist MODULES test_imperative_stat
py_test_modules(test_imperative_static_runner_while MODULES test_imperative_static_runner_while ENVS
FLAGS_cudnn_deterministic=1)
set_tests_properties(test_conv2d_op PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_faster_tokenizer_op PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_conv2d_op_depthwise_conv PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_conv2d_api PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_conv_nn_grad PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
......
......@@ -388,6 +388,34 @@ class TestBertTokenizerOp(unittest.TestCase):
exe.run(paddle.static.default_main_program(), feed={'x': self.text})
paddle.disable_static()
def test_data_parallel(self):
self.max_seq_len = 128
self.pad_to_max_seq_len = True
self.is_split_into_words = False
model = paddle.DataParallel(self.faster_tokenizer)
input_ids, token_type_ids = model(
text=self.text_tensor,
do_lower_case=self.bert_tokenizer.do_lower_case,
max_seq_len=self.max_seq_len,
pad_to_max_seq_len=self.pad_to_max_seq_len,
is_split_into_words=self.is_split_into_words)
input_ids = input_ids.numpy()
token_type_ids = token_type_ids.numpy()
encoded_inputs = self.bert_tokenizer(
self.text,
max_seq_len=self.max_seq_len,
pad_to_max_seq_len=self.pad_to_max_seq_len,
is_split_into_words=self.is_split_into_words)
py_input_ids = np.array(encoded_inputs[0]["input_ids"]).reshape([1, -1])
py_token_type_ids = np.array(encoded_inputs[0][
"token_type_ids"]).reshape([1, -1])
self.assertTrue(np.allclose(input_ids, py_input_ids, rtol=0, atol=0.01))
self.assertTrue(
np.allclose(
token_type_ids, py_token_type_ids, rtol=0, atol=0.01))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册