未验证 提交 0ce66c1c 编写于 作者: Z zhaoyingli 提交者: GitHub

[NewIR] support c_sync_calc_stream/c_sync_comm_stream/send_v2/recv_v2 (#56557)

* [AutoParallel][NewIR] support calc_sync/comm_sync/send_v2/recv_v2

* pre-commit

* rm unittest

* tiny fix

* api_gen support send_v2's output is empty

* fix format

* python_c_gen support send_v2
上级 df9d9c59
...@@ -275,16 +275,20 @@ class CodeGen: ...@@ -275,16 +275,20 @@ class CodeGen:
return 'return;' return 'return;'
def _gen_one_impl(self, op_info, op_name, is_mutable_attr): def _gen_one_impl(self, op_info, op_name, is_mutable_attr):
ret_type = self._gen_ret_type(op_info)
in_combine, in_combine_op_list = self._gen_in_combine(op_info) in_combine, in_combine_op_list = self._gen_in_combine(op_info)
compute_op, op_inst_name = self._gen_compute_op( compute_op, op_inst_name = self._gen_compute_op(
op_info, op_name, in_combine_op_list, is_mutable_attr op_info, op_name, in_combine_op_list, is_mutable_attr
) )
if ret_type == 'void':
compute_op += f' (void){op_inst_name};'
out_split, ret_list = self._gen_out_split_and_ret_list( out_split, ret_list = self._gen_out_split_and_ret_list(
op_info, op_inst_name op_info, op_inst_name
) )
ret = API_IMPL_TEMPLATE.format( ret = API_IMPL_TEMPLATE.format(
ret_type=self._gen_ret_type(op_info), ret_type=ret_type,
api_name=op_name, api_name=op_name,
args=self._gen_api_args(op_info, False, is_mutable_attr), args=self._gen_api_args(op_info, False, is_mutable_attr),
in_combine=in_combine, in_combine=in_combine,
......
...@@ -81,6 +81,29 @@ PyObject *static_api_{api_name}(PyObject *self, PyObject *args, PyObject *kwargs ...@@ -81,6 +81,29 @@ PyObject *static_api_{api_name}(PyObject *self, PyObject *args, PyObject *kwargs
}} }}
""" """
NO_OUTPUT_API_IMPL_TEMPLATE = """
PyObject *static_api_{api_name}(PyObject *self, PyObject *args, PyObject *kwargs) {{
try {{
VLOG(6) << "Add {api_name} op into program";
VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);
// Get OpResult from args
{inputs}
// Parse Attributes
{attrs}
// Call ir static api
paddle::dialect::{api_name}({args});
return nullptr;
}} catch (...) {{
ThrowExceptionToPython(std::current_exception());
return nullptr;
}}
}}
"""
INPUT_TEMPLATE = """ INPUT_TEMPLATE = """
PyObject *{name}_obj = PyTuple_GET_ITEM(args, {index}); PyObject *{name}_obj = PyTuple_GET_ITEM(args, {index});
auto {name} = {cast_func}({name}_obj, "{api_name}", {index});""" auto {name} = {cast_func}({name}_obj, "{api_name}", {index});"""
...@@ -263,7 +286,15 @@ class PythonCCodeGen(CodeGen): ...@@ -263,7 +286,15 @@ class PythonCCodeGen(CodeGen):
attr_name_list = op_info.attribute_name_list attr_name_list = op_info.attribute_name_list
mutable_attr_name_list = op_info.mutable_attribute_name_list mutable_attr_name_list = op_info.mutable_attribute_name_list
no_mutable_attr_name_list = op_info.non_mutable_attribute_name_list no_mutable_attr_name_list = op_info.non_mutable_attribute_name_list
if len(mutable_attr_name_list) > 0:
if op_name == "send_v2":
ret = NO_OUTPUT_API_IMPL_TEMPLATE.format(
api_name=op_name,
inputs=self._gen_inputs(op_info, op_name),
attrs=self._gen_attrs_without_mutable(op_info, op_name),
args=', '.join(input_name_list + attr_name_list),
)
elif len(mutable_attr_name_list) > 0:
ret = MUTABLE_ATTR_API_IMPL_TEMPLATE.format( ret = MUTABLE_ATTR_API_IMPL_TEMPLATE.format(
api_name=op_name, api_name=op_name,
inputs=self._gen_inputs(op_info, op_name), inputs=self._gen_inputs(op_info, op_name),
......
...@@ -355,6 +355,66 @@ ...@@ -355,6 +355,66 @@
inplace: null inplace: null
backward: null backward: null
- name: send_v2
inputs:
- typename: Tensor
name: x
optional: false
no_need_buffer: false
data_transform: {}
attrs:
- {typename: int, name: ring_id, default_value: '0'}
- {typename: int, name: peer, default_value: '0'}
- {typename: bool, name: use_calc_stream, default_value: 'false'}
- {typename: bool, name: dynamic_shape, default_value: 'false'}
outputs: []
no_need_buffer: null
data_transform: null
infer_meta:
func: SendV2InferMeta
param: [peer, ring_id]
kernel:
func: [send_v2]
param: [x, ring_id, dynamic_shape, peer, use_calc_stream]
backend: null
layout: null
data_type: null
dispatch: {send_v2: null}
force_backend: null
inplace: null
view: null
backward: null
- name: recv_v2
inputs: []
attrs:
- {typename: 'int[]', name: out_shape, default_value: '{}'}
- {typename: DataType, name: dtype, default_value: 'DataType::FLOAT32'}
- {typename: int, name: peer, default_value: '0'}
- {typename: int, name: ring_id, default_value: '0'}
- {typename: bool, name: use_calc_stream, default_value: 'false'}
- {typename: bool, name: dynamic_shape, default_value: 'false'}
outputs:
- {typename: Tensor, name: out, optional: false, intermediate: false}
no_need_buffer: null
data_transform: null
infer_meta:
func: RecvV2InferMeta
param: [peer, dtype, out_shape]
kernel:
func: [recv_v2]
param: [ring_id, dynamic_shape, peer, out_shape, dtype, use_calc_stream]
backend: null
layout: null
data_type:
ordered: false
candidates: [dtype]
to_complex_flag: [false]
dispatch: {recv_v2: null}
force_backend: null
inplace: null
view: null
- name : set_value - name : set_value
inputs: inputs:
- {typename: Tensor, name: x, optional: false, no_need_buffer: false, data_transform: {} } - {typename: Tensor, name: x, optional: false, no_need_buffer: false, data_transform: {} }
......
...@@ -24,7 +24,10 @@ const std::unordered_set<std::string> LegacyOpList = { ...@@ -24,7 +24,10 @@ const std::unordered_set<std::string> LegacyOpList = {
"pd.c_broadcast_", "pd.c_broadcast_",
"pd.fused_bn_add_activation_", "pd.fused_bn_add_activation_",
"pd.fused_bn_add_activation_grad", "pd.fused_bn_add_activation_grad",
}; "pd.c_sync_calc_stream_",
"pd.c_sync_comm_stream_",
"pd.send_v2",
"pd.recv_v2"};
enum class AttrType { enum class AttrType {
UNDEFINED = 0, UNDEFINED = 0,
......
...@@ -142,6 +142,26 @@ ...@@ -142,6 +142,26 @@
kernel : kernel :
func : c_concat func : c_concat
- op : c_sync_calc_stream
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : c_sync_calc_stream
inplace : (x -> out)
- op : c_sync_comm_stream
args : (Tensor x)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : c_sync_comm_stream
inplace : (x -> out)
- op : cast - op : cast
args : (Tensor x, DataType dtype) args : (Tensor x, DataType dtype)
output : Tensor(out) output : Tensor(out)
......
...@@ -3034,6 +3034,18 @@ ...@@ -3034,6 +3034,18 @@
outputs : outputs :
out : Out out : Out
- op: c_sync_calc_stream
inputs :
x : X
outputs :
out : Out
- op: c_sync_comm_stream
inputs :
x : X
outputs :
out : Out
- op: channel_shuffle - op: channel_shuffle
inputs: inputs:
{x: X} {x: X}
...@@ -3071,6 +3083,10 @@ ...@@ -3071,6 +3083,10 @@
outputs : outputs :
out : Out out : Out
- op: recv_v2
outputs :
out : Out
- op: reindex_graph (graph_reindex) - op: reindex_graph (graph_reindex)
inputs : inputs :
{x : X, neighbors : Neighbors, count : Count, hashtable_value : HashTable_Value, hashtable_index : HashTable_Index} {x : X, neighbors : Neighbors, count : Count, hashtable_value : HashTable_Value, hashtable_index : HashTable_Index}
...@@ -3083,6 +3099,10 @@ ...@@ -3083,6 +3099,10 @@
outputs: outputs:
{out: Out, noise: Noise} {out: Out, noise: Noise}
- op: send_v2
inputs :
x : X
- op: set_value - op: set_value
backward: set_value_grad backward: set_value_grad
inputs: inputs:
......
...@@ -403,6 +403,50 @@ void CConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) { ...@@ -403,6 +403,50 @@ void CConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out) {
out->set_dtype(x.dtype()); out->set_dtype(x.dtype());
} }
void SendV2InferMeta(const int peer, const int ring_id) {
PADDLE_ENFORCE_GE(
peer,
0,
errors::InvalidArgument(
"The peer (%d) for send_v2 op must be non-negative.", peer));
PADDLE_ENFORCE_GE(
ring_id,
0,
errors::InvalidArgument(
"The ring_id (%d) for send_v2 op must be non-negative.", ring_id));
}
void RecvV2InferMeta(int peer,
DataType dtype,
const std::vector<int>& out_shape,
MetaTensor* out) {
PADDLE_ENFORCE_GE(
peer,
0,
errors::InvalidArgument(
"The peer (%d) for p_recv op must be non-negative.", peer));
PADDLE_ENFORCE_GE(out_shape.size(),
1,
errors::InvalidArgument(
"The size of the output shape must be greater than 0 "
"but the value given is %d.",
out_shape.size()));
for (size_t i = 0; i < out_shape.size(); ++i) {
PADDLE_ENFORCE_GE(
out_shape[i],
1,
errors::InvalidArgument("The shape attribute for recv must be set "
"explicitly, but the %dth element is %d which "
"is less than 1. Or dynamic_shape should be "
"set to True for both send_v2 and recv_v2.",
i,
out_shape[i]));
}
out->set_dtype(dtype);
}
void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) { void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out) {
auto dims = x.dims(); auto dims = x.dims();
auto rank = dims.size(); auto rank = dims.size();
......
...@@ -73,6 +73,13 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out); ...@@ -73,6 +73,13 @@ void CastInferMeta(const MetaTensor& x, DataType out_dtype, MetaTensor* out);
void CConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out); void CConcatInferMeta(const MetaTensor& x, int nranks, MetaTensor* out);
void SendV2InferMeta(const int peer, const int ring_id);
void RecvV2InferMeta(int peer,
DataType dtype,
const std::vector<int>& out_shape,
MetaTensor* out);
void ChannelShuffleInferMeta(const MetaTensor& x, void ChannelShuffleInferMeta(const MetaTensor& x,
int groups, int groups,
const std::string& data_format, const std::string& data_format,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册