未验证 提交 ff40a7e5 编写于 作者: C cyber-pioneer 提交者: GitHub

[Prim] support batch_norm vjp (#51283)

* add bn vjp

* fix example

* fix code

* fix code

* fix cinn case

* fix code

* fix example

* fix code

* fix example

* fix example
上级 eefe601c
...@@ -23,6 +23,10 @@ limitations under the License. */ ...@@ -23,6 +23,10 @@ limitations under the License. */
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/fluid/framework/infershape_utils.h" #include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/multiary.h"
...@@ -534,6 +538,70 @@ phi::KernelKey BatchNormDoubleGradOp::GetExpectedKernelType( ...@@ -534,6 +538,70 @@ phi::KernelKey BatchNormDoubleGradOp::GetExpectedKernelType(
ctx.GetPlace()); ctx.GetPlace());
} }
class BatchNormCompositeGradOpMaker : public prim::CompositeGradOpMakerBase {
using prim::CompositeGradOpMakerBase::CompositeGradOpMakerBase;
public:
void Apply() override {
// inputs and outputs of batch_norm
paddle::Tensor x = this->GetSingleForwardInput("X");
paddle::Tensor scale = this->GetSingleForwardInput("Scale");
paddle::Tensor bias = this->GetSingleForwardInput("Bias");
paddle::Tensor mean = this->GetSingleForwardInput("Mean");
paddle::Tensor variance = this->GetSingleForwardInput("Variance");
paddle::Tensor y = this->GetSingleForwardOutput("Y");
paddle::Tensor mean_out = this->GetSingleForwardOutput("MeanOut");
paddle::Tensor variance_out = this->GetSingleForwardOutput("VarianceOut");
paddle::Tensor saved_mean = this->GetSingleForwardOutput("SavedMean");
paddle::Tensor saved_variance =
this->GetSingleForwardOutput("SavedVariance");
paddle::optional<paddle::Tensor> reserve_space;
paddle::Tensor y_grad = this->GetSingleOutputGrad("Y");
paddle::Tensor x_grad = this->GetSingleInputGrad("X");
paddle::Tensor scale_grad = this->GetSingleInputGrad("Scale");
paddle::Tensor bias_grad = this->GetSingleInputGrad("Bias");
auto dx_ptr = this->GetOutputPtr(&x_grad);
std::string dx_name = this->GetOutputName(x_grad);
auto dscale_ptr = this->GetOutputPtr(&scale_grad);
std::string dscale_name = this->GetOutputName(scale_grad);
auto dbias_ptr = this->GetOutputPtr(&bias_grad);
std::string dbias_name = this->GetOutputName(bias_grad);
// attrs of batch_norm
auto momentum = this->Attr<float>("momentum");
auto epsilon = this->Attr<float>("epsilon");
auto data_layout = this->Attr<std::string>("data_layout");
auto is_test = this->Attr<bool>("is_test");
auto use_global_stats = this->Attr<bool>("use_global_stats");
auto trainable_statistics = this->Attr<bool>("trainable_statistics");
VLOG(3) << "Runing batch_norm composite func";
prim::batch_norm_grad<prim::DescTensor>(x,
scale,
bias,
mean_out,
variance_out,
saved_mean,
saved_variance,
reserve_space,
y_grad,
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
dx_ptr,
dscale_ptr,
dbias_ptr);
this->RecoverOutputName(x_grad, dx_name);
this->RecoverOutputName(scale_grad, dscale_name);
this->RecoverOutputName(bias_grad, dbias_name);
}
};
DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"}); DECLARE_INPLACE_OP_INFERER(BatchNormDoubleGradOpInplaceInferer, {"DY", "DDY"});
} // namespace operators } // namespace operators
...@@ -550,7 +618,8 @@ REGISTER_OPERATOR(batch_norm, ...@@ -550,7 +618,8 @@ REGISTER_OPERATOR(batch_norm,
ops::BatchNormOpMaker, ops::BatchNormOpMaker,
ops::BatchNormOpInferVarType, ops::BatchNormOpInferVarType,
ops::BatchNormGradMaker<paddle::framework::OpDesc>, ops::BatchNormGradMaker<paddle::framework::OpDesc>,
ops::BatchNormGradMaker<paddle::imperative::OpBase>); ops::BatchNormGradMaker<paddle::imperative::OpBase>,
ops::BatchNormCompositeGradOpMaker);
REGISTER_OPERATOR(batch_norm_grad, REGISTER_OPERATOR(batch_norm_grad,
ops::BatchNormGradOp, ops::BatchNormGradOp,
......
...@@ -1022,5 +1022,168 @@ void dropout_grad(const Tensor& mask, ...@@ -1022,5 +1022,168 @@ void dropout_grad(const Tensor& mask,
} }
} }
} }
template <typename T>
void batch_norm_grad(const Tensor& x,
const Tensor& scale,
const Tensor& bias,
const paddle::optional<Tensor>& mean_out,
const paddle::optional<Tensor>& variance_out,
const Tensor& saved_mean,
const Tensor& saved_variance,
const paddle::optional<Tensor>& reserve_space,
const Tensor& out_grad,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
Tensor* x_grad,
Tensor* scale_grad,
Tensor* bias_grad) {
use_global_stats = is_test || use_global_stats;
DataLayout data_layout_ = phi::StringToDataLayout(data_layout);
Tensor x_data = x;
Tensor out_grad_data = out_grad;
if (x.dtype() == phi::DataType::FLOAT16) {
x_data = cast<T>(x, phi::DataType::FLOAT32);
}
if (out_grad.dtype() == phi::DataType::FLOAT16) {
out_grad_data = cast<T>(out_grad, phi::DataType::FLOAT32);
}
auto x_dims = x_data.dims();
const int C = (data_layout_ == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
int nume = 1;
for (auto i = 0; i < x_dims.size(); i++) {
nume = nume * x_dims[i];
}
const int nhw = nume / C;
if (x_dims.size() == 2 && data_layout_ == DataLayout::kNCHW) {
data_layout_ = DataLayout::kNHWC;
}
auto run_var = variance_out.get();
auto run_mean = mean_out.get();
Tensor mean_data;
Tensor rsqrt_var;
if (use_global_stats) {
auto eps =
full<T>(phi::vectorize(run_var.dims()), epsilon, run_var.dtype());
mean_data = run_mean;
rsqrt_var = 1 / (run_var + eps).pow(0.5);
} else {
mean_data = saved_mean;
rsqrt_var = saved_variance;
}
// inv_var = 1 / sqrt(var + eps)
// reduce_axis = [0, 2, 3] (NCHW) [0, 1, 2] (NHWC)
//
// d_bias = np.sum(d_y, reduce_axis)
// d_scale = np.sum((X - mean) / inv_var * dy, reduce_axis)
//
// train mode
// d_x = (1. / nhw) * scale * inv_var
// *(nhw * d_y - np.sum(d_y, reduce_axis) - (X - mean) * inv_var * inv_var *
// np.sum(d_y * (X - mean), reduce_axis))
//
// test mode
// d_x = d_y * scale * inv_var
std::vector<int> nchw_to_nhwc_dim = {0, 2, 3, 1};
std::vector<int> nhwc_to_nchw_dim = {0, 3, 1, 2};
auto reduce_axis = IntArray(std::vector<int>{0, 1, 2});
auto dtype = x_data.dtype();
switch (data_layout_) {
case DataLayout::kNCHW: {
auto nhwc_x = transpose<T>(x_data, nchw_to_nhwc_dim);
auto nhwc_out_grad = transpose<T>(out_grad_data, nchw_to_nhwc_dim);
auto x_sub_mean = nhwc_x - mean_data;
if (x_grad) {
if (use_global_stats) {
auto nhwc_x_grad = scale * rsqrt_var * nhwc_out_grad;
auto nchw_x_grad = transpose<T>(nhwc_x_grad, nhwc_to_nchw_dim);
set_output<T>(nchw_x_grad, x_grad);
} else {
auto part1 = scale * rsqrt_var;
auto mean_temp1 =
sum<T>(nhwc_out_grad, reduce_axis, dtype, false) / nhw;
auto tmp = nhwc_out_grad * x_sub_mean * rsqrt_var * rsqrt_var / nhw;
auto mean_temp2 = sum<T>(tmp, reduce_axis, dtype, false);
auto part2 = nhwc_out_grad - mean_temp1 - x_sub_mean * mean_temp2;
auto x_grad_data = part1 * part2;
auto nchw_x_grad = transpose<T>(x_grad_data, nhwc_to_nchw_dim);
if (x.dtype() == phi::DataType::FLOAT16) {
nchw_x_grad = cast<T>(nchw_x_grad, x.dtype());
}
set_output<T>(nchw_x_grad, x_grad);
}
}
if (scale_grad) {
auto scale_grad_data = sum<T>(
nhwc_out_grad * x_sub_mean * rsqrt_var, reduce_axis, dtype, false);
set_output<T>(scale_grad_data, scale_grad);
}
if (bias_grad) {
auto bias_grad_data = sum<T>(nhwc_out_grad, reduce_axis, dtype, false);
set_output<T>(bias_grad_data, bias_grad);
}
break;
}
case DataLayout::kNHWC: {
if (x_grad) {
auto x_sub_mean = x_data - mean_data;
if (use_global_stats) {
auto x_grad_data = scale * rsqrt_var * out_grad_data;
set_output<T>(x_grad_data, x_grad);
} else {
auto part1 = scale * rsqrt_var;
auto mean_temp1 =
sum<T>(out_grad_data, reduce_axis, dtype, false) / nhw;
auto tmp = out_grad_data * x_sub_mean * rsqrt_var * rsqrt_var / nhw;
auto mean_temp2 = sum<T>(tmp, reduce_axis, dtype, false);
auto part2 = out_grad - mean_temp1 - x_sub_mean * mean_temp2;
auto x_grad_data = part1 * part2;
if (x.dtype() == phi::DataType::FLOAT16) {
x_grad_data = cast<T>(x_grad_data, x.dtype());
}
set_output<T>(x_grad_data, x_grad);
}
if (scale_grad) {
auto scale_grad_data = sum<T>(out_grad_data * x_sub_mean * rsqrt_var,
reduce_axis,
dtype,
false);
set_output<T>(scale_grad_data, scale_grad);
}
if (bias_grad) {
auto bias_grad_data =
sum<T>(out_grad_data, reduce_axis, dtype, false);
set_output<T>(bias_grad_data, bias_grad);
}
break;
}
}
default:
PADDLE_THROW(phi::errors::InvalidArgument("Unknown storage order: %s",
data_layout));
}
}
} // namespace prim } // namespace prim
} // namespace paddle } // namespace paddle
...@@ -131,6 +131,7 @@ ...@@ -131,6 +131,7 @@
func : batch_norm_grad func : batch_norm_grad
data_type : out_grad data_type : out_grad
optional : mean_out, variance_out, reserve_space optional : mean_out, variance_out, reserve_space
composite: batch_norm_grad(x, scale, bias, mean_out, variance_out, saved_mean, saved_variance, reserve_space, out_grad, momentum, epsilon, data_layout, is_test, use_global_stats, trainable_statistics)
backward : batch_norm_double_grad backward : batch_norm_double_grad
- backward_op : bce_loss_grad - backward_op : bce_loss_grad
......
...@@ -427,6 +427,7 @@ class TestResnet(unittest.TestCase): ...@@ -427,6 +427,7 @@ class TestResnet(unittest.TestCase):
def test_resnet_composite(self): def test_resnet_composite(self):
core._set_prim_backward_enabled(True) core._set_prim_backward_enabled(True)
core._add_skip_comp_ops("batch_norm")
static_loss = self.train(to_static=True) static_loss = self.train(to_static=True)
core._set_prim_backward_enabled(False) core._set_prim_backward_enabled(False)
dygraph_loss = self.train(to_static=False) dygraph_loss = self.train(to_static=False)
......
...@@ -137,73 +137,97 @@ def expect_forward( ...@@ -137,73 +137,97 @@ def expect_forward(
) )
def cal_static(inputs, running_mean, running_variance, weight, bias, mode=None):
paddle.enable_static()
core._set_prim_all_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x1 = paddle.static.data(
'x1', shape=inputs.shape, dtype=str(inputs.dtype)
)
x2 = paddle.static.data(
'x2', shape=running_mean.shape, dtype=str(running_mean.dtype)
)
x3 = paddle.static.data(
'x3',
shape=running_variance.shape,
dtype=str(running_variance.dtype),
)
x4 = paddle.static.data(
'x4', shape=weight.shape, dtype=str(weight.dtype)
)
x5 = paddle.static.data('x5', shape=bias.shape, dtype=str(bias.dtype))
y = fn(
x1,
x2,
x3,
x4,
x5,
attrs.training,
attrs.momentum,
attrs.epsilon,
attrs.data_format,
attrs.use_global_stats,
)
blocks = main_program.blocks
names = dict(
zip(
blocks[0].ops[0].output_names, blocks[0].ops[0].output_arg_names
)
)
vars_list = [
names[key]
for key in [
"Y",
"MeanOut",
"VarianceOut",
"SavedMean",
"SavedVariance",
]
]
fwd_ops = [op.type for op in blocks[0].ops]
# Ensure that batch_norm in original block
assert 'batch_norm' in fwd_ops
if mode:
primapi.to_prim(blocks)
fwd_ops_new = [op.type for op in blocks[0].ops]
# Ensure that batch_norm is splitted into small ops
assert 'batch_norm' not in fwd_ops_new
exe = paddle.static.Executor()
exe.run(startup_program)
# indeed SavedVariance is 1/sqrt(batch_var+eps)
Y, MeanOut, VarianceOut, SavedMean, SavedVariance = exe.run(
main_program,
feed={
'x1': inputs,
'x2': running_mean,
'x3': running_variance,
'x4': weight,
'x5': bias,
},
fetch_list=vars_list,
)
paddle.disable_static()
core._set_prim_all_enabled(False)
return Y, MeanOut, VarianceOut, SavedMean, SavedVariance
class TestCompositeBatchNorm(unittest.TestCase): class TestCompositeBatchNorm(unittest.TestCase):
def setUp(self): def setUp(self):
self.dtypes = ["float32", "float64"] self.dtypes = ["float32", "float64"]
self.training = [False, True] self.training = [False, True]
self.shapes = [[8, 8, 16, 16], [2, 1, 2, 3]] self.shapes = [[8, 8, 16, 16], [2, 3, 4, 4]]
self.momentum = [0.1, 0.9] self.momentum = [0.1, 0.9]
self.data_formats = ["NCHW", "NHWC"] self.data_formats = ["NCHW", "NHWC"]
self.use_global_stats = [None, True, False] self.use_global_stats = [None, True, False]
def cal_composite(
self, inputs, running_mean, running_variance, weight, bias
):
paddle.enable_static()
core._set_prim_all_enabled(True)
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x1 = paddle.static.data(
'x1', shape=inputs.shape, dtype=str(inputs.dtype)
)
x2 = paddle.static.data(
'x2', shape=running_mean.shape, dtype=str(running_mean.dtype)
)
x3 = paddle.static.data(
'x3',
shape=running_variance.shape,
dtype=str(running_variance.dtype),
)
x4 = paddle.static.data(
'x4', shape=weight.shape, dtype=str(weight.dtype)
)
x5 = paddle.static.data(
'x5', shape=bias.shape, dtype=str(bias.dtype)
)
y = fn(
x1,
x2,
x3,
x4,
x5,
attrs.training,
attrs.momentum,
attrs.epsilon,
attrs.data_format,
attrs.use_global_stats,
)
blocks = main_program.blocks
primapi.to_prim(blocks)
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(
main_program,
feed={
'x1': inputs,
'x2': running_mean,
'x3': running_variance,
'x4': weight,
'x5': bias,
},
fetch_list=[y],
)
paddle.disable_static()
core._set_prim_all_enabled(False)
return res
def compare_forward(self): def compare_forward(self):
np_data = generate_data(attrs.shape, attrs.dtype) np_data = generate_data(attrs.shape, attrs.dtype)
tensor_data = paddle.to_tensor(np_data) tensor_data = paddle.to_tensor(np_data)
...@@ -234,32 +258,87 @@ class TestCompositeBatchNorm(unittest.TestCase): ...@@ -234,32 +258,87 @@ class TestCompositeBatchNorm(unittest.TestCase):
np_running_variance = np.ones(C, dtype=attrs.dtype) np_running_variance = np.ones(C, dtype=attrs.dtype)
np_weight = np.ones(C, dtype=attrs.dtype) * 2 np_weight = np.ones(C, dtype=attrs.dtype) * 2
np_bias = np.ones(C, dtype=attrs.dtype) np_bias = np.ones(C, dtype=attrs.dtype)
actual = self.cal_composite( res_origin = cal_static(
np_data, np_running_mean, np_running_variance, np_weight, np_bias np_data, np_running_mean, np_running_variance, np_weight, np_bias
)[0] )
assert expect.dtype == actual.dtype res_prim = cal_static(
np_data,
np_running_mean,
np_running_variance,
np_weight,
np_bias,
mode="prim",
)
# prim out vs dygraph mode out
assert expect.dtype == res_prim[0].dtype
np.testing.assert_allclose( np.testing.assert_allclose(
expect, expect,
actual, res_prim[0],
rtol=attrs.get_rtol("forward"), rtol=attrs.get_rtol("forward"),
atol=attrs.get_atol("forward"), atol=attrs.get_atol("forward"),
) )
# prim all outs vs origin static all outs
use_global_stats = attrs.use_global_stats
if use_global_stats is None:
use_global_stats = not attrs.training
trainable_statistics = False
else:
trainable_statistics = not use_global_stats
test_mode = (not attrs.training) and (not trainable_statistics)
global_stats = test_mode or use_global_stats
vars_name = [
"Y",
"MeanOut",
"VarianceOut",
"SavedMean",
"SavedVariance",
]
assert len(res_origin) == len(res_prim)
for idx in range(len(res_origin)):
if global_stats and idx >= 3:
# In this case saved_mean and saved_var are not expected.
continue
origin_item = res_origin[idx]
prim_item = res_prim[idx]
assert origin_item.dtype == prim_item.dtype
rtol = attrs.get_rtol("forward")
atol = attrs.get_atol("forward")
if attrs.dtype == "float64" and idx in (1, 2, 3):
atol = 1e-7
rtol = 1e-7
if not isinstance(
framework._current_expected_place(), core.CPUPlace
) and idx in (2, 3):
atol = 5e-3
rtol = 5e-3
np.testing.assert_allclose(
origin_item,
prim_item,
rtol=atol,
atol=rtol,
err_msg=f"Check diff failed of output: {vars_name[idx]}",
)
def test_forward(self): def test_forward(self):
for i in self.training: for i in self.training:
for j in self.dtypes: for j in self.dtypes:
for m in self.momentum: for k in self.use_global_stats:
attrs.set_training(i) attrs.set_training(i)
attrs.set_dtype(j) attrs.set_dtype(j)
attrs.set_momentum(m) attrs.set_use_global_stats(k)
self.compare_forward() self.compare_forward()
for n in self.shapes: for n in self.shapes:
for s in self.data_formats: for m in self.momentum:
for t in self.use_global_stats: for s in self.data_formats:
attrs.set_momentum(m)
attrs.set_shape(n) attrs.set_shape(n)
attrs.set_data_format(s) attrs.set_data_format(s)
attrs.set_use_global_stats(t)
self.compare_forward() self.compare_forward()
......
...@@ -142,6 +142,7 @@ class TestResnet(unittest.TestCase): ...@@ -142,6 +142,7 @@ class TestResnet(unittest.TestCase):
def test_prim(self): def test_prim(self):
# todo: to be removed after adjust of rtol # todo: to be removed after adjust of rtol
core._set_prim_forward_blacklist("batch_norm") core._set_prim_forward_blacklist("batch_norm")
core._add_skip_comp_ops("batch_norm")
dy2st_prim = train(to_static=True, enable_prim=True, enable_cinn=False) dy2st_prim = train(to_static=True, enable_prim=True, enable_cinn=False)
# NOTE: Now dy2st is equal to dy2st_prim. With the splitting of kernels, the threshold here may need to be adjusted # NOTE: Now dy2st is equal to dy2st_prim. With the splitting of kernels, the threshold here may need to be adjusted
np.testing.assert_allclose(self.dy2st, dy2st_prim, rtol=1e-6) np.testing.assert_allclose(self.dy2st, dy2st_prim, rtol=1e-6)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.fluid import core
np.random.seed(2023)
class Arg:
dout = None
def generate_data(shape, dtype="float32"):
np_data = np.random.random(shape).astype(dtype)
return np_data
class Attr:
def __init__(self) -> None:
self.dtype = "float32"
self.shape = [8, 8, 16, 16]
self.training = True
self.momentum = 0.9
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
def set_dtype(self, dtype) -> None:
self.dtype = dtype
return
def set_shape(self, shape) -> None:
self.shape = shape
return
def set_training(self, training) -> None:
self.training = training
return
def set_momentum(self, momentum) -> None:
self.momentum = momentum
return
def set_epsilon(self, epsilon) -> None:
self.epsilon = epsilon
return
def set_data_format(self, data_format) -> None:
self.data_format = data_format
return
def set_use_global_stats(self, use_global_stats) -> None:
self.use_global_stats = use_global_stats
return
attrs = Attr()
def fn(
x,
running_mean,
running_variance,
weight,
bias,
training,
momentum,
epsilon,
data_format,
use_global_stats,
):
z = F.batch_norm(
x,
running_mean,
running_variance,
weight,
bias,
training=training,
momentum=momentum,
epsilon=epsilon,
data_format=data_format,
use_global_stats=use_global_stats,
)
out = z * paddle.to_tensor(Arg.dout)
res = paddle.mean(out)
return res
def expect_grad(
x,
running_mean,
running_variance,
weight,
bias,
training,
momentum,
epsilon,
data_format,
use_global_stats,
):
x.stop_gradient = False
res = fn(
x,
running_mean,
running_variance,
weight,
bias,
training,
momentum,
epsilon,
data_format,
use_global_stats,
)
gradients = paddle.grad(res, x)
return gradients
def cal_composite(inputs, running_mean, running_variance, weight, bias):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x1 = paddle.static.data(
'x1', shape=inputs.shape, dtype=str(inputs.dtype)
)
x1.stop_gradient = False
x2 = paddle.static.data(
'x2', shape=running_mean.shape, dtype=str(running_mean.dtype)
)
x3 = paddle.static.data(
'x3',
shape=running_variance.shape,
dtype=str(running_variance.dtype),
)
x4 = paddle.static.data(
'x4', shape=weight.shape, dtype=str(weight.dtype)
)
x5 = paddle.static.data('x5', shape=bias.shape, dtype=str(bias.dtype))
y = fn(
x1,
x2,
x3,
x4,
x5,
attrs.training,
attrs.momentum,
attrs.epsilon,
attrs.data_format,
attrs.use_global_stats,
)
z = paddle.static.gradients([y], [x1])
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(
main_program,
feed={
'x1': inputs,
'x2': running_mean,
'x3': running_variance,
'x4': weight,
'x5': bias,
},
fetch_list=[z],
)
paddle.disable_static()
return res
class TestCompositeBatchNorm(unittest.TestCase):
def setUp(self):
self.dtypes = ["float32"]
self.training = [False, True]
self.shapes = [[8, 8, 16, 16], [2, 1, 2, 3]]
self.momentum = [0.1, 0.9]
self.epsilon = [1e-05, 2e-05]
self.data_formats = ["NCHW"]
self.use_global_stats = [None, True, False]
def compare_backward(self):
if attrs.training is True and attrs.use_global_stats is False:
# in this case, origin bn grad kernel is not the same as forward kernel.
return
np_data = generate_data(attrs.shape, attrs.dtype)
tensor_data = paddle.to_tensor(np_data)
Arg.dout = np.random.random(np_data.shape).astype(attrs.dtype)
C = np_data.shape[1]
running_mean = paddle.zeros(C, dtype=attrs.dtype)
running_variance = paddle.ones(C, dtype=attrs.dtype)
weight = paddle.ones(C, dtype=attrs.dtype) * 2
bias = paddle.ones(C, dtype=attrs.dtype)
expect = expect_grad(
tensor_data,
running_mean,
running_variance,
weight,
bias,
attrs.training,
attrs.momentum,
attrs.epsilon,
attrs.data_format,
attrs.use_global_stats,
)[0].numpy()
np_running_mean = np.zeros(C, dtype=attrs.dtype)
np_running_variance = np.ones(C, dtype=attrs.dtype)
np_weight = np.ones(C, dtype=attrs.dtype) * 2
np_bias = np.ones(C, dtype=attrs.dtype)
actual = cal_composite(
np_data, np_running_mean, np_running_variance, np_weight, np_bias
)[0]
assert expect.dtype == actual.dtype
np.testing.assert_allclose(
expect,
actual,
rtol=1e-5,
atol=1e-5,
)
def test_backward_prim_dygraph_vjp(self):
core.set_prim_eager_enabled(True)
for i in self.training:
for j in self.dtypes:
for m in self.momentum:
attrs.set_training(i)
attrs.set_dtype(j)
attrs.set_momentum(m)
self.compare_backward()
for n in self.shapes:
for t in self.use_global_stats:
attrs.set_shape(n)
attrs.set_use_global_stats(t)
self.compare_backward()
core.set_prim_eager_enabled(False)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle.fluid import core, framework
np.random.seed(2023)
class Arg:
dout = None
def generate_data(shape, dtype="float32"):
np_data = np.random.random(shape).astype(dtype)
return np_data
class Attr:
def __init__(self) -> None:
self.dtype = "float32"
self.shape = [8, 8, 16, 16]
self.training = True
self.momentum = 0.9
self.epsilon = 1e-05
self.data_format = "NCHW"
self.use_global_stats = None
def set_dtype(self, dtype) -> None:
self.dtype = dtype
return
def set_shape(self, shape) -> None:
self.shape = shape
return
def set_training(self, training) -> None:
self.training = training
return
def set_momentum(self, momentum) -> None:
self.momentum = momentum
return
def set_epsilon(self, epsilon) -> None:
self.epsilon = epsilon
return
def set_data_format(self, data_format) -> None:
self.data_format = data_format
return
def set_use_global_stats(self, use_global_stats) -> None:
self.use_global_stats = use_global_stats
return
attrs = Attr()
def fn(
x,
running_mean,
running_variance,
weight,
bias,
training,
momentum,
epsilon,
data_format,
use_global_stats,
):
z = F.batch_norm(
x,
running_mean,
running_variance,
weight,
bias,
training=training,
momentum=momentum,
epsilon=epsilon,
data_format=data_format,
use_global_stats=use_global_stats,
)
out = z * paddle.to_tensor(Arg.dout)
res = paddle.mean(out)
return res
def expect_grad(
x,
running_mean,
running_variance,
weight,
bias,
training,
momentum,
epsilon,
data_format,
use_global_stats,
):
x.stop_gradient = False
weight.stop_gradient = False
bias.stop_gradient = False
res = fn(
x,
running_mean,
running_variance,
weight,
bias,
training,
momentum,
epsilon,
data_format,
use_global_stats,
)
gradients = paddle.grad(res, (x, weight, bias))
return gradients
def cal_composite(inputs, running_mean, running_variance, weight, bias):
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x1 = paddle.static.data(
'x1', shape=inputs.shape, dtype=str(inputs.dtype)
)
x1.stop_gradient = False
x2 = paddle.static.data(
'x2', shape=running_mean.shape, dtype=str(running_mean.dtype)
)
x3 = paddle.static.data(
'x3',
shape=running_variance.shape,
dtype=str(running_variance.dtype),
)
x4 = paddle.static.data(
'x4', shape=weight.shape, dtype=str(weight.dtype)
)
x4.stop_gradient = False
x5 = paddle.static.data('x5', shape=bias.shape, dtype=str(bias.dtype))
x5.stop_gradient = False
y = fn(
x1,
x2,
x3,
x4,
x5,
attrs.training,
attrs.momentum,
attrs.epsilon,
attrs.data_format,
attrs.use_global_stats,
)
blocks = main_program.blocks
paddle.incubate.autograd.primapi.to_prim(blocks)
z = paddle.static.gradients([y], [x1, x4, x5])
exe = paddle.static.Executor()
exe.run(startup_program)
res = exe.run(
main_program,
feed={
'x1': inputs,
'x2': running_mean,
'x3': running_variance,
'x4': weight,
'x5': bias,
},
fetch_list=[z],
)
paddle.disable_static()
return res
class TestCompositeBatchNorm(unittest.TestCase):
def setUp(self):
self.dtypes = ["float32", "float64"]
self.training = [False, True]
self.shapes = [[8, 8, 16, 16], [2, 4, 3, 3]]
self.momentum = [0.1, 0.9]
self.epsilon = [1e-05, 2e-05]
self.data_formats = ["NCHW", "NHWC"]
self.use_global_stats = [None, True, False]
def compare_backward(self):
np_data = generate_data(attrs.shape, attrs.dtype)
tensor_data = paddle.to_tensor(np_data)
Arg.dout = np.random.random(np_data.shape).astype(attrs.dtype)
if attrs.data_format == 'NCHW':
C = np_data.shape[1]
elif attrs.data_format == 'NHWC':
C = np_data.shape[-1]
else:
raise TypeError
running_mean = paddle.zeros(C, dtype=attrs.dtype)
running_variance = paddle.ones(C, dtype=attrs.dtype)
weight = paddle.ones(C, dtype=attrs.dtype) * 2
bias = paddle.ones(C, dtype=attrs.dtype)
res_origin = expect_grad(
tensor_data,
running_mean,
running_variance,
weight,
bias,
attrs.training,
attrs.momentum,
attrs.epsilon,
attrs.data_format,
attrs.use_global_stats,
)
np_running_mean = np.zeros(C, dtype=attrs.dtype)
np_running_variance = np.ones(C, dtype=attrs.dtype)
np_weight = np.ones(C, dtype=attrs.dtype) * 2
np_bias = np.ones(C, dtype=attrs.dtype)
res_prim = cal_composite(
np_data, np_running_mean, np_running_variance, np_weight, np_bias
)
vars_name = ["x_grad", "weight_grad", "bias_grad"]
assert len(res_origin) == len(res_prim)
for idx in range(len(res_origin)):
origin_item = res_origin[idx].numpy()
prim_item = res_prim[idx]
assert origin_item.dtype == prim_item.dtype
rtol = 1e-5
atol = 1e-5
if (
not isinstance(
framework._current_expected_place(), core.CPUPlace
)
and attrs.data_format == "NHWC"
):
rtol = 1e-4
atol = 1e-4
if idx in (1, 2):
continue
np.testing.assert_allclose(
origin_item,
prim_item,
rtol=rtol,
atol=atol,
err_msg=f"Check diff failed of output: {vars_name[idx]} with data_format: {attrs.data_format}",
)
def test_backward_prim_static_vjp(self):
core._set_prim_backward_enabled(True)
for i in self.training:
for j in self.dtypes:
for k in self.data_formats:
for m in self.momentum:
attrs.set_training(i)
attrs.set_dtype(j)
attrs.set_data_format(k)
attrs.set_momentum(m)
self.compare_backward()
for s in self.training:
for n in self.shapes:
for t in self.use_global_stats:
attrs.set_training(s)
attrs.set_shape(n)
attrs.set_use_global_stats(t)
self.compare_backward()
core._set_prim_backward_enabled(False)
if __name__ == '__main__':
unittest.main()
...@@ -134,6 +134,7 @@ class TestResnet50Accuracy(unittest.TestCase): ...@@ -134,6 +134,7 @@ class TestResnet50Accuracy(unittest.TestCase):
loop_num = 10 loop_num = 10
feed = self.generate_random_data(loop_num) feed = self.generate_random_data(loop_num)
core._set_prim_backward_enabled(True) core._set_prim_backward_enabled(True)
core._add_skip_comp_ops("batch_norm")
loss_c = self.train(place, loop_num, feed, use_cinn=True) loss_c = self.train(place, loop_num, feed, use_cinn=True)
core._set_prim_backward_enabled(False) core._set_prim_backward_enabled(False)
loss_p = self.train(place, loop_num, feed, use_cinn=True) loss_p = self.train(place, loop_num, feed, use_cinn=True)
......
...@@ -119,16 +119,19 @@ def composite_batchnorm( ...@@ -119,16 +119,19 @@ def composite_batchnorm(
if is_amp: if is_amp:
y = cast(y, "float16") y = cast(y, "float16")
# As the same with op kernel, indeed return inverse std
inv_std = 1.0 / sqrt(batch_var + epsilon)
# add op assign to detach tensor in void unsafe change outside the rule. # add op assign to detach tensor in void unsafe change outside the rule.
batch_mean_ = assign(reshape(batch_mean, run_mean.shape)) batch_mean_ = assign(reshape(batch_mean, run_mean.shape))
batch_var_ = assign(reshape(batch_var, run_var.shape)) inv_std_ = assign(reshape(inv_std, run_var.shape))
run_mean_ = assign(run_mean) run_mean_ = assign(run_mean)
run_var_ = assign(run_var) run_var_ = assign(run_var)
# reserve_space is not needed in composite rule, but still ruturn None to keep same as phi op definition. # reserve_space is not needed in composite rule, but still ruturn None to keep same as phi op definition.
reserve_space = None reserve_space = None
return y, run_mean_, run_var_, batch_mean_, batch_var_, reserve_space return y, run_mean_, run_var_, batch_mean_, inv_std_, reserve_space
@REGISTER_COMPOSITE('layer_norm') @REGISTER_COMPOSITE('layer_norm')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册