提交 b1025cf5 编写于 作者: S sweetsky0901

add norm_op for ssd(cross channel norm)

上级 a87f4963
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/norm_op.h"
namespace paddle {
namespace operators {
class NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NormOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(
"X",
"(Tensor) The input tensor of norm operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddInput("Scale",
"(Tensor) The input tensor of norm operator. "
"The format of input tensor is C * 1.");
AddAttr<float>("epsilon",
"(float, default 1e-10) Constant "
"for numerical stability.")
.SetDefault(1.0e-10f);
AddOutput("Out",
"(Tensor) The output tensor of norm operator."
"N * M."
"M = C * H * W");
AddComment(R"DOC(
"Input shape: $(N, C, H, W)$
Sclae shape: $(C, 1)$
Output shape: $(N, C, H, W)$
Where
forward
$$
[\frac {x_{1}}{\sqrt{\sum{x_{i}^{2}}}} \frac {x_{2}}{\sqrt{\sum{x_{i}^{2}}}} \frac {x_{3}}{\sqrt{\sum{x_{i}^{2}}}} \cdot \cdot \cdot \frac {x_{n}}{\sqrt{\sum{x_{i}^{2}}}}]
$$
backward
$$
\frac{\frac{\mathrm{d}L }{\mathrm{d}y_{1}} - \frac {x_{1}\sum {\frac{\mathrm{d} L}{\mathrm{d} y_{j}}}x_{j}}{\sum x_{j}^{2}} }{\sqrt{\sum{x_{j}^{2}}}}
$$
)DOC");
}
};
class NormOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of NormOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of NormOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", in_x_dims);
}
};
class NormOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.device_context());
}
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker, norm_grad, ops::NormOpGrad);
REGISTER_OP_CPU_KERNEL(
norm, ops::NormKernel<paddle::platform::CPUDeviceContext, float>,
ops::NormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
norm_grad, ops::NormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::NormGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/norm_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
norm, ops::NormKernel<paddle::platform::CUDADeviceContext, float>,
ops::NormKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
norm_grad, ops::NormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::NormGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
class NormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
auto* out = context.Output<framework::Tensor>("Out");
T epsilon = context.Attr<T>("epsilon");
out->mutable_data<T>(context.GetPlace());
int batch_size = in_x->dims()[0];
int channels = in_x->dims()[1];
int height = in_x->dims()[2];
int width = in_x->dims()[3];
int fea_len = height * width;
auto* place =
context.template device_context<DeviceContext>().eigen_device();
auto x = EigenMatrix<T>::From(
*in_x, framework::make_ddim({batch_size, fea_len * channels}));
// get square
framework::Tensor x_square;
x_square.mutable_data<T>(in_x->dims(), context.GetPlace());
auto x_square_eigen = EigenMatrix<T>::From(
x_square, framework::make_ddim({batch_size, fea_len * channels}));
x_square_eigen.device(*place) = x.square();
auto scale_eigen = EigenVector<T>::Flatten(*scale);
for (int n = 0; n < batch_size; ++n) {
framework::Tensor in_x_batch = in_x->Slice(n, n + 1);
auto in_x_batch_eigen = EigenMatrix<T>::From(
in_x_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor x_square_batch = x_square.Slice(n, n + 1);
auto x_square_batch_eigen = EigenMatrix<T>::From(
x_square_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor out_batch = out->Slice(n, n + 1);
auto out_batch_eigen = EigenMatrix<T>::From(
out_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor tmp_tensor;
tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
context.GetPlace());
auto tmp = EigenVector<T>::Flatten(tmp_tensor);
// get colsum and sqrt , inverse
auto dim = Eigen::array<int, 1>({{0}});
tmp.device(*place) = x_square_batch_eigen.sum(dim);
tmp.device(*place) = (tmp + epsilon).sqrt().inverse();
Eigen::array<int, 2> broadcast_dim_col;
broadcast_dim_col[1] = 1;
broadcast_dim_col[0] = channels;
out_batch_eigen.device(*place) =
in_x_batch_eigen * (tmp.broadcast(broadcast_dim_col));
Eigen::array<int, 2> broadcast_dim_row;
broadcast_dim_row[1] = fea_len;
broadcast_dim_row[0] = 1;
out_batch_eigen.device(*place) =
out_batch_eigen * (scale_eigen.broadcast(broadcast_dim_row));
}
}
};
template <typename DeviceContext, typename T>
class NormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
const framework::Tensor* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
T epsilon = context.Attr<T>("epsilon");
framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
in_x_grad->mutable_data<T>(context.GetPlace());
int batch_size = in_x->dims()[0];
int channels = in_x->dims()[1];
int height = in_x->dims()[2];
int width = in_x->dims()[3];
int fea_len = height * width;
auto* place =
context.template device_context<DeviceContext>().eigen_device();
auto scale_eigen = EigenVector<T>::Flatten(*scale);
auto x = EigenMatrix<T>::From(
*in_x, framework::make_ddim({batch_size, fea_len * channels}));
// get square
framework::Tensor x_square;
x_square.mutable_data<T>(in_x->dims(), context.GetPlace());
auto x_square_eigen = EigenMatrix<T>::From(
x_square, framework::make_ddim({batch_size, fea_len * channels}));
x_square_eigen.device(*place) = x.square();
for (int n = 0; n < batch_size; ++n) {
framework::Tensor in_x_batch = in_x->Slice(n, n + 1);
auto in_x_batch_eigen = EigenMatrix<T>::From(
in_x_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor in_g_batch = in_x_grad->Slice(n, n + 1);
auto in_g_batch_eigen = EigenMatrix<T>::From(
in_g_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor x_square_batch = x_square.Slice(n, n + 1);
auto x_square_batch_eigen = EigenMatrix<T>::From(
x_square_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor outg_batch = out_grad->Slice(n, n + 1);
auto outg_batch_eigen = EigenMatrix<T>::From(
outg_batch, framework::make_ddim({channels, fea_len}));
framework::Tensor tmp_tensor;
tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
context.GetPlace());
auto tmp_eigen = EigenVector<T>::Flatten(tmp_tensor);
auto dim = Eigen::array<int, 1>({{0}});
tmp_eigen.device(*place) = (in_x_batch_eigen * outg_batch_eigen).sum(dim);
framework::Tensor norm_tmp_tensor;
norm_tmp_tensor.mutable_data<T>(framework::make_ddim({1, fea_len}),
context.GetPlace());
auto norm_tmp_eigen = EigenVector<T>::Flatten(norm_tmp_tensor);
norm_tmp_eigen.device(*place) =
(x_square_batch_eigen.sum(dim) + epsilon).sqrt();
Eigen::array<int, 2> broadcast_dim_col;
broadcast_dim_col[1] = 1;
broadcast_dim_col[0] = channels;
in_g_batch_eigen.device(*place) =
in_x_batch_eigen * tmp_eigen.broadcast(broadcast_dim_col);
in_g_batch_eigen.device(*place) =
in_g_batch_eigen /
(norm_tmp_eigen * norm_tmp_eigen).broadcast(broadcast_dim_col);
in_g_batch_eigen.device(*place) = outg_batch_eigen - in_g_batch_eigen;
// outg_batch_eigen + (in_g_batch_eigen * -1);
in_g_batch_eigen.device(*place) =
in_g_batch_eigen / norm_tmp_eigen.broadcast(broadcast_dim_col);
Eigen::array<int, 2> broadcast_dim_row;
broadcast_dim_row[1] = fea_len;
broadcast_dim_row[0] = 1;
in_g_batch_eigen.device(*place) =
in_g_batch_eigen * (scale_eigen.broadcast(broadcast_dim_row));
}
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
def norm(input, scale, epsilon):
s0, s1, s2, s3 = input.shape
x_square = input * input
for i in xrange(s0):
input_batch = input[i:i + 1, :, :, :]
input_batch = input_batch.reshape(s1, s2 * s3)
x_square_batch = x_square[i:i + 1, :, :, :]
x_square_batch = x_square_batch.reshape(s1, s2 * s3)
square_colsum = x_square_batch.sum(axis=0) + epsilon
tmp = pow(square_colsum, 0.5)
tmp = np.reciprocal(tmp)
tmp_tile = np.tile(tmp, s1)
tmp_tile = tmp_tile.reshape(s1, s2 * s3)
scale_tile = np.tile(scale, (1, s2 * s3))
scale_tile = scale_tile.reshape(s1, s2 * s3)
out_batch = input_batch * tmp_tile * scale_tile
out_batch = out_batch.reshape(1, s1, s2, s3)
if i == 0:
out = out_batch
else:
out = np.concatenate((out, out_batch), 0)
out.reshape(s0, s1, s2, s3)
return out
class TestNormOp(OpTest):
def setUp(self):
self.op_type = "norm"
self.init_test_case()
input = np.random.random(self.shape).astype("float32")
scale = np.array([10, 10, 10])
self.inputs = {
'X': input.astype('float32'),
'Scale': scale.astype('float32')
}
self.attrs = {'epsilon': self.epsilon}
output = norm(input, scale, self.epsilon)
self.outputs = {'Out': output.astype('float32')}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def init_test_case(self):
self.shape = [1, 3, 2, 2]
self.epsilon = 1e-6
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册