提交 3e85dbb6 编写于 作者: C cyber-pioneer

add dropout op map

上级 4b8b4c71
...@@ -414,6 +414,17 @@ ...@@ -414,6 +414,17 @@
- op : dropout - op : dropout
backward : dropout_grad backward : dropout_grad
inputs :
x : X
outputs :
out : Out
mask : Mask
attrs :
p : dropout_prob
is_test : is_test
mode : dropout_implementation
seed : seed
fix_seed : fix_seed
extra : extra :
attrs : [bool fix_seed = false, int seed = 0] attrs : [bool fix_seed = false, int seed = 0]
...@@ -790,6 +801,14 @@ ...@@ -790,6 +801,14 @@
- op : layer_norm - op : layer_norm
backward : layer_norm_grad backward : layer_norm_grad
inputs :
x : X
scale : Scale
bias : Bias
outputs :
out : Y
mean : Mean
variance : Variance
extra : extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false] attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
...@@ -940,6 +959,17 @@ ...@@ -940,6 +959,17 @@
outputs : outputs :
out : Out out : Out
- op : mean (reduce_mean)
backward : reduce_mean_grad
inputs :
x : X
outputs :
out : Out
attrs :
{axis : dim, keepdim : keep_dim}
extra :
attrs : [bool use_mkldnn = false]
- op : meshgrid - op : meshgrid
backward : meshgrid_grad backward : meshgrid_grad
inputs : inputs :
...@@ -1145,17 +1175,6 @@ ...@@ -1145,17 +1175,6 @@
extra : extra :
attrs : [bool use_mkldnn = false] attrs : [bool use_mkldnn = false]
- op : mean (reduce_mean)
backward : reduce_mean_grad
inputs :
x : X
outputs :
out : Out
attrs :
{axis : dim, keepdim : keep_dim}
extra :
attrs : [bool use_mkldnn = false]
- op : reduce_min - op : reduce_min
backward : reduce_min_grad backward : reduce_min_grad
extra : extra :
......
...@@ -597,11 +597,13 @@ def _lower_composite(block, blacklist=[]): ...@@ -597,11 +597,13 @@ def _lower_composite(block, blacklist=[]):
# if output var of composite rule is None, this means this var is not needed # if output var of composite rule is None, this means this var is not needed
none_vars_to_remove = set() none_vars_to_remove = set()
change = None
# Step2: Process all ops in the target block # Step2: Process all ops in the target block
for op_idx in range(len(block.ops)): for op_idx in range(len(block.ops)):
op = block.ops[op_idx] op = block.ops[op_idx]
ops_to_remove.append(op_idx) ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and op.type not in blacklist: if lookup_fn(op.type) is not None and op.type not in blacklist:
change = True
input_args = prepare_python_api_arguments(op) input_args = prepare_python_api_arguments(op)
bind(input_args, to_bind, value_table) bind(input_args, to_bind, value_table)
...@@ -681,6 +683,10 @@ def _lower_composite(block, blacklist=[]): ...@@ -681,6 +683,10 @@ def _lower_composite(block, blacklist=[]):
block.desc._remove_var(var_name.encode()) block.desc._remove_var(var_name.encode())
del block.vars[var_name] del block.vars[var_name]
block._sync_with_cpp() block._sync_with_cpp()
# composite ops may contain other ops, thus, call _lower_composite again.
if change:
_lower_composite(block, blacklist)
return return
elif isinstance(block, typing.Sequence): elif isinstance(block, typing.Sequence):
......
...@@ -232,14 +232,20 @@ def get_output_vars_from_comosite(op): ...@@ -232,14 +232,20 @@ def get_output_vars_from_comosite(op):
origin_output_name = op_map[name]["outputs"][item] origin_output_name = op_map[name]["outputs"][item]
if origin_output_name not in origin_output_names: if origin_output_name not in origin_output_names:
continue continue
origin_output_var = get_var_block(op.block, op.output(origin_output_name)) origin_output_var = get_var_block(
op.block, op.output(origin_output_name)
)
res.append(origin_output_var) res.append(origin_output_var)
elif len(origin_output_names) == 1: elif len(origin_output_names) == 1:
# When origin output num is 1, map info is not needed. # When origin output num is 1, map info is not needed.
origin_output_var = get_var_block(op.block, op.output(origin_output_names[0])) origin_output_var = get_var_block(
op.block, op.output(origin_output_names[0])
)
res.append(origin_output_var) res.append(origin_output_var)
else: else:
raise ValueError("When replace op with composite rule, there must exist output map info from origin op to composite rule.") raise ValueError(
"When replace op with composite rule, there must exist output map info from origin op to composite rule."
)
return res return res
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册