未验证 提交 ff86aeab 编写于 作者: C cyber-pioneer 提交者: GitHub

fix composite op map (#50397)

* map output from composite rule to origin op

add mean layer_norm dropout op map

add input map check

composite softmax support input shape []

* composite softmax support shape []

* polish log

* solve conflict

* polish code

* polish op map output

* add check dtype
上级 8decfb78
......@@ -145,6 +145,13 @@
variance : Variance
scale : Scale
bias : Bias
outputs :
out : Y
mean_out: MeanOut
variance_out: VarianceOut
saved_mean: SavedMean
saved_variance: SavedVariance
reserve_space: ReserveSpace
extra :
attrs : [bool use_mkldnn = false, bool fuse_with_relu = false]
......@@ -421,6 +428,17 @@
- op : dropout
backward : dropout_grad
inputs :
x : X
outputs :
out : Out
mask : Mask
attrs :
p : dropout_prob
is_test : is_test
mode : dropout_implementation
seed : seed
fix_seed : fix_seed
extra :
attrs : [bool fix_seed = false, int seed = 0]
......@@ -808,6 +826,14 @@
- op : layer_norm
backward : layer_norm_grad
inputs :
x : X
scale : Scale
bias : Bias
outputs :
out : Y
mean : Mean
variance : Variance
extra :
attrs : [bool use_mkldnn = false, str mkldnn_data_type = "float32", bool is_test = false]
......@@ -978,6 +1004,17 @@
outputs :
out : Out
- op : mean (reduce_mean)
backward : reduce_mean_grad
inputs :
x : X
outputs :
out : Out
attrs :
{axis : dim, keepdim : keep_dim}
extra :
attrs : [bool use_mkldnn = false]
- op : meshgrid
backward : meshgrid_grad
inputs :
......@@ -1196,11 +1233,6 @@
extra :
attrs : [bool use_mkldnn = false]
- op : reduce_mean
backward : reduce_mean_grad
extra :
attrs : [bool use_mkldnn = false]
- op : reduce_min
backward : reduce_min_grad
extra :
......
......@@ -449,14 +449,14 @@ def _test_use_sync(value):
# ops in forward_blacklisk will not be replaced by composite ops.
prim_config = {"forward_blacklist": []}
prim_config = {"forward_blacklist": set(), "composite_ops_record": set()}
def _set_prim_forward_blacklist(ops=None):
if ops is None:
prim_config["forward_blacklist"] = []
elif isinstance(ops, str):
prim_config["forward_blacklist"].append(ops)
prim_config["forward_blacklist"].add(ops)
elif isinstance(ops, (list, tuple)):
for item in ops:
if not isinstance(item, str):
......@@ -464,7 +464,7 @@ def _set_prim_forward_blacklist(ops=None):
"ops set in forward_blacklist must belong to [str, str of tuple or list]"
)
else:
prim_config["forward_blacklist"].append(item)
prim_config["forward_blacklist"].add(item)
else:
raise TypeError(
"ops set in forward_blacklist must belong to [str, str of tuple or list]"
......
......@@ -68,7 +68,7 @@ def expect_forward(inputs):
class TestCompositeSoftmax(unittest.TestCase):
def setUp(self):
self.dtypes = ["float32", "float64"]
self.shapes = [[2, 3, 4], [2, 3]]
self.shapes = [[], [2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def cal_composite(self, inputs):
......@@ -101,6 +101,9 @@ class TestCompositeSoftmax(unittest.TestCase):
return res
def compare_forward(self):
if not attrs.shape and attrs.axis not in [-1, 0]:
# op softmax does not support both case
return
np_data = generate_data(attrs.shape)
tensor_data = paddle.to_tensor(np_data)
......
......@@ -143,7 +143,7 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
def setUp(self):
core._set_prim_backward_enabled(True)
self.dtypes = ["float32", "float64"]
self.shapes = [[2, 3, 4], [2, 3]]
self.shapes = [[], [2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def cal_composite_grad(self, inputs):
......@@ -169,6 +169,9 @@ class TestCompositeSoftmaxPrimBackward(unittest.TestCase):
return res
def compare_backward(self):
if not attrs.shape and attrs.axis not in [-1, 0]:
# op softmax does not support both case
return
np_data = generate_data(attrs.shape)
tensor_data = paddle.to_tensor(np_data)
......
......@@ -18,8 +18,6 @@
# ops.yaml or legacy_ops.yaml.
import paddle
from .primitives import * # noqa: F403
from .primreg import REGISTER_COMPOSITE, lookup_composite
......@@ -32,6 +30,10 @@ def _composite(op, *args):
@REGISTER_COMPOSITE('softmax')
def softmax_composite(x, axis):
"""define composite rule of op softmax"""
if not x.shape:
# do not return 1, to ensure gradients
res = divide(x + 1e-5, x + 1e-5)
return res
max_temp = max(x, axis, keepdim=True)
max_temp.stop_gradient = True
molecular = exp(x - max_temp)
......@@ -95,16 +97,15 @@ def composite_batchnorm(
y = reshape(scale, stats_shape) * x_hat + reshape(bias, stats_shape)
# add op assign to detach tensor in void unsafe change outside the rule.
batch_mean_ = assign(reshape(batch_mean, run_mean.shape))
batch_var_ = assign(reshape(batch_var, run_var.shape))
run_mean_ = assign(run_mean)
run_var_ = assign(run_var)
batch_mean_ = paddle.assign(batch_mean)
batch_var_ = paddle.assign(batch_var)
run_mean_ = paddle.assign(run_mean)
run_var_ = paddle.assign(run_var)
# reserve_space is not needed in composite rule, but still ruturn None to keep same as phi op defination.
reserve_space = None
if trainable_statistics or not is_test:
return run_mean_, None, batch_mean_, batch_var_, run_var_, y
else:
return run_mean_, batch_mean_, batch_var_, run_var_, y
return y, run_mean_, run_var_, batch_mean_, batch_var_, reserve_space
@REGISTER_COMPOSITE('gelu')
......
......@@ -84,7 +84,7 @@ def generate_code(
else:
op_name = key
map_dct[op_name] = {"phi_name": op_name}
for element in ["inputs", "attrs"]:
for element in ["inputs", "outputs", "attrs"]:
if element in item.keys():
map_dct[op_name][element] = item[element]
for element in ["scalar", "int_array"]:
......
......@@ -236,6 +236,8 @@ def to_prim(blocks):
f"Expect block or sequence of blocks, but got {type(blocks)}."
)
with framework.program_guard(main_program):
print("Running lowering for forward...")
print("Lowering composite forward ops begin...")
primx._lower_composite(blocks, prim_config["forward_blacklist"])
replace_ops = prim_config["composite_ops_record"]
print(f"Lowering composite forward ops finish: {replace_ops}")
return
......@@ -18,6 +18,7 @@ from collections import OrderedDict
import paddle
from paddle.fluid import framework as framework
from paddle.fluid.core import prim_config
from paddle.fluid.framework import Operator, default_main_program
from paddle.incubate.autograd.utils import as_tensors
......@@ -36,6 +37,7 @@ from .utils import (
flatten_and_remove_none,
get_input_var_list,
get_output_var_list,
map_output_for_composite,
prepare_python_api_arguments,
)
......@@ -596,19 +598,43 @@ def _lower_composite(block, blacklist=[]):
# if output var of composite rule is None, this means this var is not needed
none_vars_to_remove = set()
change = None
# Step2: Process all ops in the target block
for op_idx in range(len(block.ops)):
op = block.ops[op_idx]
ops_to_remove.append(op_idx)
if lookup_fn(op.type) is not None and op.type not in blacklist:
change = True
op_name = op.type
prim_config["composite_ops_record"].add(op_name)
input_args = prepare_python_api_arguments(op)
bind(input_args, to_bind, value_table)
orig_outs = expand_nested_list(map_output_for_composite(op))
new_outs = expand_nested_list(
as_tensors(lower_fn(op, *input_args))
)
assert len(orig_outs) == len(new_outs), (
f'when replace origin op {op_name} with composite rule, num of origin outs should be equal to new outs, '
f'but len(orig_outs) = {len(orig_outs)} and len(new_outs) = {len(new_outs)}'
)
for orig_out, new_out in zip(
expand_nested_list(get_output_var_list(op)),
expand_nested_list(as_tensors(lower_fn(op, *input_args))),
orig_outs,
new_outs,
):
if new_out is not None:
if orig_out is None:
# to keep same as phi op defination, orig_out may receive None
continue
elif new_out is not None:
assert orig_out.dtype == new_out.dtype, (
f'when replace origin op {op_name} with composite rule, origin out dtype should be equal to new out dtype, '
f'but orig_out.dtype={orig_out.dtype} and new_out.dtype={new_out.dtype}'
)
if orig_out.shape and new_out.shape:
assert orig_out.shape == new_out.shape, (
f'when replace origin op {op_name} with composite rule, origin out shape should be equal to new out shape, '
f'but orig_out.shape={orig_out.shape} and new_out.shape={new_out.shape}'
)
assert not (orig_out is None) ^ (
new_out is None
), "orig_out and new_out should match."
......@@ -675,6 +701,10 @@ def _lower_composite(block, blacklist=[]):
block.desc._remove_var(var_name.encode())
del block.vars[var_name]
block._sync_with_cpp()
# composite ops may contain other composite ops, thus, call _lower_composite again.
if change:
_lower_composite(block, blacklist)
return
elif isinstance(block, typing.Sequence):
......
......@@ -169,6 +169,7 @@ def _get_args_values(op, phi_name):
arg_type, arg_name = _solve_arg(item)
op_content = op_map[op.type]
if arg_type in ("Tensor", "Tensor[]"):
# assume Tensor type must belong to inputs
if (
"inputs" in op_content.keys()
and arg_name in op_content["inputs"].keys()
......@@ -182,8 +183,12 @@ def _get_args_values(op, phi_name):
"attrs" in op_content.keys()
and arg_name in op_content["attrs"].keys()
):
attrs.append(op.attr(op_content["attrs"][arg_name]))
attrs.append(op.attr(arg_name))
arg_name = op_content["attrs"][arg_name]
# Note: in some cases, attrs may be optional , thus assign None. Such case must be recorded.
if arg_name not in op.attr_names:
attrs.append(None)
else:
attrs.append(op.attr(arg_name))
return inputs, attrs
......@@ -202,7 +207,13 @@ def prepare_python_api_arguments(op):
else:
phi_name = op.type
inputs, attrs = _get_args_values(op, phi_name)
res = [get_var_block(op.block, op.input(n)) for n in inputs]
res = []
for item in inputs:
if item in op.input_names:
res.append(get_var_block(op.block, op.input(item)))
else:
# Note: in some cases, inputs may be optional, thus assign None. Such case must be recorded.
res.append(None)
if attrs:
res.extend(attrs)
return res
......@@ -218,6 +229,38 @@ def get_output_var_list(op):
]
def map_output_for_composite(op):
"""origin op outputs must be mapped into outputs of composite rule. map info has been defined in op_compat.yaml"""
origin_output_names = op.output_names
if origin_output_names is None:
return []
else:
name = op.type
res = []
if op_map[name].get("outputs"):
for item in op_map[name]["outputs"].keys():
origin_output_name = op_map[name]["outputs"][item]
if origin_output_name not in origin_output_names:
res.append(None)
# Note: in some cases, some output of origin op is optional, so op name may not be in origin_output_names
continue
origin_output_var = get_var_block(
op.block, op.output(origin_output_name)
)
res.append(origin_output_var)
elif len(origin_output_names) == 1:
# When origin output num is 1, map info is not needed.
origin_output_var = get_var_block(
op.block, op.output(origin_output_names[0])
)
res.append(origin_output_var)
else:
raise ValueError(
"When replace op with composite rule, there must exist output map info from origin op to composite rule."
)
return res
def flatten(inp):
if inp is None or isinstance(inp, paddle.fluid.framework.Variable):
return [inp]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册