未验证 提交 77cf305f 编写于 作者: H hong 提交者: GitHub

Add batch norm yaml (#41386)

* update

* fix bug
上级 19cb0d18
......@@ -312,8 +312,8 @@ class InplaceABNGradKernel : public framework::OpKernel<T> {
phi::BatchNormGradRawKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*d_y, *y, *scale, *bias, *saved_mean, *saved_variance, space_opt,
mean_opt, variance_opt, momentum, epsilon, data_layout, is_test,
*y, *scale, *bias, mean_opt, variance_opt, *saved_mean, *saved_variance,
space_opt, *d_y, momentum, epsilon, data_layout, is_test,
use_global_stats, trainable_statistics, fuse_with_relu, true, d_x,
scale_grad, bias_grad);
}
......
......@@ -140,10 +140,10 @@ class InplaceABNGradKernel
phi::BatchNormGradRawKernel<T>(
static_cast<const typename framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*d_y, *y, *scale, *bias, *saved_mean, *saved_variance, space_opt,
mean_opt, variance_opt, momentum, epsilon, data_layout, is_test,
use_global_stats, trainable_statistics, fuse_with_relu, true, d_x,
scale_grad, bias_grad);
*y, *scale, *bias, mean_opt, variance_opt, *saved_mean,
*saved_variance, space_opt, *d_y, momentum, epsilon, data_layout,
is_test, use_global_stats, trainable_statistics, fuse_with_relu, true,
d_x, scale_grad, bias_grad);
}
}
};
......
......@@ -167,6 +167,135 @@ std::vector<Tensor> split_impl(const Tensor& x,
return out;
}
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
const Tensor& x,
const Tensor& scale,
const Tensor& bias,
const Tensor& mean,
const Tensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu) {
Backend kernel_backend = Backend::UNDEFINED;
DataLayout kernel_layout = DataLayout::UNDEFINED;
DataType kernel_data_type = DataType::UNDEFINED;
kernel_data_type = ParseDataType(x);
if (kernel_backend == Backend::UNDEFINED ||
kernel_layout == DataLayout::UNDEFINED ||
kernel_data_type == DataType::UNDEFINED) {
auto kernel_key_set = ParseKernelKeyByInputArgs(x);
auto kernel_key = kernel_key_set.GetHighestPriorityKernelKey();
if (kernel_backend == Backend::UNDEFINED) {
kernel_backend = kernel_key.backend();
}
if (kernel_layout == DataLayout::UNDEFINED) {
kernel_layout = kernel_key.layout();
}
if (kernel_data_type == DataType::UNDEFINED) {
kernel_data_type = kernel_key.dtype();
}
}
const auto& kernel = phi::KernelFactory::Instance().SelectKernelOrThrowError(
"batch_norm", {kernel_backend, kernel_layout, kernel_data_type});
VLOG(6) << "batch_norm API kernel key: [" << kernel_backend << ", "
<< kernel_layout << ", " << kernel_data_type << "]";
VLOG(6) << "batch_norm API kernel: " << kernel;
auto* dev_ctx = GetDeviceContextByBackend(kernel_backend);
auto input_x = PrepareData(x, kernel.InputAt(0), {});
auto input_scale = PrepareData(scale, kernel.InputAt(1), {});
auto input_bias = PrepareData(bias, kernel.InputAt(2), {});
auto input_mean = PrepareData(mean, kernel.InputAt(3), {});
auto input_variance = PrepareData(variance, kernel.InputAt(4), {});
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> api_output;
auto kernel_out_0 = SetKernelOutput(kernel_backend, &std::get<0>(api_output));
std::get<1>(api_output).set_impl(mean.impl());
std::get<2>(api_output).set_impl(variance.impl());
auto kernel_out_1 = SetKernelOutput(kernel_backend, &std::get<1>(api_output));
auto kernel_out_2 = SetKernelOutput(kernel_backend, &std::get<2>(api_output));
auto kernel_out_3 = SetKernelOutput(kernel_backend, &std::get<3>(api_output));
auto kernel_out_4 = SetKernelOutput(kernel_backend, &std::get<4>(api_output));
auto kernel_out_5 = SetKernelOutput(kernel_backend, &std::get<5>(api_output));
phi::MetaTensor meta_out_0(kernel_out_0);
phi::MetaTensor meta_out_1(kernel_out_1);
phi::MetaTensor meta_out_2(kernel_out_2);
phi::MetaTensor meta_out_3(kernel_out_3);
phi::MetaTensor meta_out_4(kernel_out_4);
phi::MetaTensor meta_out_5(kernel_out_5);
phi::BatchNormInferMeta(MakeMetaTensor(*input_x),
MakeMetaTensor(*input_scale),
MakeMetaTensor(*input_bias),
MakeMetaTensor(*input_mean),
MakeMetaTensor(*input_variance),
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
&meta_out_0,
&meta_out_1,
&meta_out_2,
&meta_out_3,
&meta_out_4,
&meta_out_5);
using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
const phi::DenseTensor&,
float,
float,
const std::string&,
bool,
bool,
bool,
bool,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
{
(*kernel_fn)(*dev_ctx,
*input_x,
*input_scale,
*input_bias,
*input_mean,
*input_variance,
momentum,
epsilon,
data_layout,
is_test,
use_global_stats,
trainable_statistics,
fuse_with_relu,
kernel_out_0,
kernel_out_1,
kernel_out_2,
kernel_out_3,
kernel_out_4,
kernel_out_5);
}
return api_output;
}
std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Scalar& axis) {
......
......@@ -31,6 +31,20 @@ std::vector<Tensor> split_impl(const Tensor& x,
const IntArray& num_or_sections,
const Scalar& axis);
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor> batch_norm_impl(
const Tensor& x,
const Tensor& scale,
const Tensor& bias,
const Tensor& mean,
const Tensor& variance,
float momentum,
float epsilon,
const std::string& data_layout,
bool is_test,
bool use_global_stats,
bool trainable_statistics,
bool fuse_with_relu);
std::vector<Tensor> concat_grad_impl(const std::vector<Tensor>& x,
const Tensor& out_grad,
const Scalar& axis);
......
......@@ -21,15 +21,15 @@ namespace phi {
template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& dev_ctx,
const DenseTensor& y_grad,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> reserve_space,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
......@@ -44,15 +44,15 @@ void BatchNormGradRawKernel(const Context& dev_ctx,
template <typename T, typename Context>
void BatchNormGradKernel(const Context& dev_ctx,
const DenseTensor& y_grad,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> reserve_space,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
......
......@@ -37,15 +37,16 @@ using ConstEigenVectorArrayMap =
template <typename T, typename Context>
void BatchNormGradRawKernel(const Context& ctx,
const DenseTensor& y_grad,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> reserve_space,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout_str,
......@@ -122,8 +123,8 @@ void BatchNormGradRawKernel(const Context& ctx,
ctx.template Alloc<T>(d_x);
}
const T* mean_data = saved_mean.data<T>();
const T* inv_var_data = saved_variance.data<T>();
const T* mean_data = nullptr;
const T* inv_var_data = nullptr;
DenseTensor inv_var_tensor;
if (use_global_stats) {
const auto* running_mean = mean.get_ptr();
......@@ -136,6 +137,9 @@ void BatchNormGradRawKernel(const Context& ctx,
inv_var_tmp = (var_arr + epsilon).sqrt().inverse();
inv_var_data = running_inv_var_data;
} else {
mean_data = saved_mean.data<T>();
inv_var_data = saved_variance.data<T>();
}
ConstEigenVectorArrayMap<T> scale_arr(scale.data<T>(), C);
......@@ -293,15 +297,15 @@ void BatchNormGradRawKernel(const Context& ctx,
template <typename T, typename Context>
void BatchNormGradKernel(const Context& dev_ctx,
const DenseTensor& y_grad,
const DenseTensor& x,
const DenseTensor& scale,
const DenseTensor& bias,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& saved_mean,
const DenseTensor& saved_variance,
paddle::optional<const DenseTensor&> reserve_space,
paddle::optional<const DenseTensor&> mean,
paddle::optional<const DenseTensor&> variance,
const DenseTensor& y_grad,
float momentum,
float epsilon,
const std::string& data_layout,
......@@ -313,15 +317,15 @@ void BatchNormGradKernel(const Context& dev_ctx,
DenseTensor* scale_grad,
DenseTensor* bias_grad) {
BatchNormGradRawKernel<T, Context>(dev_ctx,
y_grad,
x,
scale,
bias,
mean,
variance,
saved_mean,
saved_variance,
reserve_space,
mean,
variance,
y_grad,
momentum,
epsilon,
data_layout,
......
......@@ -306,15 +306,15 @@ static __global__ LAUNCH_BOUNDS(BlockDim) void BNBackwardData(
template <typename T, typename Context>
void BatchNormGradRawKernel(const Context &ctx,
const DenseTensor &y_grad,
const DenseTensor &x,
const DenseTensor &scale,
const DenseTensor &bias,
paddle::optional<const DenseTensor &> mean,
paddle::optional<const DenseTensor &> variance,
const DenseTensor &saved_mean,
const DenseTensor &saved_variance,
paddle::optional<const DenseTensor &> reserve_space,
paddle::optional<const DenseTensor &> mean,
paddle::optional<const DenseTensor &> variance,
const DenseTensor &y_grad,
float momentum,
float epsilon_f,
const std::string &data_layout_str,
......@@ -863,15 +863,15 @@ void BatchNormGradRawKernel(const Context &ctx,
template <typename T, typename Context>
void BatchNormGradKernel(const Context &dev_ctx,
const DenseTensor &y_grad,
const DenseTensor &x,
const DenseTensor &scale,
const DenseTensor &bias,
paddle::optional<const DenseTensor &> mean,
paddle::optional<const DenseTensor &> variance,
const DenseTensor &saved_mean,
const DenseTensor &saved_variance,
paddle::optional<const DenseTensor &> reserve_space,
paddle::optional<const DenseTensor &> mean,
paddle::optional<const DenseTensor &> variance,
const DenseTensor &y_grad,
float momentum,
float epsilon,
const std::string &data_layout,
......@@ -883,15 +883,15 @@ void BatchNormGradKernel(const Context &dev_ctx,
DenseTensor *scale_grad,
DenseTensor *bias_grad) {
BatchNormGradRawKernel<T, Context>(dev_ctx,
y_grad,
x,
scale,
bias,
mean,
variance,
saved_mean,
saved_variance,
reserve_space,
mean,
variance,
y_grad,
momentum,
epsilon,
data_layout,
......
......@@ -59,15 +59,17 @@ KernelSignature BatchNormGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"batch_norm_grad",
{GradVarName("Y"),
"X",
"Scale",
"Bias",
"SavedMean",
"SavedVariance",
"ReserveSpace",
"Mean",
"Variance"},
{
"X",
"Scale",
"Bias",
"Mean",
"Variance",
"SavedMean",
"SavedVariance",
"ReserveSpace",
GradVarName("Y"),
},
{"momentum",
"epsilon",
"data_layout",
......
......@@ -1339,15 +1339,22 @@ class BatchNorm(layers.Layer):
variance_out = self._variance
if _non_static_mode():
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout",
self._data_layout, "use_mkldnn", self._use_mkldnn,
"fuse_with_relu", self._fuse_with_relu, "use_global_stats",
self._use_global_stats, 'trainable_statistics',
self._trainable_statistics)
batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm(
input, self.weight, self.bias, self._mean, self._variance,
mean_out, variance_out, *attrs)
if in_dygraph_mode():
batch_norm_out, t1, t2, t3, t4, _ = _C_ops.final_state_batch_norm(
input, self.weight, self.bias, self._mean, self._variance,
self._momentum, self._epsilon, self._data_layout,
not self.training, self._use_global_stats,
self._trainable_statistics, False)
else:
attrs = ("momentum", self._momentum, "epsilon", self._epsilon,
"is_test", not self.training, "data_layout",
self._data_layout, "use_mkldnn", self._use_mkldnn,
"fuse_with_relu", self._fuse_with_relu,
"use_global_stats", self._use_global_stats,
'trainable_statistics', self._trainable_statistics)
batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm(
input, self.weight, self.bias, self._mean, self._variance,
mean_out, variance_out, *attrs)
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn)
......
......@@ -81,6 +81,40 @@ class TestBatchNorm(unittest.TestCase):
self.assertRaises(ValueError, error2d_dataformat)
self.assertRaises(ValueError, error3d_dataformat)
def test_eager_api(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
shape = [4, 10, 4, 4]
def compute_v1(x):
with fluid.dygraph.guard(p):
bn = fluid.dygraph.BatchNorm(shape[1])
#bn = paddle.nn.BatchNorm2D(shape[1])
x1 = paddle.to_tensor(x)
x1.stop_gradient = False
y = bn(x1)
y.backward()
return y.numpy(), x1.gradient()
def compute_v2(x):
with fluid.dygraph.guard(p):
with _test_eager_guard():
print("v2")
bn = paddle.nn.BatchNorm2D(shape[1])
x1 = paddle.to_tensor(x)
x1.stop_gradient = False
y = bn(x1)
y.backward()
return y.numpy(), x1.gradient()
x = np.random.randn(*shape).astype("float32")
y1, g1 = compute_v1(x)
y2, g2 = compute_v2(x)
self.assertTrue(np.allclose(g1, g2))
self.assertTrue(np.allclose(y1, y2))
def test_dygraph(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
......
......@@ -186,15 +186,24 @@ def batch_norm(x,
else:
trainable_statistics = not use_global_stats
if in_dynamic_mode():
if in_dygraph_mode():
batch_norm_out, _, _, _, _, _ = _C_ops.final_state_batch_norm(
x, weight, bias, running_mean, running_var, momentum, epsilon,
data_format, not training, use_global_stats, trainable_statistics,
False)
return batch_norm_out
if _in_legacy_dygraph():
# for dygraph need tuple
attrs = ("momentum", momentum, "epsilon", epsilon, "is_test",
not training, "data_layout", data_format, "use_mkldnn", False,
"fuse_with_relu", False, "use_global_stats", use_global_stats,
"trainable_statistics", trainable_statistics)
batch_norm_out, _, _, _, _, _ = _C_ops.batch_norm(
x, weight, bias, running_mean, running_var, mean_out, variance_out,
*attrs)
return dygraph_utils._append_activation_in_dygraph(
batch_norm_out, act=None)
......
......@@ -207,6 +207,13 @@
kernel :
func : auc
# batch_norm
- api : batch_norm
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
invoke : batch_norm_impl(x, scale, bias, mean, variance, momentum, epsilon, data_layout, is_test, use_global_stats, trainable_statistics, fuse_with_relu)
backward : batch_norm_grad
- api : bce_loss
args : (Tensor input, Tensor label)
output : Tensor
......
......@@ -118,6 +118,18 @@
kernel :
func : atanh_grad
- backward_api : batch_norm_grad
forward : batch_norm (Tensor x, Tensor scale, Tensor bias, Tensor mean, Tensor variance, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu) -> Tensor(out), Tensor(mean_out), Tensor(variance_out), Tensor(saved_mean), Tensor(saved_variance), Tensor(reserve_space)
args : (Tensor x, Tensor scale, Tensor bias, Tensor mean_out, Tensor variance_out, Tensor saved_mean, Tensor saved_variance, Tensor reserve_space, Tensor out_grad, float momentum, float epsilon, str data_layout, bool is_test, bool use_global_stats, bool trainable_statistics, bool fuse_with_relu)
output : Tensor(x_grad), Tensor(scale_grad), Tensor(bias_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [x, scale, bias]
kernel :
func : batch_norm_grad
data_type : out_grad
optional : mean_out, variance_out, reserve_space
- backward_api : bce_loss_grad
forward : bce_loss (Tensor input, Tensor label) -> Tensor(out)
args : (Tensor input, Tensor label, Tensor out_grad)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册