diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 6a37b5ca433a3baa1388dd4f720d782ca53e4e99..1241ae784e373cc9485537a6887cf19ce9667bb9 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -173,6 +173,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.read_file ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index df3e3fcd9c75f03f4d9b0a7c12788f06bfdefd7f..c97225669a572cd62250729a9e4e9f7b674816e4 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -305,6 +305,7 @@ if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) op_library(layer_norm_op DEPS cub) op_library(reduce_mean_op DEPS cub) + op_library(affine_channel_op DEPS cub) else() op_library(conv_op DEPS vol2col im2col) endif() diff --git a/paddle/fluid/operators/affine_channel_op.cc b/paddle/fluid/operators/affine_channel_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..8944a749674c3ba6c83526e4d66f449075716f43 --- /dev/null +++ b/paddle/fluid/operators/affine_channel_op.cc @@ -0,0 +1,255 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class AffineChannelOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor) Feature map input can be a 4D tensor with order NCHW " + "or NHWC. It also can be a 2D tensor and C is the second " + "dimension."); + AddInput("Scale", + "(Tensor) 1D input of shape (C), the c-th element " + "is the scale factor of the affine transformation " + "for the c-th channel of the input."); + AddInput("Bias", + "(Tensor) 1D input of shape (C), the c-th element " + "is the bias of the affine transformation for the " + "c-th channel of the input."); + AddAttr( + "data_layout", + "(string, default NCHW) Only used in " + "An optional string from: \"NHWC\", \"NCHW\". " + "Defaults to \"NHWC\". Specify the data format of the output data, " + "the input will be transformed automatically. ") + .SetDefault("AnyLayout"); + AddOutput("Out", "(Tensor) A tensor of the same shape and order with X."); + AddComment(R"DOC( + +Applies a separate affine transformation to each channel of the input. Useful +for replacing spatial batch norm with its equivalent fixed transformation. +The input also can be 2D tensor and applies a affine transformation in second +dimension. + +$$Out = Scale*X + Bias$$ + +)DOC"); + } +}; + +class AffineChannelOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of AffineChannelOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Scale"), + "Input(Scale) of AffineChannelOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Bias"), + "Input(Bias) of AffineChannelOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of AffineChannelOp should not be null."); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", "Out"); + } +}; + +class AffineChannelOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + if (ctx->HasOutput(framework::GradVarName("X"))) { + PADDLE_ENFORCE(ctx->HasInput("Scale"), + "Input(Scale) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), + ctx->GetInputDim(framework::GradVarName("Out"))); + } + if (ctx->HasOutput(framework::GradVarName("Scale"))) { + // Scale@GRAD and Bias@GRAD must exist at the same time. + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Bias")), + "Output(Scale@GRAD) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + ctx->SetOutputDim(framework::GradVarName("Scale"), + ctx->GetInputDim("Scale")); + ctx->SetOutputDim(framework::GradVarName("Bias"), + ctx->GetInputDim("Scale")); + } + } +}; + +template +using EigenArrayMap = + Eigen::Map>; +template +using ConstEigenArrayMap = + Eigen::Map>; +template +using EigenVectorArrayMap = Eigen::Map>; +template +using ConstEigenVectorArrayMap = + Eigen::Map>; + +template +class AffineChannelKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); + + auto* y = ctx.Output("Out"); + y->mutable_data(ctx.GetPlace()); + + const framework::DataLayout layout = + framework::StringToDataLayout(ctx.Attr("data_layout")); + + auto dims = x->dims(); + int N = dims[0]; + int C = layout == framework::DataLayout::kNCHW ? dims[1] + : dims[dims.size() - 1]; + int HxW = x->numel() / N / C; + + auto* scale_d = scale->data(); + auto* bias_d = bias->data(); + ConstEigenVectorArrayMap a_e(scale_d, C); + ConstEigenVectorArrayMap b_e(bias_d, C); + + auto* x_d = x->data(); + auto* y_d = y->data(); + if (layout == framework::DataLayout::kNCHW) { + int stride = C * HxW; + for (int i = 0; i < N; i++) { + ConstEigenArrayMap x_e(x_d, HxW, C); + EigenArrayMap y_e(y_d, HxW, C); + y_e = (x_e.rowwise() * a_e.transpose()).rowwise() + b_e.transpose(); + x_d += stride; + y_d += stride; + } + } else { + int num = N * HxW; + ConstEigenArrayMap x_e(x_d, C, num); + EigenArrayMap y_e(y_d, C, num); + y_e = (x_e.colwise() * a_e).colwise() + b_e; + } + } +}; + +template +class AffineChannelGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* scale = ctx.Input("Scale"); + auto* dy = ctx.Input(framework::GradVarName("Out")); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dscale = + ctx.Output(framework::GradVarName("Scale")); + auto* dbias = ctx.Output(framework::GradVarName("Bias")); + + const framework::DataLayout layout = + framework::StringToDataLayout(ctx.Attr("data_layout")); + + auto dims = x->dims(); + int N = dims[0]; + int C = layout == framework::DataLayout::kNCHW ? dims[1] + : dims[dims.size() - 1]; + int HxW = x->numel() / N / C; + + auto* x_d = x->data(); + auto* dy_d = dy->data(); + auto* scale_d = scale->data(); + ConstEigenVectorArrayMap scale_e(scale_d, C); + + T* dx_d = dx ? dx->mutable_data(ctx.GetPlace()) : nullptr; + T* dscale_d = dscale ? dscale->mutable_data(ctx.GetPlace()) : nullptr; + T* dbias_d = dbias ? dbias->mutable_data(ctx.GetPlace()) : nullptr; + EigenVectorArrayMap dscale_e(dscale_d, C); + EigenVectorArrayMap dbias_e(dbias_d, C); + + if (layout == framework::DataLayout::kNCHW) { + // compute dx + int stride = C * HxW; + if (dx) { + for (int i = 0; i < N; i++) { + ConstEigenArrayMap dy_e(dy_d, HxW, C); + EigenArrayMap dx_e(dx_d, HxW, C); + dx_e = dy_e.rowwise() * scale_e.transpose(); + dy_d += stride; + dx_d += stride; + } + } + // compute dscale and dbias + if (dscale && dbias) { + dy_d = dy->data(); + for (int i = 0; i < N; i++) { + ConstEigenArrayMap x_e(x_d, HxW, C); + ConstEigenArrayMap dy_e(dy_d, HxW, C); + if (i == 0) { + dscale_e = (x_e * dy_e).colwise().sum(); + } else { + dscale_e += (x_e * dy_e).colwise().sum(); + } + if (i == 0) { + dbias_e = dy_e.colwise().sum(); + } else { + dbias_e += dy_e.colwise().sum(); + } + x_d += stride; + dy_d += stride; + } + } + } else { + int num = N * HxW; + ConstEigenArrayMap dy_e(dy_d, C, num); + // compute dx + if (dx) { + EigenArrayMap dx_e(dx_d, C, num); + dx_e = dy_e.colwise() * scale_e; + } + // compute dscale and dbias + if (dscale && dbias) { + ConstEigenArrayMap x_e(x_d, C, num); + dscale_e = (x_e * dy_e).rowwise().sum(); + dbias_e = dy_e.rowwise().sum(); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR(affine_channel, ops::AffineChannelOp, + ops::AffineChannelOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(affine_channel_grad, ops::AffineChannelOpGrad); + +REGISTER_OP_CPU_KERNEL(affine_channel, ops::AffineChannelKernel, + ops::AffineChannelKernel); +REGISTER_OP_CPU_KERNEL(affine_channel_grad, + ops::AffineChannelGradKernel, + ops::AffineChannelGradKernel); diff --git a/paddle/fluid/operators/affine_channel_op.cu b/paddle/fluid/operators/affine_channel_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..2bebdb345ab324eb0a2dafd54c74833dd21bdb6d --- /dev/null +++ b/paddle/fluid/operators/affine_channel_op.cu @@ -0,0 +1,187 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +Indicesou may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "cub/cub.cuh" +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +template +__global__ void KeAffineChannelCUDA(const T* x, const T* scale, const T* bias, + const int C, const int HxW, const int num, + T* y) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + for (int i = gid; i < num; i += stride) { + const int c = layout == framework::DataLayout::kNCHW ? i / HxW % C : i % C; + if (HasBias) { + y[i] = scale[c] * x[i] + bias[c]; + } else { + y[i] = scale[c] * x[i]; + } + } +} + +template +class AffineChannelCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); + + auto* y = ctx.Output("Out"); + y->mutable_data(ctx.GetPlace()); + + const framework::DataLayout layout = + framework::StringToDataLayout(ctx.Attr("data_layout")); + auto& dev_ctx = ctx.template device_context(); + + auto dims = x->dims(); + const int num = x->numel(); + int N = dims[0]; + int C = layout == framework::DataLayout::kNCHW ? dims[1] + : dims[dims.size() - 1]; + int HxW = num / N / C; + + const T* x_d = x->data(); + const T* scale_d = scale->data(); + const T* bias_d = bias->data(); + T* y_d = y->data(); + + int block = 1024; + int grid = (num + block - 1) / block; + if (layout == framework::DataLayout::kNCHW) { + KeAffineChannelCUDA<<>>( + x_d, scale_d, bias_d, C, HxW, num, y_d); + } else { + KeAffineChannelCUDA<<>>( + x_d, scale_d, bias_d, C, HxW, num, y_d); + } + } +}; + +template +__global__ void AffineChannelScaleBiasGradientCUDAKernel( + const T* dy, const T* x, const int N, const int C, const int HxW, T* dscale, + T* dbias) { + const int outer_size = C; + const int inner_size = N * HxW; + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage ds_storage; + __shared__ typename BlockReduce::TempStorage db_storage; + + for (int i = blockIdx.x; i < outer_size; i += gridDim.x) { + T ds_sum = 0; + T db_sum = 0; + for (int j = threadIdx.x; j < inner_size; j += blockDim.x) { + const int index = layout == framework::DataLayout::kNCHW + ? (j / HxW * C + i) * HxW + j % HxW + : j * outer_size + i; + ds_sum += dy[index] * x[index]; + db_sum += dy[index]; + } + ds_sum = BlockReduce(ds_storage).Reduce(ds_sum, cub::Sum()); + db_sum = BlockReduce(db_storage).Reduce(db_sum, cub::Sum()); + if (threadIdx.x == 0) { + dscale[i] = ds_sum; + dbias[i] = db_sum; + } + __syncthreads(); + } +} + +template +class AffineChannelGradCUDAKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* x = ctx.Input("X"); + auto* scale = ctx.Input("Scale"); + auto* bias = ctx.Input("Bias"); + auto* dy = ctx.Input(framework::GradVarName("Out")); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dscale = + ctx.Output(framework::GradVarName("Scale")); + auto* dbias = ctx.Output(framework::GradVarName("Bias")); + + const framework::DataLayout layout = + framework::StringToDataLayout(ctx.Attr("data_layout")); + auto& dev_ctx = ctx.template device_context(); + + auto dims = x->dims(); + const int num = x->numel(); + int N = dims[0]; + int C = layout == framework::DataLayout::kNCHW ? dims[1] + : dims[dims.size() - 1]; + int HxW = num / N / C; + + const T* x_d = x->data(); + const T* dy_d = dy->data(); + const T* s_d = scale->data(); + + T* dx_d = dx ? dx->mutable_data(ctx.GetPlace()) : nullptr; + T* ds_d = dscale ? dscale->mutable_data(ctx.GetPlace()) : nullptr; + T* db_d = dbias ? dbias->mutable_data(ctx.GetPlace()) : nullptr; + + const int block = 1024; + int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); + const int max_blocks = std::max(max_threads / block, 1); + int grid1 = (num + block - 1) / block; + int grid2 = std::min(C, max_blocks); + if (layout == framework::DataLayout::kNCHW) { + if (dx) { + KeAffineChannelCUDA<<>>( + dy_d, s_d, nullptr, C, HxW, num, dx_d); + } + if (dscale && dbias) { + AffineChannelScaleBiasGradientCUDAKernel< + T, block, framework::DataLayout::kNCHW><<>>( + dy_d, x_d, N, C, HxW, ds_d, db_d); + } + } else { + if (dx) { + KeAffineChannelCUDA<<>>( + dy_d, s_d, nullptr, C, HxW, num, dx_d); + } + if (dscale && dbias) { + AffineChannelScaleBiasGradientCUDAKernel< + T, block, framework::DataLayout::kNHWC><<>>( + dy_d, x_d, N, C, HxW, ds_d, db_d); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CUDA = paddle::platform::CUDADeviceContext; + +REGISTER_OP_CUDA_KERNEL(affine_channel, + ops::AffineChannelCUDAKernel, + ops::AffineChannelCUDAKernel); +REGISTER_OP_CUDA_KERNEL(affine_channel_grad, + ops::AffineChannelGradCUDAKernel, + ops::AffineChannelGradCUDAKernel); diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 224781e6596e2b2bc6d3084f4780a33ed2903118..249e0c5c3990cd7b9cc33b24651c64942a1f7521 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -153,6 +153,7 @@ __all__ = [ 'mul', 'sigmoid_cross_entropy_with_logits', 'maxout', + 'affine_channel', ] @@ -7268,3 +7269,44 @@ def maxout(x, groups, name=None): attrs={"groups": groups}, outputs={"Out": out}) return out + + +def affine_channel(x, scale=None, bias=None, data_layout='NCHW', name=None): + """ + Applies a separate affine transformation to each channel of the input. + Useful for replacing spatial batch norm with its equivalent fixed + transformation. The input also can be 2D tensor and applies a affine + transformation in second dimension. + + Args: + x (Variable): Feature map input can be a 4D tensor with order NCHW + or NHWC. It also can be a 2D tensor and the affine transformation + is applied in the second dimension. + scale (Variable): 1D input of shape (C), the c-th element is the scale + factor of the affine transformation for the c-th channel of + the input. + bias (Variable): 1D input of shape (C), the c-th element is the bias + of the affine transformation for the c-th channel of the input. + data_layout (string, default NCHW): NCHW or NHWC. If input is 2D + tensor, you can ignore data_layout. + name (str, default None): The name of this layer. + + Returns: + out (Variable): A tensor of the same shape and data layout with x. + """ + helper = LayerHelper("affine_channel", **locals()) + + if name is None: + out = helper.create_tmp_variable(dtype=x.dtype) + else: + out = helper.create_variable( + name=name, dtype=x.dtype, persistable=False) + + helper.append_op( + type="affine_channel", + inputs={"X": x, + 'Scale': scale, + 'Bias': bias}, + attrs={"data_layout": data_layout}, + outputs={"Out": out}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_affine_channel_op.py b/python/paddle/fluid/tests/unittests/test_affine_channel_op.py new file mode 100644 index 0000000000000000000000000000000000000000..2c9a063e6ee75371e0d05e1ff6964753017881a1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_affine_channel_op.py @@ -0,0 +1,106 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid.core as core + + +def affine_channel(x, scale, bias, layout): + C = x.shape[1] if layout == 'NCHW' else x.shape[-1] + if len(x.shape) == 4: + new_shape = (1, C, 1, 1) if layout == 'NCHW' else (1, 1, 1, C) + else: + new_shape = (1, C) + scale = scale.reshape(new_shape) + bias = bias.reshape(new_shape) + return x * scale + bias + + +class TestAffineChannelOp(OpTest): + def setUp(self): + self.op_type = "affine_channel" + self.init_test_case() + + x = np.random.random(self.shape).astype("float32") + scale = np.random.random(self.C).astype("float32") + bias = np.random.random(self.C).astype("float32") + + y = affine_channel(x, scale, bias, self.layout) + + self.inputs = {'X': x, 'Scale': scale, 'Bias': bias} + self.attrs = {'data_layout': self.layout} + self.outputs = {'Out': y} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X', 'Scale', 'Bias'], 'Out') + + def test_check_grad_stopgrad_dx(self): + self.check_grad(['Scale', 'Bias'], 'Out', no_grad_set=set('X')) + + def test_check_grad_stopgrad_dscale_dbias(self): + self.check_grad(['X'], 'Out', no_grad_set=set(['Scale', 'Bias'])) + + def init_test_case(self): + self.shape = [2, 32, 14, 14] + self.C = 32 + self.layout = 'NCHW' + + +class TestAffineChannelNHWC(TestAffineChannelOp): + def init_test_case(self): + self.shape = [2, 14, 14, 32] + self.C = 32 + self.layout = 'NHWC' + + +class TestAffineChannel2D(TestAffineChannelOp): + def init_test_case(self): + self.shape = [16, 64] + self.C = 64 + self.layout = 'NCHW' + + +class TestAffineChannelNCHWLargeShape(TestAffineChannelOp): + def init_test_case(self): + self.shape = [64, 128, 112, 112] + self.C = 128 + self.layout = 'NCHW' + + # since the gradient check is very slow in large shape, so skip check_grad + def test_check_grad(self): + pass + + def test_check_grad_stopgrad_dx(self): + pass + + def test_check_grad_stopgrad_dscale_dbias(self): + pass + + +class TestAffineChannelNCHWLargeShape(TestAffineChannelNCHWLargeShape): + def init_test_case(self): + self.shape = [64, 112, 112, 512] + self.C = 512 + self.layout = 'NHWC' + + +if __name__ == '__main__': + unittest.main()