diff --git a/paddle/fluid/operators/flip_op.cc b/paddle/fluid/operators/flip_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..afb8cd4af416598cd82ff691aaa012757eba470b --- /dev/null +++ b/paddle/fluid/operators/flip_op.cc @@ -0,0 +1,149 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/flip_op.h" + +#include +#include +#include + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class FlipOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE_EQ( + ctx->HasInput("X"), true, + platform::errors::NotFound("Input(X) of FlipOp should not be null.")); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + platform::errors::NotFound( + "Output(Out) of FlipOp should not be null.")); + auto x_dims = ctx->GetInputDim("X"); + auto flip_dims = ctx->Attrs().Get>("dims"); + size_t flip_dims_size = flip_dims.size(); + + // check if dims axis within range + auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end()); + PADDLE_ENFORCE_LT(*min_max_d.first, x_dims.size(), + platform::errors::InvalidArgument( + "min(dims) should be less than the input tensor X's " + "dimensions of FlipOp. But received min(dims) = %d, " + "X's dimensions = %d, X's shape = [%s]", + *min_max_d.first, x_dims.size(), x_dims)); + PADDLE_ENFORCE_GE( + *min_max_d.first, x_dims.size() * -1, + platform::errors::InvalidArgument( + "min(dims) should be greater than or equal to the input tensor X's " + "dimensions of FlipOp times -1. But received min(dims) = %d, X's " + "dimensions = %d, X's shape = [%s]", + *min_max_d.first, x_dims.size() * -1, x_dims)); + PADDLE_ENFORCE_LT(*min_max_d.second, x_dims.size(), + platform::errors::InvalidArgument( + "max(dims) should be less than the input tensor X's " + "dimensions of FlipOp. But received max(dims) = %d, " + "X's dimensions = %d, X's shape = [%s]", + *min_max_d.second, x_dims.size(), x_dims)); + PADDLE_ENFORCE_GE( + *min_max_d.second, x_dims.size() * -1, + platform::errors::InvalidArgument( + "max(dims) should be greater than or equal to the input tensor X's " + "dimensions of FlipOp times -1. But received max(dims) = %d, X's " + "dimensions = %d, X's shape = [%s]", + *min_max_d.second, x_dims.size() * -1, x_dims)); + + // check duplicates in dims + flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()), + flip_dims.end()); + PADDLE_ENFORCE_EQ(flip_dims.size(), flip_dims_size, + platform::errors::InvalidArgument( + "dims has duplicates, original flip dims size=%d, " + "but unique flip dims size=%d.)", + flip_dims_size, flip_dims.size())); + + VLOG(3) << "flip operator x.shape=" << x_dims; + + std::vector output_dims(x_dims.size()); + for (int i = 0; i < x_dims.size(); ++i) { + output_dims[i] = x_dims[i]; + } + ctx->SetOutputDim("Out", framework::make_ddim(output_dims)); + ctx->ShareLoD("X", "Out"); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const { + framework::LibraryType library = framework::LibraryType::kPlain; + framework::DataLayout layout = framework::DataLayout::kAnyLayout; + int customized_type_value = + framework::OpKernelType::kDefaultCustomizedTypeValue; + auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout, + library, customized_type_value); + } +}; + +class FlipOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), The input tensor of flip op."); + AddOutput("Out", "(Tensor), The output tensor of flip op."); + AddAttr>("dims", "The axes to flip on."); + AddComment(R"DOC( + Flip Operator. + Reverse the order of a n-D tensor along given axis in dims. + )DOC"); + } +}; + +class FlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map GetInputOutputWithSameType() + const override { + return std::unordered_map{{"X", /*->*/ "Out"}}; + } +}; + +template +class FlipOpGradMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("flip"); + retv->SetInput("X", this->OutputGrad("Out")); + retv->SetOutput("Out", this->InputGrad("X")); + retv->SetAttrMap(this->Attrs()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType, + ops::FlipOpGradMaker, + ops::FlipOpGradMaker); +REGISTER_OP_CPU_KERNEL( + flip, ops::FlipKernel, + ops::FlipKernel, + ops::FlipKernel, + ops::FlipKernel, + ops::FlipKernel); diff --git a/paddle/fluid/operators/flip_op.cu b/paddle/fluid/operators/flip_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..898aeabb193e464bc8b17f9fb81708c7137de55d --- /dev/null +++ b/paddle/fluid/operators/flip_op.cu @@ -0,0 +1,166 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/flip_op.h" + +#include +#include "paddle/fluid/memory/malloc.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +using CUDADeviceContext = paddle::platform::CUDADeviceContext; + +template +__global__ void kernel_pointwise_flip_apply(const int N, const T* in_data, + T* out_data, int dim0, int stride0, + int dim1, int flip_dim) { + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < N; + idx += gridDim.x * blockDim.x) { + int dst_offset = 0; + if (flip_dim == 0) { + // flip 1st dim + dst_offset = (dim0 - 1 - idx / stride0) * stride0 + idx % stride0; + } else { + // flip last dim + dst_offset = idx / stride0 * stride0 + (dim1 - 1 - idx % stride0); + } + out_data[dst_offset] = in_data[idx]; + } +} + +template +__global__ void flip_cuda_kernel(const int N, const T* in_data, T* out_data, + int64_t* x_shape, int64_t* x_stride, + int* flip_dims, int flip_dims_size, + int total_dims) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= N) { + return; + } + + int cur_indices = idx, rem = 0, dst_offset = 0; + for (int i = 0; i < total_dims; ++i) { + int64_t temp = cur_indices; + cur_indices = cur_indices / x_stride[i]; + rem = temp - cur_indices * x_stride[i]; + // flip the indices if it is in flip_dims + for (int j = 0; j < flip_dims_size; ++j) { + if (i == flip_dims[j]) { + cur_indices = x_shape[i] - 1 - cur_indices; + } + } + dst_offset += cur_indices * x_stride[i]; + cur_indices = rem; + } + out_data[idx] = in_data[dst_offset]; +} + +template +class FlipKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto gplace = boost::get(ctx.GetPlace()); + auto cplace = platform::CPUPlace(); + auto& dev_ctx = ctx.template device_context(); + + const Tensor* x = ctx.Input("X"); + Tensor* out = ctx.Output("Out"); + auto* in_data = x->data(); + auto* out_data = out->mutable_data(ctx.GetPlace()); + auto flip_dims = ctx.template Attr>("dims"); + + const int flip_dims_size = static_cast(flip_dims.size()); + auto x_dims = x->dims(); + const int total_dims = x_dims.size(); + const int N = x->numel(); + + int block_size = 512; + dim3 dim_block(block_size); + dim3 dim_grid((N + block_size - 1) / block_size); + + for (size_t i = 0; i < flip_dims.size(); ++i) { + if (flip_dims[i] < 0) { + flip_dims[i] += total_dims; + } + } + + auto x_stride = framework::stride(x_dims); + std::vector x_dims_v = framework::vectorize(x_dims); + std::vector x_stride_v = framework::vectorize(x_stride); + + // wrap high-dims to 2-dims + if (flip_dims_size == 1 && + (flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) { + int dim0 = 1, dim1 = 1; + int stride0 = 1; + if (flip_dims[0] == 0) { + dim0 = x_dims_v[0]; + stride0 = x_stride_v[0]; + for (size_t i = 1; i < total_dims; ++i) { + dim1 *= x_dims_v[i]; + } + } else { + dim1 = x_dims_v[total_dims - 1]; + for (size_t i = 0; i < total_dims - 1; ++i) { + dim0 *= x_dims_v[i]; + } + stride0 *= x_dims_v[total_dims - 1]; + } + kernel_pointwise_flip_apply< + T><<>>( + N, in_data, out_data, dim0, stride0, dim1, flip_dims[0]); + } + + int bytes = total_dims * sizeof(int64_t); + auto x_strides_array_tmp = memory::Alloc(dev_ctx, bytes); + int64_t* x_strides_array_gpu = + reinterpret_cast(x_strides_array_tmp->ptr()); + memory::Copy(gplace, x_strides_array_gpu, cplace, x_stride_v.data(), bytes, + dev_ctx.stream()); + + auto x_shape_array_tmp = memory::Alloc(dev_ctx, bytes); + int64_t* x_shape_array_gpu = + reinterpret_cast(x_shape_array_tmp->ptr()); + memory::Copy(gplace, x_shape_array_gpu, cplace, x_dims_v.data(), bytes, + dev_ctx.stream()); + + bytes = flip_dims_size * sizeof(int); + auto flip_dims_array_tmp = memory::Alloc(dev_ctx, bytes); + int* flip_dims_array_gpu = + reinterpret_cast(flip_dims_array_tmp->ptr()); + memory::Copy(gplace, flip_dims_array_gpu, cplace, flip_dims.data(), bytes, + dev_ctx.stream()); + + flip_cuda_kernel< + T><<>>( + N, in_data, out_data, x_shape_array_gpu, x_strides_array_gpu, + flip_dims_array_gpu, flip_dims_size, total_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_CUDA_KERNEL( + flip, ops::FlipKernel, + ops::FlipKernel, + ops::FlipKernel, + ops::FlipKernel, + ops::FlipKernel, + ops::FlipKernel); diff --git a/paddle/fluid/operators/flip_op.h b/paddle/fluid/operators/flip_op.h new file mode 100644 index 0000000000000000000000000000000000000000..73d73f5d0f2e06dc4049f4b10ea7a12d63193c40 --- /dev/null +++ b/paddle/fluid/operators/flip_op.h @@ -0,0 +1,83 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +constexpr size_t dim_bitset_size = 64; + +template +class FlipKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override; +}; + +template +class FlipKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const Tensor* x = ctx.Input("X"); + Tensor* out = ctx.Output("Out"); + auto flip_dims = ctx.template Attr>("dims"); + + auto x_dims = x->dims(); + const int total_dims = x_dims.size(); + std::bitset dim_bitset; + for (size_t i = 0; i < flip_dims.size(); ++i) { + int dim = flip_dims[i]; + if (flip_dims[i] < 0) { + dim += total_dims; + } + dim_bitset[dim] = true; + } + auto x_strides = framework::stride(x_dims); + auto numel = x->numel(); + const T* x_data = x->data(); + T* out_data = out->mutable_data(ctx.GetPlace()); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int64_t i = 0; i < numel; ++i) { + int64_t cur_indices = i; + int64_t rem = 0; + int64_t dst_offset = 0; + + for (int d = 0; d < total_dims; ++d) { + int64_t temp = cur_indices; + cur_indices = cur_indices / x_strides[d]; + rem = temp - cur_indices * x_strides[d]; + dst_offset += dim_bitset[d] + ? (x_dims[d] - 1 - cur_indices) * x_strides[d] + : cur_indices * x_strides[d]; + cur_indices = rem; + } + out_data[i] = x_data[dst_offset]; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index d6f18eebd7b2e0241bccdb913943bd248d2dd878..8d2ed9194d5b7b9251e06d77d8efe940fee8d1eb 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -178,7 +178,7 @@ from .tensor.logic import equal #DEFINE_ALIAS # from .tensor.manipulation import unique_with_counts #DEFINE_ALIAS # from .tensor.manipulation import unsqueeze #DEFINE_ALIAS # from .tensor.manipulation import unstack #DEFINE_ALIAS -# from .tensor.manipulation import flip #DEFINE_ALIAS +from .tensor.manipulation import flip #DEFINE_ALIAS # from .tensor.manipulation import unbind #DEFINE_ALIAS # from .tensor.manipulation import roll #DEFINE_ALIAS # from .tensor.search import argmax #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_flip.py b/python/paddle/fluid/tests/unittests/test_flip.py new file mode 100644 index 0000000000000000000000000000000000000000..77e416e5e6a73c63e775e9a1bdb8b56776bee9d2 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_flip.py @@ -0,0 +1,115 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid import Program, program_guard +from op_test import OpTest + + +class TestFlipOp_API(unittest.TestCase): + """Test flip api.""" + + def test_static_graph(self): + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + dims = [0] + input = fluid.data(name='input', dtype='float32', shape=[2, 3]) + output = paddle.flip(input, dims) + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + exe.run(startup_program) + img = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32) + res = exe.run(train_program, + feed={'input': img}, + fetch_list=[output]) + out_np = np.array(res[0]) + out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32) + self.assertTrue( + (out_np == out_ref).all(), + msg='flip output is wrong, out =' + str(out_np)) + + def test_dygraph(self): + img = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32) + with fluid.dygraph.guard(): + inputs = fluid.dygraph.to_variable(img) + ret = paddle.flip(inputs, [0]) + out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32) + self.assertTrue( + (ret.numpy() == out_ref).all(), + msg='flip output is wrong, out =' + str(ret.numpy())) + + +class TestFlipOp(OpTest): + def setUp(self): + self.op_type = 'flip' + self.init_test_case() + self.inputs = {'X': np.random.random(self.in_shape).astype('float64')} + self.init_attrs() + self.outputs = {'Out': self.calc_ref_res()} + + def init_attrs(self): + self.attrs = {"dims": self.dims} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.in_shape = (6, 4, 2, 3) + self.dims = [0, 1] + + def calc_ref_res(self): + res = self.inputs['X'] + for axis in self.dims: + res = np.flip(res, axis) + return res + + +class TestFlipOpAxis1(TestFlipOp): + def init_test_case(self): + self.in_shape = (2, 4, 4) + self.dims = [0] + + +class TestFlipOpAxis2(TestFlipOp): + def init_test_case(self): + self.in_shape = (4, 4, 6, 3) + self.dims = [0, 2] + + +class TestFlipOpAxis3(TestFlipOp): + def init_test_case(self): + self.in_shape = (4, 3, 1) + self.dims = [0, 1, 2] + + +class TestFlipOpAxis4(TestFlipOp): + def init_test_case(self): + self.in_shape = (6, 4, 2, 2) + self.dims = [0, 1, 2, 3] + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 05598cc8600a9e30717e3febba2b19eae1fb2b41..c578ee5386dc152b893c76503ff98e9c597f421e 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -155,7 +155,7 @@ from .logic import equal #DEFINE_ALIAS # from .manipulation import unique_with_counts #DEFINE_ALIAS # from .manipulation import unsqueeze #DEFINE_ALIAS # from .manipulation import unstack #DEFINE_ALIAS -# from .manipulation import flip #DEFINE_ALIAS +from .manipulation import flip #DEFINE_ALIAS # from .manipulation import unbind #DEFINE_ALIAS # from .manipulation import roll #DEFINE_ALIAS # from .search import argmax #DEFINE_ALIAS diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index e06af8be2d84fae054e4728925a61ccf476e8684..8961d009bfa127b2d20ea87fbc363cf60d9d4474 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -12,30 +12,89 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import print_function + +from ..fluid.layers import core +from ..fluid.layer_helper import LayerHelper +from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, convert_np_dtype_to_dtype_ +from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype + # TODO: define functions to manipulate a tensor -# __all__ = ['cast', -# 'concat', -# 'expand', -# 'expand_as', -# 'flatten', -# 'gather', -# 'gather_nd', -# 'reshape', -# 'reverse', -# 'scatter', -# 'scatter_nd_add', -# 'scatter_nd', -# 'shard_index', -# 'slice', -# 'split', -# 'squeeze', -# 'stack', -# 'strided_slice', -# 'transpose', -# 'unique', -# 'unique_with_counts', -# 'unsqueeze', -# 'unstack', -# 'flip', -# 'unbind', -# 'roll'] +__all__ = [ + # 'cast', + # 'concat', + # 'expand', + # 'expand_as', + # 'flatten', + # 'gather', + # 'gather_nd', + # 'reshape', + # 'reverse', + # 'scatter', + # 'scatter_nd_add', + # 'scatter_nd', + # 'shard_index', + # 'slice', + # 'split', + # 'squeeze', + # 'stack', + # 'strided_slice', + # 'transpose', + # 'unique', + # 'unique_with_counts', + # 'unsqueeze', + # 'unstack', + 'flip', + # 'unbind', + # 'roll' +] + + +def flip(input, dims, name=None): + """ + + Reverse the order of a n-D tensor along given axis in dims. + + Args: + input (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor + should be float32, float64, int32, int64, bool. + dims (list): The axis to flip on. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + + Returns: + Variable: Tensor or LoDTensor calculated by flip layer. The data type is same with input. + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + import numpy as np + input = fluid.data(name="x", shape=[-1, 2, 2], dtype='float32') + output = paddle.flip(input, dims=[0, 1]) + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + img = np.arange(12).reshape((3,2,2)).astype(np.float32) + res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output]) + print(res) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]] + """ + helper = LayerHelper("flip", **locals()) + check_type(input, 'X', (Variable), 'flip') + dtype = helper.input_dtype() + check_dtype(dtype, 'X', + ['float16', 'float32', 'float64', 'int32', 'int64', 'bool'], + 'flip') + check_type(dims, 'dims', (list, tuple), 'flip') + assert len(dims) > 0, 'len(dims) must be greater than 0.' + if name is None: + out = helper.create_variable_for_type_inference(dtype) + else: + out = helper.create_variable(name=name, dtype=dtype, persistable=False) + + helper.append_op( + type="flip", + inputs={"X": input}, + outputs={"Out": out}, + attrs={"dims": dims}) + return out