未验证 提交 f7f67b72 编写于 作者: Z zqw_1997 提交者: GitHub

Add mean composite rule (#50298)

* beta

* small commit

* add batch_norm composite rule

move composite test case

remove unuseful var

add composite op blacklist

* small change v2

* finish the test_composite_mean and test_composite_mean_grad

* add ops assertion to the tests

* add cinn test

* fix the error and inappropriate usage in func: mean_composite

* remove the ref of outer lib in primtives.py

* modify sample code of reduce_sum

* fix composite mean op map

* modify testcases to test more float type

* remove cpu float16 test

* cinn test fix

* remove reduce_max

* change the name sum to sum_x

* change the use of reduce_sum to sum

---------
Co-authored-by: Ncyber-pioneer <chenzhuo@tju.edu.cn>
上级 3c14b38e
...@@ -574,18 +574,18 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None): ...@@ -574,18 +574,18 @@ def reduce_sum(input, dim=None, keep_dim=False, name=None):
# [0.1, 0.2, 0.6, 0.7]] # [0.1, 0.2, 0.6, 0.7]]
# Each example is followed by the corresponding output tensor. # Each example is followed by the corresponding output tensor.
x = fluid.data(name='x', shape=[2, 4], dtype='float32') x = fluid.data(name='x', shape=[2, 4], dtype='float32')
fluid.layers.reduce_sum(x) # [3.5] fluid.layers.nn.reduce_sum(x) # [3.5]
fluid.layers.reduce_sum(x, dim=0) # [0.3, 0.5, 1.1, 1.6] fluid.layers.nn.reduce_sum(x, dim=0) # [0.3, 0.5, 1.1, 1.6]
fluid.layers.reduce_sum(x, dim=-1) # [1.9, 1.6] fluid.layers.nn.reduce_sum(x, dim=-1) # [1.9, 1.6]
fluid.layers.reduce_sum(x, dim=1, keep_dim=True) # [[1.9], [1.6]] fluid.layers.nn.reduce_sum(x, dim=1, keep_dim=True) # [[1.9], [1.6]]
# y is a Tensor variable with shape [2, 2, 2] and elements as below: # y is a Tensor variable with shape [2, 2, 2] and elements as below:
# [[[1, 2], [3, 4]], # [[[1, 2], [3, 4]],
# [[5, 6], [7, 8]]] # [[5, 6], [7, 8]]]
# Each example is followed by the corresponding output tensor. # Each example is followed by the corresponding output tensor.
y = fluid.data(name='y', shape=[2, 2, 2], dtype='float32') y = fluid.data(name='y', shape=[2, 2, 2], dtype='float32')
fluid.layers.reduce_sum(y, dim=[1, 2]) # [10, 26] fluid.layers.nn.reduce_sum(y, dim=[1, 2]) # [10, 26]
fluid.layers.reduce_sum(y, dim=[0, 1]) # [16, 20] fluid.layers.nn.reduce_sum(y, dim=[0, 1]) # [16, 20]
""" """
reduce_all, dim = _get_reduce_dim(dim, input) reduce_all, dim = _get_reduce_dim(dim, input)
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import platform
import unittest
import numpy as np
import paddle
import paddle.tensor as tensor
from paddle.fluid import core
TOLERANCE = {
"float16": {"rtol": 1e-3, "atol": 1e-3},
"float32": {"rtol": 1e-6, "atol": 1e-6},
"float64": {"rtol": 1e-15, "atol": 1e-15},
}
keepdim_conds = [True, False]
axes_condis = [-1, 0, 1]
def apply_to_static(net, use_cinn):
build_strategy = paddle.static.BuildStrategy()
build_strategy.build_cinn_pass = use_cinn
return paddle.jit.to_static(net, build_strategy=build_strategy)
def generate_data(shape, dtype="float32"):
np_data = np.random.random(shape).astype(dtype)
return np_data
class PrimeNet(
paddle.nn.Layer,
):
def __init__(self):
super(PrimeNet, self).__init__()
self.fc = paddle.nn.Linear(4, 4)
def forward(self, x):
# y = self.fc(x)
out = tensor.mean(x)
return out
class TestPrimForward(unittest.TestCase):
"""
This case only tests prim_forward + to_static + cinn. Thus we need to
set this flag as False to avoid prim_backward.
core.set_prim_backward(False)
"""
def setUp(self):
paddle.seed(2022)
self.shapes = [[2, 4], [64, 16, 4]]
self.dtypes = ["float16", "float32", "float64"]
def train(self, use_prim, data):
for keep_dim in keepdim_conds:
for axis in axes_condis:
return self._train(use_prim, data, axis, keep_dim)
def _train(self, use_prim, data, axis, keep_dim):
paddle.seed(2022)
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
core._set_prim_forward_enabled(use_prim)
if use_prim:
net = apply_to_static(net, use_prim)
res = []
for _ in range(10):
out = net(data)
loss = paddle.mean(out, axis, keep_dim)
loss.backward()
sgd.step()
sgd.clear_grad()
res.append(out.numpy())
self.check_prim(net, use_prim)
return res
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
# Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops)
def test_cinn_prim_forward(self):
for shape in self.shapes:
for dtype in self.dtypes:
# mean-kernel on cpu not support float16
if paddle.device.get_device() == "cpu" and dtype == "float16":
print("need pass this case")
continue
data = generate_data(shape, dtype)
data_t = paddle.to_tensor(data)
data_t.stop_gradient = False
dy_res = self.train(use_prim=False, data=data_t)
cinn_res = self.train(use_prim=True, data=data_t)
np.testing.assert_allclose(
cinn_res,
dy_res,
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
class TestPrimForwardAndBackward(unittest.TestCase):
"""
Test PrimeNet with @to_static + prim forward + prim backward + cinn v.s Dygraph
"""
def setUp(self):
paddle.seed(2022)
self.shapes = [[2, 4], [64, 16, 4]]
self.dtypes = ["float16", "float32", "float64"]
def train(self, use_prim, data):
for keep_dim in keepdim_conds:
for axis in axes_condis:
return self._train(use_prim, data, axis, keep_dim)
def _train(self, use_prim, data, axis, keep_dim):
paddle.seed(2022)
net = PrimeNet()
sgd = paddle.optimizer.SGD(
learning_rate=0.1, parameters=net.parameters()
)
core._set_prim_all_enabled(use_prim)
if use_prim:
net = apply_to_static(net, use_prim)
res = []
for _ in range(10):
out = net(data)
loss = paddle.mean(out, axis, keep_dim)
loss.backward()
sgd.step()
sgd.clear_grad()
res.append(out.numpy())
self.check_prim(net, use_prim)
return res
def check_prim(self, net, use_prim):
if not use_prim:
return
fwd_ops = [op.type for op in net.forward.main_program.block(0).ops]
# Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops)
def test_cinn_prim(self):
plat = platform.system()
if plat == "Linux":
for shape in self.shapes:
for dtype in self.dtypes:
# mean-kernel on cpu not support float16
if (
paddle.device.get_device() == "cpu"
and dtype == "float16"
):
print("need pass this case")
continue
data = generate_data(shape, dtype)
data_t = paddle.to_tensor(data)
data_t.stop_gradient = False
dy_res = self.train(use_prim=False, data=data_t)
cinn_res = self.train(use_prim=True, data=data_t)
np.testing.assert_allclose(
cinn_res,
dy_res,
rtol=TOLERANCE[dtype]['rtol'],
atol=TOLERANCE[dtype]['atol'],
)
else:
pass
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from utils import TOLERANCE
import paddle
import paddle.tensor as tensor
from paddle.fluid import core
def generate_data(shape, dtype="float32"):
np_data = np.random.random(shape).astype(dtype)
return np_data
class Attr:
def __init__(self) -> None:
self.dtype = "float32"
self.keepdim = False
self.axis = None
self.shape = None
def set_dtype(self, dtype) -> None:
self.dtype = dtype
return
def set_keepdim(self, keepdim) -> None:
self.keepdim = keepdim
return
def set_axis(self, axis) -> None:
self.axis = axis
return
def set_shape(self, shape) -> None:
self.shape = shape
return
def get_rtol(self, flag):
rtol = TOLERANCE[self.dtype][flag].get("rtol")
return rtol
def get_atol(self, flag):
atol = TOLERANCE[self.dtype][flag].get("atol")
return atol
attrs = Attr()
def fn(x):
return tensor.mean(x, axis=attrs.axis, keepdim=attrs.keepdim)
def expect_forward(inputs):
return fn(inputs)
class TestCompositeMean(unittest.TestCase):
def setUp(self):
self.dtypes = ["float16", "float32", "float64"]
self.keepdim = [False, True]
self.shapes = [[16, 16, 64, 64], [2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def cal_composite(self, inputs):
paddle.enable_static()
core._set_prim_forward_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
y = fn(x)
blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that reduce_mean in original block
self.assertTrue('reduce_mean' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops_new)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[y])
paddle.disable_static()
core._set_prim_forward_enabled(False)
return res
def compare_forward(self):
np_data = generate_data(attrs.shape, attrs.dtype)
tensor_data = paddle.to_tensor(np_data)
expect = expect_forward(tensor_data).numpy()
actual = self.cal_composite(np_data)[0]
assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("forward"),
atol=attrs.get_atol("forward"),
)
def test_forward(self):
for i in self.axes:
for j in self.dtypes:
for t in self.shapes:
for k in self.keepdim:
# mean-kernel on cpu not support float16
if (
paddle.device.get_device() == "cpu"
and j == "float16"
):
print("need pass this case")
continue
attrs.set_axis(i)
attrs.set_dtype(j)
attrs.set_shape(t)
attrs.set_keepdim(k)
self.compare_forward()
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from utils import TOLERANCE
import paddle
import paddle.tensor as tensor
from paddle.fluid import core
def generate_data(shape, dtype="float32"):
np_data = np.random.random(shape).astype(dtype)
return np_data
class Attr:
def __init__(self) -> None:
self.dtype = "float32"
self.keepdim = False
self.axis = None
self.shape = None
def set_dtype(self, dtype) -> None:
self.dtype = dtype
return
def set_keepdim(self, keepdim) -> None:
self.keepdim = keepdim
return
def set_axis(self, axis) -> None:
self.axis = axis
return
def set_shape(self, shape) -> None:
self.shape = shape
return
def get_rtol(self, flag):
rtol = TOLERANCE[self.dtype][flag].get("rtol")
return rtol
def get_atol(self, flag):
atol = TOLERANCE[self.dtype][flag].get("atol")
return atol
attrs = Attr()
def fn(x):
return tensor.mean(x, axis=attrs.axis, keepdim=attrs.keepdim)
def expect_grad(inputs):
paddle.disable_static()
inputs.stop_gradient = False
res = fn(inputs)
gradients = paddle.grad(res, inputs)
return gradients
class TestCompositeMean(unittest.TestCase):
def setUp(self):
self.dtypes = ["float16", "float32", "float64"]
self.keepdim = [False, True]
self.shapes = [[16, 16, 64, 64], [2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def cal_composite_grad(self, inputs):
paddle.enable_static()
core._set_prim_forward_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that reduce_mean in original block
self.assertTrue('reduce_mean' in fwd_ops)
paddle.incubate.autograd.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that reduce_mean is splitted into small ops
self.assertTrue('reduce_mean' not in fwd_ops_new)
z = paddle.static.gradients([y], x)
fwd_ops_grad = [op.type for op in blocks[0].ops]
# Ensure that reduce_mean_grad not in grad block
self.assertTrue('reduce_mean_grad' not in fwd_ops_grad)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
core._set_prim_forward_enabled(False)
return res
def compare_backward(self):
np_data = generate_data(attrs.shape, attrs.dtype)
tensor_data = paddle.to_tensor(np_data)
expect = expect_grad(tensor_data)[0].numpy()
actual = self.cal_composite_grad(np_data)[0]
assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("backward"),
atol=attrs.get_atol("backward"),
)
def test_backward(self):
for i in self.axes:
for j in self.dtypes:
for t in self.shapes:
for k in self.keepdim:
# mean-kernel on cpu not support float16
if (
paddle.device.get_device() == "cpu"
and j == "float16"
):
print("need pass this case")
continue
attrs.set_axis(i)
attrs.set_dtype(j)
attrs.set_shape(t)
attrs.set_keepdim(k)
self.compare_backward()
class TestCompositeMeanPrimBackward(unittest.TestCase):
"test composite mean and prim backward"
def setUp(self):
self.dtypes = ["float16", "float32", "float64"]
self.keepdim = [False, True]
self.shapes = [[16, 16, 64, 64], [2, 3, 4], [2, 3]]
self.axes = [-1, 0, 1]
def cal_composite_grad(self, inputs):
paddle.enable_static()
core._set_prim_all_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
'x', shape=inputs.shape, dtype=str(inputs.dtype)
)
x.stop_gradient = False
y = fn(x)
blocks = main_program.blocks
paddle.incubate.autograd.to_prim(blocks)
z = paddle.static.gradients([y], x)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(main_program, feed={'x': inputs}, fetch_list=[z])
paddle.disable_static()
core._set_prim_all_enabled(False)
return res
def compare_backward(self):
np_data = generate_data(attrs.shape, attrs.dtype)
tensor_data = paddle.to_tensor(np_data)
expect = expect_grad(tensor_data)[0].numpy()
actual = self.cal_composite_grad(np_data)[0]
assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=attrs.get_rtol("prim_backward"),
atol=attrs.get_rtol("prim_backward"),
)
def test_prim_backward(self):
for i in self.axes:
for j in self.dtypes:
for t in self.shapes:
for k in self.keepdim:
# mean-kernel on cpu not support float16
if (
paddle.device.get_device() == "cpu"
and j == "float16"
):
print("need pass this case")
continue
attrs.set_axis(i)
attrs.set_dtype(j)
attrs.set_shape(t)
attrs.set_keepdim(k)
self.compare_backward()
if __name__ == '__main__':
unittest.main()
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
# 2. The name and args of target op must be corresponding with standard description of op in # 2. The name and args of target op must be corresponding with standard description of op in
# ops.yaml or legacy_ops.yaml. # ops.yaml or legacy_ops.yaml.
import functools
import operator
from .primitives import * # noqa: F403 from .primitives import * # noqa: F403
from .primreg import REGISTER_COMPOSITE, lookup_composite from .primreg import REGISTER_COMPOSITE, lookup_composite
...@@ -130,3 +132,20 @@ def gelu_composite(x, approximate): ...@@ -130,3 +132,20 @@ def gelu_composite(x, approximate):
cdf = half * (one + erf(x * full(x.shape, M_SQRT1_2, x.dtype))) cdf = half * (one + erf(x * full(x.shape, M_SQRT1_2, x.dtype)))
out = x * cdf out = x * cdf
return out return out
@REGISTER_COMPOSITE('reduce_mean')
def mean_composite(x, axis, keepdim):
"""define composite rule of op mean"""
axes = axis or list(range(0, len(x.shape)))
axes = [axes] if isinstance(axes, int) else axes
sum_x = sum(x, axis=axes, keepdim=keepdim)
value_to_fill = functools.reduce(
operator.mul, [x.shape[axis] for axis in axes]
)
norm = fill_constant(
shape=sum_x.shape,
value=value_to_fill,
dtype=sum_x.dtype,
)
return divide(sum_x, norm)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册