未验证 提交 2857fdbb 编写于 作者: Z zhaoyingli 提交者: GitHub

[NewIR] Update send recv infermeta and add unittest (#56794)

* [NewIR]Update send recv infermeta and add unittest

* rm new ir flag

* rm fluid api

* skip runing startup prog

* update flag name

* update recv_v2 yaml

* fix conflict

* unittest only for pp

* fix cmakelist

* unittest check precision

* control random

* fix cmakelist
上级 093f6ecd
......@@ -283,11 +283,12 @@ class PythonCCodeGen(CodeGen):
def _gen_one_impl(self, op_info, op_name):
input_name_list = op_info.input_name_list
output_name_list = op_info.output_name_list
attr_name_list = op_info.attribute_name_list
mutable_attr_name_list = op_info.mutable_attribute_name_list
no_mutable_attr_name_list = op_info.non_mutable_attribute_name_list
if op_name == "send_v2":
if len(output_name_list) == 0:
ret = NO_OUTPUT_API_IMPL_TEMPLATE.format(
api_name=op_name,
inputs=self._gen_inputs(op_info, op_name),
......
......@@ -92,7 +92,7 @@
output : Tensor(out)
infer_meta:
func: RecvV2InferMeta
param: [peer, dtype, out_shape]
param: [ring_id, dynamic_shape, peer, out_shape, dtype]
kernel :
func : recv_v2
param : [ring_id, dynamic_shape, peer, out_shape, dtype, use_calc_stream]
......
......@@ -181,6 +181,48 @@ void PRecvArrayInferMeta(int peer,
out->set_dtype(dtype);
}
void RecvV2InferMeta(const int ring_id,
const bool dynamic_shape,
const int peer,
const std::vector<int>& out_shape,
DataType dtype,
MetaTensor* out) {
PADDLE_ENFORCE_GE(
peer,
0,
errors::InvalidArgument(
"The peer (%d) for recv_v2 op must be non-negative.", peer));
PADDLE_ENFORCE_GE(
ring_id,
0,
errors::InvalidArgument(
"The ring_id (%d) for recv_v2 op must be non-negative.", ring_id));
PADDLE_ENFORCE_GE(out_shape.size(),
1,
errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));
if (!dynamic_shape) {
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(out_shape[i],
1,
errors::InvalidArgument(
"The shape attribute for recv_v2 must be set "
"explicitly, but the %dth element is %d which "
"is less than 1. Or dynamic_shape should be "
"set to True for both send_v2 and recv_v2.",
i,
out_shape[i]));
}
out->set_dims(phi::make_ddim(out_shape));
}
out->set_dtype(dtype);
}
void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
float mean,
float std,
......
......@@ -76,6 +76,13 @@ void PRecvArrayInferMeta(int peer,
const std::vector<int>& out_shape,
MetaTensor* out);
void RecvV2InferMeta(const int ring_id,
const bool dynamic_shape,
const int peer,
const std::vector<int>& out_shape,
DataType dtype,
MetaTensor* out);
void TruncatedGaussianRandomInferMeta(const std::vector<int>& shape,
float mean,
float std,
......
......@@ -405,50 +405,6 @@ void CConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) {
out->set_dtype(x.dtype());
}
void SendV2InferMeta(const int peer, const int ring_id) {
PADDLE_ENFORCE_GE(
peer,
0,
errors::InvalidArgument(
"The peer (%d) for send_v2 op must be non-negative.", peer));
PADDLE_ENFORCE_GE(
ring_id,
0,
errors::InvalidArgument(
"The ring_id (%d) for send_v2 op must be non-negative.", ring_id));
}
void RecvV2InferMeta(int peer,
DataType dtype,
const std::vector<int>& out_shape,
MetaTensor* out) {
PADDLE_ENFORCE_GE(
peer,
0,
errors::InvalidArgument(
"The peer (%d) for p_recv op must be non-negative.", peer));
PADDLE_ENFORCE_GE(out_shape.size(),
1,
errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(
out_shape[i],
1,
errors::InvalidArgument("The shape attribute for recv must be set "
"explicitly, but the %dth element is %d which "
"is less than 1. Or dynamic_shape should be "
"set to True for both send_v2 and recv_v2.",
i,
out_shape[i]));
}
out->set_dtype(dtype);
}
void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) {
auto dims = x.dims();
auto rank = dims.size();
......@@ -3045,6 +3001,19 @@ void PSendArrayInferMeta(const MetaTensor& x, int peer) {
"The peer (%d) for p_send op must be non-negative.", peer));
}
void SendV2InferMeta(const int peer, const int ring_id) {
PADDLE_ENFORCE_GE(
peer,
0,
errors::InvalidArgument(
"The peer (%d) for send_v2 op must be non-negative.", peer));
PADDLE_ENFORCE_GE(
ring_id,
0,
errors::InvalidArgument(
"The ring_id (%d) for send_v2 op must be non-negative.", ring_id));
}
void PoolInferMeta(const MetaTensor& x,
const std::vector<int>& kernel_size,
const std::vector<int>& strides,
......
......@@ -73,13 +73,6 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
void CConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out);
void SendV2InferMeta(const int peer, const int ring_id);
void RecvV2InferMeta(int peer,
DataType dtype,
const std::vector<int>& out_shape,
MetaTensor* out);
void ChannelShuffleInferMeta(const MetaTensor& x,
int groups,
const std::string& data_format,
......@@ -448,6 +441,8 @@ void PSendInferMeta(const MetaTensor& x, int peer);
void PSendArrayInferMeta(const MetaTensor& x, int peer);
void SendV2InferMeta(const int peer, const int ring_id);
void QrInferMeta(const MetaTensor& x,
const std::string& mode,
MetaTensor* q,
......
......@@ -833,6 +833,14 @@ class Engine:
dist_main_program, self._place, dist_context
)
# NOTE(zhaoyinglia): Skip startup program when use new ir temporarily.
use_new_ir = False
if auto_utils.use_new_ir():
use_new_ir = True
paddle.framework.set_flags(
{"FLAGS_enable_new_ir_in_executor": False}
)
if self._executor is None:
self._executor = paddle.static.Executor(self._place)
uninitialized = []
......@@ -860,6 +868,11 @@ class Engine:
]
self._executor.run(dist_startup_prog)
if use_new_ir:
paddle.framework.set_flags(
{"FLAGS_enable_new_ir_in_executor": True}
)
def fit(
self,
train_data,
......
......@@ -2423,6 +2423,19 @@ def use_new_executor():
]
def use_new_ir():
enable_new_ir_in_executor = os.environ.get(
'FLAGS_enable_new_ir_in_executor', None
)
return enable_new_ir_in_executor in [
1,
'1',
True,
'True',
'true',
]
def get_pp_stage(dist_context, rank):
pp_idx = None
for idx, process_mesh in enumerate(dist_context.process_meshes):
......
......@@ -78,20 +78,23 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_pass_quantization MODULES test_pass_quantization)
set_tests_properties(test_pass_quantization
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 60)
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 60)
py_test_modules(test_reshard_s_to_r MODULES test_reshard_s_to_r)
set_tests_properties(test_reshard_s_to_r
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_reshard_r_to_s MODULES test_reshard_r_to_s)
set_tests_properties(test_reshard_r_to_s
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p)
set_tests_properties(test_reshard_r_to_p
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_semi_auto_parallel_basic MODULES
test_semi_auto_parallel_basic)
set_tests_properties(test_semi_auto_parallel_basic
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
py_test_modules(test_gpt_with_newir MODULES test_gpt_with_newir)
set_tests_properties(test_gpt_with_newir
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 100)
# End of unittests WITH multi cards and timeout
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import unittest
import numpy as np
from get_gpt_model import FakeDataset, generate_model
import paddle
from paddle.distributed import ParallelEnv
from paddle.distributed.fleet import auto
paddle.enable_static()
def apply_pass():
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
return strategy
def reset_prog():
paddle.framework.switch_main_program(paddle.static.Program())
paddle.framework.switch_startup_program(paddle.static.Program())
paddle.utils.unique_name.switch()
class TestNewIR(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.batch_num = 5
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
os.environ['FLAGS_new_executor_micro_batching'] = 'True'
paddle.set_flags({'FLAGS_embedding_deterministic': 1})
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
paddle.distributed.fleet.init(is_collective=True)
place = paddle.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, mode):
reset_prog()
strategy = apply_pass()
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=None)
model, loss = generate_model(mode)
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_equal(
ref_losses,
check_losses,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses
),
)
def enable_new_ir(self, flag):
paddle.set_flags({'FLAGS_enable_new_ir_in_executor': flag}) # for c++
os.environ['FLAGS_enable_new_ir_in_executor'] = str(flag) # for python
def test_dp(self):
self.enable_new_ir(False)
engine_dp_prog = self.get_engine("dp")
out_dp_prog = engine_dp_prog.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
self.enable_new_ir(True)
engine_dp_ir = self.get_engine("dp")
out_dp_ir = engine_dp_ir.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
self.check_results(
out_dp_prog.history["loss"][0], out_dp_ir.history["loss"][0]
)
def test_mp(self):
self.enable_new_ir(False)
engine_mp_prog = self.get_engine("mp")
out_mp_prog = engine_mp_prog.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
self.enable_new_ir(True)
engine_mp_ir = self.get_engine("mp")
out_mp_ir = engine_mp_ir.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
self.check_results(
out_mp_prog.history["loss"][0], out_mp_ir.history["loss"][0]
)
def test_pp(self):
# navie pipeline parallel without schedule
self.enable_new_ir(False)
engine_pp_prog = self.get_engine("pp")
out_pp_prog = engine_pp_prog.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
self.enable_new_ir(True)
# send_v2/recv_v2 dynamic_shape is True
engine_pp_ir = self.get_engine("pp")
out_pp_ir = engine_pp_ir.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
if paddle.distributed.get_rank() == 1:
self.check_results(
out_pp_prog.history["loss"][0], out_pp_ir.history["loss"][0]
)
# send_v2/recv_v2 dynamic_shape is False
engine_pp_prog1 = self.get_engine("pp")
dataloader_pp_prog = engine_pp_prog1.dataloader(
self.dataset,
batch_size=self.batch_size,
sample_split=3,
mode="train",
)
engine_pp_prog1.prepare(mode="train")
for op in engine_pp_prog1.main_program.global_block().ops:
if op.type in ["send_v2", "recv_v2"]:
op.desc._set_attr("dynamic_shape", False)
for data in dataloader_pp_prog:
out_pp_prog1 = engine_pp_prog1.run(data, mode="train")
if paddle.distributed.get_rank() == 1:
self.check_results(
out_pp_prog1["loss"], out_pp_ir.history["loss"][0]
)
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestGPTNewIR(unittest.TestCase):
def test_gpt_newir(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "gpt_with_newir.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册