未验证 提交 4155e625 编写于 作者: L lvmengsi 提交者: GitHub

add instance norm (#19500)

* add instance norm op
上级 c7f36e7c
......@@ -133,6 +133,7 @@ paddle.fluid.layers.pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'po
paddle.fluid.layers.adaptive_pool2d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '52343203de40afe29607397e13aaf0d2'))
paddle.fluid.layers.adaptive_pool3d (ArgSpec(args=['input', 'pool_size', 'pool_type', 'require_index', 'name'], varargs=None, keywords=None, defaults=('max', False, None)), ('document', '55db6ae7275fb9678a6814aebab81a9c'))
paddle.fluid.layers.batch_norm (ArgSpec(args=['input', 'act', 'is_test', 'momentum', 'epsilon', 'param_attr', 'bias_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var', 'fuse_with_relu', 'use_global_stats'], varargs=None, keywords=None, defaults=(None, False, 0.9, 1e-05, None, None, 'NCHW', False, None, None, None, False, False, False)), ('document', '404741b5690228c493a2d9f59c6b1122'))
paddle.fluid.layers.instance_norm (ArgSpec(args=['input', 'epsilon', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1e-05, None, None, None)), ('document', 'c124b947a6ac4d01f491275561b9c1ab'))
paddle.fluid.layers.data_norm (ArgSpec(args=['input', 'act', 'epsilon', 'param_attr', 'data_layout', 'in_place', 'name', 'moving_mean_name', 'moving_variance_name', 'do_model_average_for_mean_and_var'], varargs=None, keywords=None, defaults=(None, 1e-05, None, 'NCHW', False, None, None, None, False)), ('document', '2460b30fb87037555208fa8ac6fc1787'))
paddle.fluid.layers.beam_search_decode (ArgSpec(args=['ids', 'scores', 'beam_size', 'end_id', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '83e08f21af41ac8bac37aeab1f86fdd0'))
paddle.fluid.layers.conv2d_transpose (ArgSpec(args=['input', 'num_filters', 'output_size', 'filter_size', 'padding', 'stride', 'dilation', 'groups', 'param_attr', 'bias_attr', 'use_cudnn', 'act', 'name'], varargs=None, keywords=None, defaults=(None, None, 0, 1, 1, None, None, None, True, None, None)), ('document', '6d3b135bb3834d58ef2cb581ead1487c'))
......
......@@ -18,6 +18,7 @@ limitations under the License. */
#include <unordered_map>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/norm_utils.h"
namespace paddle {
namespace operators {
......@@ -96,26 +97,5 @@ class BatchNormGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &ctx) const override;
};
inline void ExtractNCWHD(const framework::DDim &dims,
const DataLayout &data_layout, int *N, int *C, int *H,
int *W, int *D) {
*N = dims[0];
if (dims.size() == 2) {
*C = dims[1];
*H = 1;
*W = 1;
*D = 1;
} else {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1];
*H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*W = dims.size() > 3
? (data_layout == DataLayout::kNCHW ? dims[3] : dims[2])
: 1;
*D = dims.size() > 4
? (data_layout == DataLayout::kNCHW ? dims[4] : dims[3])
: 1;
}
}
} // namespace operators
} // namespace paddle
此差异已折叠。
此差异已折叠。
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/norm_utils.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using DataLayout = framework::DataLayout;
template <typename T>
using EigenArrayMap =
Eigen::Map<Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using ConstEigenArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, Eigen::Dynamic>>;
template <typename T>
using EigenVectorArrayMap = Eigen::Map<Eigen::Array<T, Eigen::Dynamic, 1>>;
template <typename T>
using ConstEigenVectorArrayMap =
Eigen::Map<const Eigen::Array<T, Eigen::Dynamic, 1>>;
class InstanceNormOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
};
class InstanceNormGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
};
class InstanceNormDoubleGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override;
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override;
};
class InstanceNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override;
};
class InstanceNormGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override;
};
class InstanceNormDoubleGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override;
};
class InstanceNormOpInferVarType
: public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", "Y"}};
}
};
template <typename DeviceContext, typename T>
class InstanceNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
};
template <typename DeviceContext, typename T>
class InstanceNormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
};
template <typename DeviceContext, typename T>
class InstanceNormDoubleGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override;
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using DataLayout = framework::DataLayout;
inline void ExtractNCWHD(const framework::DDim &dims,
const DataLayout &data_layout, int *N, int *C, int *H,
int *W, int *D) {
*N = dims[0];
if (dims.size() == 2) {
*C = dims[1];
*H = 1;
*W = 1;
*D = 1;
} else {
*C = data_layout == DataLayout::kNCHW ? dims[1] : dims[dims.size() - 1];
*H = data_layout == DataLayout::kNCHW ? dims[2] : dims[1];
*W = dims.size() > 3
? (data_layout == DataLayout::kNCHW ? dims[3] : dims[2])
: 1;
*D = dims.size() > 4
? (data_layout == DataLayout::kNCHW ? dims[4] : dims[3])
: 1;
}
}
} // namespace operators
} // namespace paddle
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/batch_norm_op.h"
#include "paddle/fluid/operators/norm_utils.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/nccl_helper.h"
......
......@@ -61,6 +61,7 @@ __all__ = [
'adaptive_pool2d',
'adaptive_pool3d',
'batch_norm',
'instance_norm',
'data_norm',
'beam_search_decode',
'conv2d_transpose',
......@@ -3498,6 +3499,128 @@ def batch_norm(input,
return helper.append_activation(batch_norm_out)
def instance_norm(input,
epsilon=1e-05,
param_attr=None,
bias_attr=None,
name=None):
"""
**Instance Normalization Layer**
Can be used as a normalizer function for conv2d and fully_connected operations.
The required data format for this layer is one of the following:
DataLayout: NCHW `[batch, in_channels, in_height, in_width]`
Refer to `Instance Normalization: The Missing Ingredient for
Fast Stylization <https://arxiv.org/pdf/1607.08022.pdf>`_
for more details.
:math:`input` is the input features over a mini-batch.
.. math::
\\mu_{\\beta} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW} x_i \\qquad &//\\
\\ mean of one feature map in mini-batch \\\\
\\sigma_{\\beta}^{2} &\\gets \\frac{1}{HW} \\sum_{i=1}^{HW}(x_i - \\
\\mu_{\\beta})^2 \\qquad &//\ variance of one feature map in mini-batch \\\\
\\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\
\\sigma_{\\beta}^{2} + \\epsilon}} \\qquad &//\ normalize \\\\
y_i &\\gets \\gamma \\hat{x_i} + \\beta \\qquad &//\ scale\ and\ shift
When use_global_stats = True, the :math:`\\mu_{\\beta}`
and :math:`\\sigma_{\\beta}^{2}` are not the statistics of one mini-batch.
They are global (or running) statistics. (It usually got from the
pre-trained model.)
The training and testing (or inference) have the same behavior:
.. math::
\\hat{x_i} &\\gets \\frac{x_i - \\mu_\\beta} {\\sqrt{\\
\\sigma_{\\beta}^{2} + \\epsilon}} \\\\
y_i &\\gets \\gamma \\hat{x_i} + \\beta
Args:
input(variable): The rank of input variable can be 2, 3, 4, 5.
epsilon(float, Default 1e-05): A value added to the denominator for
numerical stability. Default is 1e-5.
param_attr(ParamAttr|None): The parameter attribute for Parameter `scale`
of instance_norm. If it is set to None or one attribute of ParamAttr, instance_norm
will create ParamAttr as param_attr, the name of scale can be set in ParamAttr.
If the Initializer of the param_attr is not set, the parameter is initialized
with Xavier. Default: None.
bias_attr(ParamAttr|None): The parameter attribute for the bias of instance_norm.
If it is set to None or one attribute of ParamAttr, instance_norm
will create ParamAttr as bias_attr, the name of bias can be set in ParamAttr.
If the Initializer of the bias_attr is not set, the bias is initialized zero.
Default: None.
name(string, Default None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: A tensor variable which is the result after applying instance normalization on the input.
Examples:
.. code-block:: python
import paddle.fluid as fluid
x = fluid.layers.data(name='x', shape=[3, 7, 3, 7], dtype='float32', append_batch_size=False)
hidden1 = fluid.layers.fc(input=x, size=200, param_attr='fc1.w')
hidden2 = fluid.layers.instance_norm(input=hidden1)
"""
assert bias_attr is not False, "bias_attr should not be False in instance_norm."
helper = LayerHelper('instance_norm', **locals())
dtype = helper.input_dtype()
# use fp32 for in parameter
if dtype == core.VarDesc.VarType.FP16:
dtype = core.VarDesc.VarType.FP32
input_shape = input.shape
channel_num = input_shape[1]
param_shape = [channel_num]
# create parameter
scale = helper.create_parameter(
attr=helper.param_attr,
shape=param_shape,
dtype=dtype,
default_initializer=Constant(1.0))
bias = helper.create_parameter(
attr=helper.bias_attr,
shape=param_shape,
dtype=dtype,
is_bias=True,
default_initializer=Constant(0.0))
# create output
saved_mean = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
saved_variance = helper.create_variable_for_type_inference(
dtype=dtype, stop_gradient=True)
instance_norm_out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="instance_norm",
inputs={
"X": input,
"Scale": scale,
"Bias": bias,
},
outputs={
"Y": instance_norm_out,
"SavedMean": saved_mean,
"SavedVariance": saved_variance
},
attrs={"epsilon": epsilon, })
return instance_norm_out
def data_norm(input,
act=None,
epsilon=1e-05,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.op import Operator
from op_test import OpTest
def _reference_instance_norm_naive(x, scale, bias, epsilon, mean, var):
x_shape = x.shape
if len(x_shape) == 2:
x = np.reshape(x, (x.shape[0], x.shape[1], 1, 1))
n, c, h, w = x.shape
mean_tile = np.reshape(mean, (n, c, 1, 1))
mean_tile = np.tile(mean_tile, (1, 1, h, w))
var_tile = np.reshape(var, (n, c, 1, 1))
var_tile = np.tile(var_tile, (1, 1, h, w))
x_norm = (x - mean_tile) / np.sqrt(var_tile + epsilon).astype('float32')
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
bias_tile = np.reshape(bias, (1, c, 1, 1))
bias_tile = np.tile(bias_tile, (n, 1, h, w))
y = scale_tile * x_norm + bias_tile
if len(x_shape) == 2:
y = np.reshape(y, x_shape)
return y, mean, var
def _reference_instance_norm_grad(x, d_y, scale, mean, var, epsilon):
# d_scale = sum(d_y * (x-mean) / sqrt(var+epsilon))
# d_offset = sum(d_y)
# d_x = scale / sqrt(var+epsilon) * (d_y - np.mean(d_y, axis=(2,3)) - (x-mean)/sqrt(var+epsilon)* np.mean(y_grad * (x-mean)/sqrt(var+epsilon), axis=(2,3)))
n, c, h, w = x.shape
d_bias = np.sum(d_y, axis=(0, 2, 3))
mean_tile = np.reshape(mean, (n, c, 1, 1))
mean_tile = np.tile(mean_tile, (1, 1, h, w))
var_tile = np.reshape(var, (n, c, 1, 1))
var_tile = np.tile(var_tile, (1, 1, h, w))
d_scale = np.sum(d_y * (x - mean_tile) * var_tile, axis=(0, 2, 3))
var_inv = var_tile
scale_tile = np.reshape(scale, (1, c, 1, 1))
scale_tile = np.tile(scale_tile, (n, 1, h, w))
d_x = scale_tile * var_inv * (d_y - np.mean(
d_y, axis=(2, 3), keepdims=True) - (x - mean_tile) * var_inv * np.mean(
d_y * (x - mean_tile) * var_inv, axis=(2, 3), keepdims=True))
return d_x, d_scale, d_bias
def _cal_mean_variance(x, epsilon, mean_shape):
mean = np.reshape(np.mean(x, axis=(2, 3)), mean_shape)
var = np.reshape(np.var(x, axis=(2, 3)), mean_shape)
return mean, var
class TestInstanceNormOpTraining(unittest.TestCase):
def setUp(self):
self.epsilon = 1e-5
self.init_test_case()
def init_test_case(self):
self.use_global_stats = False
self.no_grad_set = set()
self.fetch_list = [
'y', 'saved_mean', 'saved_variance', 'x@GRAD', 'scale@GRAD',
'bias@GRAD'
]
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
self.assertTrue(np.allclose(np.array(tensor), np_array, atol=atol), msg)
def set_global_mean_var(self, mean_shape, x):
mean, variance = _cal_mean_variance(x, self.epsilon, mean_shape)
return mean, variance
def test_forward_backward(self):
def test_with_place(place, shape):
epsilon = self.epsilon
n, c, h, w = shape[0], shape[1], shape[2], shape[3]
scale_shape = [c]
mean_shape = [n * c]
np.random.seed()
x = np.random.random_sample(shape).astype(np.float32)
scale = np.random.random_sample(scale_shape).astype(np.float32)
bias = np.random.random_sample(scale_shape).astype(np.float32)
mean, variance = self.set_global_mean_var(mean_shape, x)
d_y = np.random.random_sample(shape).astype(np.float32)
y, saved_mean, variance_tmp = _reference_instance_norm_naive(
x, scale, bias, epsilon, mean, variance)
saved_variance = 1 / np.sqrt(variance_tmp + epsilon)
d_x, d_scale, d_bias = _reference_instance_norm_grad(
x, d_y, scale, saved_mean, saved_variance, epsilon)
var_dict = locals()
var_dict['y@GRAD'] = d_y
var_dict['x@GRAD'] = d_x
var_dict['scale@GRAD'] = d_scale
var_dict['bias@GRAD'] = d_bias
var_names = [
'x', 'scale', 'bias', 'y', 'saved_mean', 'saved_variance'
]
ground_truth = {name: var_dict[name] for name in var_names}
program = fluid.Program()
with fluid.program_guard(program):
block = program.global_block()
for name in ground_truth:
block.create_var(
name=name,
dtype='float32',
shape=ground_truth[name].shape)
in_op = block.append_op(
type="instance_norm",
inputs={
"X": block.var("x"),
"Scale": block.var("scale"),
"Bias": block.var("bias"),
},
outputs={
"Y": block.var("y"),
"SavedMean": block.var("saved_mean"),
"SavedVariance": block.var("saved_variance")
},
attrs={"epsilon": epsilon, })
block.create_var(name="y@GRAD", dtype='float32', shape=y.shape)
grad_op_desc_list, op_grad_to_var = core.get_grad_op_desc(
in_op.desc, self.no_grad_set, [])
grad_op_desc = grad_op_desc_list[0]
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(grad_op_desc)
for var_name in grad_op_desc.output_arg_names():
block.desc.var(var_name.encode("ascii"))
grad_op_desc.infer_var_type(block.desc)
grad_op_desc.infer_shape(block.desc)
for arg in grad_op_desc.output_arg_names():
grad_var = block.desc.find_var(arg.encode("ascii"))
grad_var.set_dtype(core.VarDesc.VarType.FP32)
exe = fluid.Executor(place)
out = exe.run(program,
feed={
name: var_dict[name]
for name in ['x', 'scale', 'bias', 'y@GRAD']
},
fetch_list=self.fetch_list)
for id, name in enumerate(self.fetch_list):
self.__assert_close(var_dict[name], out[id], name)
print("op test forward passes: ", str(place))
places = [core.CPUPlace()]
if core.is_compiled_with_cuda() and core.op_support_gpu(
"instance_norm"):
places.append(core.CUDAPlace(0))
for place in places:
test_with_place(place, [2, 3, 4, 5])
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import paddle.fluid.core as core
import gradient_checker
from decorator_helper import prog_scope
class TestInstanceNormDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
prog = fluid.Program()
with fluid.program_guard(prog):
np.random.seed()
shape = [2, 3, 4, 5]
dtype = "float32"
eps = 0.005
atol = 1e-4
x = layers.create_parameter(dtype=dtype, shape=shape, name='x')
z = fluid.layers.instance_norm(input=x)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], z, x_init=x_arr, atol=atol, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册