未验证 提交 cf5de26f 编写于 作者: W Wilber 提交者: GitHub

ut support block (#37909)

上级 b48545ee
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from typing import Optional, List, Callable, Dict, Any, Set from typing import Optional, List, Callable, Dict, Any, Set
import numpy as np import numpy as np
import enum
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -57,6 +58,12 @@ class TensorConfig: ...@@ -57,6 +58,12 @@ class TensorConfig:
return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype}) return str({'shape': self.shape, 'lod': self.lod, 'dtype': self.dtype})
class VarType(enum.Enum):
LOD_TENSOR = 1
LOD_TENSOR_ARRAY = 2
STEP_SCOPES = 3
class OpConfig: class OpConfig:
''' A config builder for generating a Op. ''' ''' A config builder for generating a Op. '''
...@@ -65,10 +72,14 @@ class OpConfig: ...@@ -65,10 +72,14 @@ class OpConfig:
inputs: Dict[str, List[str]], inputs: Dict[str, List[str]],
outputs: Dict[str, List[str]], outputs: Dict[str, List[str]],
attrs: Dict[str, Any]=None, attrs: Dict[str, Any]=None,
outputs_var_type: Dict[str, VarType]=None,
outputs_dtype: Dict[str, np.dtype]=None,
**kwargs): **kwargs):
self.type = type self.type = type
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
self.outputs_dtype = outputs_dtype
self.outputs_var_type = outputs_var_type
self.attrs = attrs self.attrs = attrs
if self.attrs is None: if self.attrs is None:
self.attrs = dict() self.attrs = dict()
...@@ -80,6 +91,88 @@ class OpConfig: ...@@ -80,6 +91,88 @@ class OpConfig:
return log_str return log_str
_OP_WITHOUT_KERNEL_SET = {
'feed', 'fetch', 'recurrent', 'go', 'rnn_memory_helper_grad',
'conditional_block', 'while', 'send', 'recv', 'listen_and_serv',
'fl_listen_and_serv', 'ncclInit', 'select', 'checkpoint_notify',
'gen_bkcl_id', 'c_gen_bkcl_id', 'gen_nccl_id', 'c_gen_nccl_id',
'c_comm_init', 'c_sync_calc_stream', 'c_sync_comm_stream',
'queue_generator', 'dequeue', 'enqueue', 'heter_listen_and_serv',
'c_wait_comm', 'c_wait_compute', 'c_gen_hccl_id', 'c_comm_init_hccl',
'copy_cross_scope'
}
class BlockConfig:
''' A config builder for generating a Block. '''
def __init__(self,
ops: List[OpConfig],
vars: List[str],
vars_dtype: Dict[str, np.dtype]=None,
vars_var_type: Dict[str, VarType]=None,
vars_lod_level: Dict[str, int]=None):
self.ops = ops
self.vars = vars
self.vars_dtype = vars_dtype
self.vars_var_type = vars_var_type
self.vars_lod_level = vars_lod_level
def fill_block_desc(self, block_desc):
for name in self.vars:
var_desc = block_desc.var(cpt.to_bytes(name))
var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR)
if self.vars_lod_level is not None and name in self.vars_lod_level.keys(
):
var_desc.set_lod_level(self.vars_lod_level[name])
if self.vars_var_type is not None and name in self.vars_var_type.keys(
):
if self.vars_var_type[name] == VarType.LOD_TENSOR_ARRAY:
var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR_ARRAY)
elif self.vars_var_type[name] == VarType.STEP_SCOPES:
var_desc.set_type(core.VarDesc.VarType.STEP_SCOPES)
continue
var_desc.set_dtype(convert_np_dtype_to_dtype_(np.float32))
if self.vars_dtype is not None and name in self.vars_dtype.keys():
var_desc.set_dtype(
convert_np_dtype_to_dtype_(self.vars_dtype[name]))
for op_config in self.ops:
op_desc = block_desc.append_op()
op_desc.set_type(op_config.type)
for name, values in op_config.inputs.items():
op_desc.set_input(name, values)
for name, values in op_config.attrs.items():
op_desc._set_attr(name, values)
for name, values in op_config.outputs.items():
op_desc.set_output(name, values)
for v in values:
if block_desc.has_var_recursive(cpt.to_bytes(v)):
continue
var_desc = block_desc.var(cpt.to_bytes(v))
var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR)
if op_config.outputs_var_type is not None and v in op_config.outputs_var_type.keys(
):
if op_config.outputs_var_type[
v] == VarType.LOD_TENSOR_ARRAY:
var_desc.set_type(
core.VarDesc.VarType.LOD_TENSOR_ARRAY)
elif op_config.outputs_var_type[
v] == VarType.STEP_SCOPES:
var_desc.set_type(core.VarDesc.VarType.STEP_SCOPES)
continue
var_desc.set_dtype(convert_np_dtype_to_dtype_(np.float32))
if op_config.outputs_dtype is not None and v in op_config.outputs_dtype.keys(
):
var_desc.set_dtype(
convert_np_dtype_to_dtype_(op_config.outputs_dtype[
v]))
if op_config.type not in _OP_WITHOUT_KERNEL_SET:
op_desc.infer_var_type(block_desc)
op_desc.infer_shape(block_desc)
op_desc.check_attrs()
class ProgramConfig: class ProgramConfig:
''' A config builder for generating a Program. ''' ''' A config builder for generating a Program. '''
...@@ -137,6 +230,8 @@ def create_fake_model(program_config): ...@@ -137,6 +230,8 @@ def create_fake_model(program_config):
var_desc.set_dtype(convert_np_dtype_to_dtype_(tensor_config.dtype)) var_desc.set_dtype(convert_np_dtype_to_dtype_(tensor_config.dtype))
var_desc.set_shape(tensor_config.shape) var_desc.set_shape(tensor_config.shape)
var_desc.set_need_check_feed(True) var_desc.set_need_check_feed(True)
if tensor_config.lod is not None:
var_desc.set_lod_level(len(tensor_config.lod))
op_desc = main_block_desc._prepend_op() op_desc = main_block_desc._prepend_op()
op_desc.set_type("feed") op_desc.set_type("feed")
op_desc.set_input('X', ["feed"]) op_desc.set_input('X', ["feed"])
...@@ -177,16 +272,36 @@ def create_fake_model(program_config): ...@@ -177,16 +272,36 @@ def create_fake_model(program_config):
for name, values in op_config.inputs.items(): for name, values in op_config.inputs.items():
op_desc.set_input(name, values) op_desc.set_input(name, values)
for name, values in op_config.attrs.items(): for name, values in op_config.attrs.items():
if name == 'sub_block':
sub_block_desc = main_program_desc.append_block(main_block_desc)
values.fill_block_desc(sub_block_desc)
op_desc._set_attr(name, sub_block_desc)
else:
op_desc._set_attr(name, values) op_desc._set_attr(name, values)
for name, values in op_config.outputs.items(): for name, values in op_config.outputs.items():
op_desc.set_output(name, values) op_desc.set_output(name, values)
for v in values: for v in values:
if main_block_desc.has_var_recursive(cpt.to_bytes(v)):
continue
var_desc = main_block_desc.var(cpt.to_bytes(v)) var_desc = main_block_desc.var(cpt.to_bytes(v))
var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR) var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR)
if op_config.outputs_var_type is not None and v in op_config.outputs_var_type.keys(
):
if op_config.outputs_var_type[
v] == VarType.LOD_TENSOR_ARRAY:
var_desc.set_type(core.VarDesc.VarType.LOD_TENSOR_ARRAY)
elif op_config.outputs_var_type[v] == VarType.STEP_SCOPES:
var_desc.set_type(core.VarDesc.VarType.STEP_SCOPES)
continue
var_desc.set_dtype(convert_np_dtype_to_dtype_(np.float32))
if op_config.outputs_dtype is not None and v in op_config.outputs_dtype.keys(
):
var_desc.set_dtype( var_desc.set_dtype(
convert_np_dtype_to_dtype_(tensor_config.dtype)) convert_np_dtype_to_dtype_(op_config.outputs_dtype[v]))
if op_config.type not in _OP_WITHOUT_KERNEL_SET:
op_desc.infer_var_type(main_block_desc) op_desc.infer_var_type(main_block_desc)
op_desc.infer_shape(main_block_desc) op_desc.infer_shape(main_block_desc)
op_desc.check_attrs()
for index, name in enumerate(program_config.outputs): for index, name in enumerate(program_config.outputs):
var_desc = main_block_desc.var(cpt.to_bytes("fetch")) var_desc = main_block_desc.var(cpt.to_bytes("fetch"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册