提交 3e85dbb6 编写于 作者: C cyber-pioneer

add dropout op map

上级 4b8b4c71
......@@ -414,6 +414,17 @@
- op : dropout
backward : dropout_grad
inputs :
x : X
outputs :
out : Out
mask : Mask
attrs :
p : dropout_prob
is_test : is_test
mode : dropout_implementation
seed : seed
fix_seed : fix_seed
extra :
attrs : [bool fix_seed = false, int seed = 0]
......@@ -790,6 +801,14 @@
- op : layer_norm
backward : layer_norm_grad
inputs :
x : X
scale : Scale
bias : Bias
outputs :
out : Y
mean : Mean
variance : Variance
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
......@@ -940,6 +959,17 @@
outputs :
out : Out
- op : mean (reduce_mean)
backward : reduce_mean_grad
inputs :
x : X
outputs :
out : Out
attrs :
{axis : dim, keepdim : keep_dim}
extra :
attrs : [bool use_mkldnn = false]
- op : meshgrid
backward : meshgrid_grad
inputs :
......@@ -1145,17 +1175,6 @@
extra :
attrs : [bool use_mkldnn = false]
- op : mean (reduce_mean)
backward : reduce_mean_grad
inputs :
x : X
outputs :
out : Out
attrs :
{axis : dim, keepdim : keep_dim}
extra :
attrs : [bool use_mkldnn = false]
- op : reduce_min
backward : reduce_min_grad
extra :
......
......@@ -597,11 +597,13 @@ def _lower_composite(block, blacklist=[]):
# if output var of composite rule is None, this means this var is not needed
none_vars_to_remove = set()
change = None
# Step2: Process all ops in the target block
for op_idx in range(len(block.ops)):
op = block.ops[op_idx]
ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and op.type not in blacklist:
change = True
input_args = prepare_python_api_arguments(op)
bind(input_args, to_bind, value_table)
......@@ -681,6 +683,10 @@ def _lower_composite(block, blacklist=[]):
block.desc._remove_var(var_name.encode())
del block.vars[var_name]
block._sync_with_cpp()
# composite ops may contain other ops, thus, call _lower_composite again.
if change:
_lower_composite(block, blacklist)
return
elif isinstance(block, typing.Sequence):
......
......@@ -232,14 +232,20 @@ def get_output_vars_from_comosite(op):
origin_output_name = op_map[name]["outputs"][item]
if origin_output_name not in origin_output_names:
continue
origin_output_var = get_var_block(op.block, op.output(origin_output_name))
origin_output_var = get_var_block(
op.block, op.output(origin_output_name)
)
res.append(origin_output_var)
elif len(origin_output_names) == 1:
# When origin output num is 1, map info is not needed.
origin_output_var = get_var_block(op.block, op.output(origin_output_names[0]))
origin_output_var = get_var_block(
op.block, op.output(origin_output_names[0])
)
res.append(origin_output_var)
else:
raise ValueError("When replace op with composite rule, there must exist output map info from origin op to composite rule.")
raise ValueError(
"When replace op with composite rule, there must exist output map info from origin op to composite rule."
)
return res
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册