未验证 提交 9676ac1c 编写于 作者: W Wilber 提交者: GitHub

Add flip op. (#23255)

* add flip op
上级 d8a21ef6
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/flip_op.h"
#include <string>
#include <unordered_map>
#include <vector>
namespace paddle {
namespace operators {
using framework::OpKernelType;
using framework::Tensor;
class FlipOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) of FlipOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of FlipOp should not be null."));
auto x_dims = ctx->GetInputDim("X");
auto flip_dims = ctx->Attrs().Get<std::vector<int>>("dims");
size_t flip_dims_size = flip_dims.size();
// check if dims axis within range
auto min_max_d = std::minmax_element(flip_dims.begin(), flip_dims.end());
PADDLE_ENFORCE_LT(*min_max_d.first, x_dims.size(),
platform::errors::InvalidArgument(
"min(dims) should be less than the input tensor X's "
"dimensions of FlipOp. But received min(dims) = %d, "
"X's dimensions = %d, X's shape = [%s]",
*min_max_d.first, x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(
*min_max_d.first, x_dims.size() * -1,
platform::errors::InvalidArgument(
"min(dims) should be greater than or equal to the input tensor X's "
"dimensions of FlipOp times -1. But received min(dims) = %d, X's "
"dimensions = %d, X's shape = [%s]",
*min_max_d.first, x_dims.size() * -1, x_dims));
PADDLE_ENFORCE_LT(*min_max_d.second, x_dims.size(),
platform::errors::InvalidArgument(
"max(dims) should be less than the input tensor X's "
"dimensions of FlipOp. But received max(dims) = %d, "
"X's dimensions = %d, X's shape = [%s]",
*min_max_d.second, x_dims.size(), x_dims));
PADDLE_ENFORCE_GE(
*min_max_d.second, x_dims.size() * -1,
platform::errors::InvalidArgument(
"max(dims) should be greater than or equal to the input tensor X's "
"dimensions of FlipOp times -1. But received max(dims) = %d, X's "
"dimensions = %d, X's shape = [%s]",
*min_max_d.second, x_dims.size() * -1, x_dims));
// check duplicates in dims
flip_dims.erase(std::unique(flip_dims.begin(), flip_dims.end()),
flip_dims.end());
PADDLE_ENFORCE_EQ(flip_dims.size(), flip_dims_size,
platform::errors::InvalidArgument(
"dims has duplicates, original flip dims size=%d, "
"but unique flip dims size=%d.)",
flip_dims_size, flip_dims.size()));
VLOG(3) << "flip operator x.shape=" << x_dims;
std::vector<int64_t> output_dims(x_dims.size());
for (int i = 0; i < x_dims.size(); ++i) {
output_dims[i] = x_dims[i];
}
ctx->SetOutputDim("Out", framework::make_ddim(output_dims));
ctx->ShareLoD("X", "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library, customized_type_value);
}
};
class FlipOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of flip op.");
AddOutput("Out", "(Tensor), The output tensor of flip op.");
AddAttr<std::vector<int>>("dims", "The axes to flip on.");
AddComment(R"DOC(
Flip Operator.
Reverse the order of a n-D tensor along given axis in dims.
)DOC");
}
};
class FlipOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput {
protected:
std::unordered_map<std::string, std::string> GetInputOutputWithSameType()
const override {
return std::unordered_map<std::string, std::string>{{"X", /*->*/ "Out"}};
}
};
template <typename T>
class FlipOpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType("flip");
retv->SetInput("X", this->OutputGrad("Out"));
retv->SetOutput("Out", this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType,
ops::FlipOpGradMaker<paddle::framework::OpDesc>,
ops::FlipOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
flip, ops::FlipKernel<paddle::platform::CPUDeviceContext, float>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, double>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, bool>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/flip_op.h"
#include <vector>
#include "paddle/fluid/memory/malloc.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using CUDADeviceContext = paddle::platform::CUDADeviceContext;
template <typename T>
__global__ void kernel_pointwise_flip_apply(const int N, const T* in_data,
T* out_data, int dim0, int stride0,
int dim1, int flip_dim) {
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < N;
idx += gridDim.x * blockDim.x) {
int dst_offset = 0;
if (flip_dim == 0) {
// flip 1st dim
dst_offset = (dim0 - 1 - idx / stride0) * stride0 + idx % stride0;
} else {
// flip last dim
dst_offset = idx / stride0 * stride0 + (dim1 - 1 - idx % stride0);
}
out_data[dst_offset] = in_data[idx];
}
}
template <typename T>
__global__ void flip_cuda_kernel(const int N, const T* in_data, T* out_data,
int64_t* x_shape, int64_t* x_stride,
int* flip_dims, int flip_dims_size,
int total_dims) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
int cur_indices = idx, rem = 0, dst_offset = 0;
for (int i = 0; i < total_dims; ++i) {
int64_t temp = cur_indices;
cur_indices = cur_indices / x_stride[i];
rem = temp - cur_indices * x_stride[i];
// flip the indices if it is in flip_dims
for (int j = 0; j < flip_dims_size; ++j) {
if (i == flip_dims[j]) {
cur_indices = x_shape[i] - 1 - cur_indices;
}
}
dst_offset += cur_indices * x_stride[i];
cur_indices = rem;
}
out_data[idx] = in_data[dst_offset];
}
template <typename T>
class FlipKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto gplace = boost::get<platform::CUDAPlace>(ctx.GetPlace());
auto cplace = platform::CPUPlace();
auto& dev_ctx = ctx.template device_context<CUDADeviceContext>();
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
auto* in_data = x->data<T>();
auto* out_data = out->mutable_data<T>(ctx.GetPlace());
auto flip_dims = ctx.template Attr<std::vector<int>>("dims");
const int flip_dims_size = static_cast<int>(flip_dims.size());
auto x_dims = x->dims();
const int total_dims = x_dims.size();
const int N = x->numel();
int block_size = 512;
dim3 dim_block(block_size);
dim3 dim_grid((N + block_size - 1) / block_size);
for (size_t i = 0; i < flip_dims.size(); ++i) {
if (flip_dims[i] < 0) {
flip_dims[i] += total_dims;
}
}
auto x_stride = framework::stride(x_dims);
std::vector<int64_t> x_dims_v = framework::vectorize(x_dims);
std::vector<int64_t> x_stride_v = framework::vectorize(x_stride);
// wrap high-dims to 2-dims
if (flip_dims_size == 1 &&
(flip_dims[0] == 0 || flip_dims[0] == total_dims - 1)) {
int dim0 = 1, dim1 = 1;
int stride0 = 1;
if (flip_dims[0] == 0) {
dim0 = x_dims_v[0];
stride0 = x_stride_v[0];
for (size_t i = 1; i < total_dims; ++i) {
dim1 *= x_dims_v[i];
}
} else {
dim1 = x_dims_v[total_dims - 1];
for (size_t i = 0; i < total_dims - 1; ++i) {
dim0 *= x_dims_v[i];
}
stride0 *= x_dims_v[total_dims - 1];
}
kernel_pointwise_flip_apply<
T><<<dim_grid, dim_block, 0, ctx.cuda_device_context().stream()>>>(
N, in_data, out_data, dim0, stride0, dim1, flip_dims[0]);
}
int bytes = total_dims * sizeof(int64_t);
auto x_strides_array_tmp = memory::Alloc(dev_ctx, bytes);
int64_t* x_strides_array_gpu =
reinterpret_cast<int64_t*>(x_strides_array_tmp->ptr());
memory::Copy(gplace, x_strides_array_gpu, cplace, x_stride_v.data(), bytes,
dev_ctx.stream());
auto x_shape_array_tmp = memory::Alloc(dev_ctx, bytes);
int64_t* x_shape_array_gpu =
reinterpret_cast<int64_t*>(x_shape_array_tmp->ptr());
memory::Copy(gplace, x_shape_array_gpu, cplace, x_dims_v.data(), bytes,
dev_ctx.stream());
bytes = flip_dims_size * sizeof(int);
auto flip_dims_array_tmp = memory::Alloc(dev_ctx, bytes);
int* flip_dims_array_gpu =
reinterpret_cast<int*>(flip_dims_array_tmp->ptr());
memory::Copy(gplace, flip_dims_array_gpu, cplace, flip_dims.data(), bytes,
dev_ctx.stream());
flip_cuda_kernel<
T><<<dim_grid, dim_block, 0, ctx.cuda_device_context().stream()>>>(
N, in_data, out_data, x_shape_array_gpu, x_strides_array_gpu,
flip_dims_array_gpu, flip_dims_size, total_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
flip, ops::FlipKernel<paddle::platform::CUDADeviceContext, float>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, double>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, int>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CUDADeviceContext, bool>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <bitset>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
constexpr size_t dim_bitset_size = 64;
template <typename DeviceContext, typename T>
class FlipKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override;
};
template <typename T>
class FlipKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
auto flip_dims = ctx.template Attr<std::vector<int>>("dims");
auto x_dims = x->dims();
const int total_dims = x_dims.size();
std::bitset<dim_bitset_size> dim_bitset;
for (size_t i = 0; i < flip_dims.size(); ++i) {
int dim = flip_dims[i];
if (flip_dims[i] < 0) {
dim += total_dims;
}
dim_bitset[dim] = true;
}
auto x_strides = framework::stride(x_dims);
auto numel = x->numel();
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(ctx.GetPlace());
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < numel; ++i) {
int64_t cur_indices = i;
int64_t rem = 0;
int64_t dst_offset = 0;
for (int d = 0; d < total_dims; ++d) {
int64_t temp = cur_indices;
cur_indices = cur_indices / x_strides[d];
rem = temp - cur_indices * x_strides[d];
dst_offset += dim_bitset[d]
? (x_dims[d] - 1 - cur_indices) * x_strides[d]
: cur_indices * x_strides[d];
cur_indices = rem;
}
out_data[i] = x_data[dst_offset];
}
}
};
} // namespace operators
} // namespace paddle
......@@ -178,7 +178,7 @@ from .tensor.logic import equal #DEFINE_ALIAS
# from .tensor.manipulation import unique_with_counts #DEFINE_ALIAS
# from .tensor.manipulation import unsqueeze #DEFINE_ALIAS
# from .tensor.manipulation import unstack #DEFINE_ALIAS
# from .tensor.manipulation import flip #DEFINE_ALIAS
from .tensor.manipulation import flip #DEFINE_ALIAS
# from .tensor.manipulation import unbind #DEFINE_ALIAS
# from .tensor.manipulation import roll #DEFINE_ALIAS
# from .tensor.search import argmax #DEFINE_ALIAS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import Program, program_guard
from op_test import OpTest
class TestFlipOp_API(unittest.TestCase):
"""Test flip api."""
def test_static_graph(self):
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
dims = [0]
input = fluid.data(name='input', dtype='float32', shape=[2, 3])
output = paddle.flip(input, dims)
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_program)
img = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)
res = exe.run(train_program,
feed={'input': img},
fetch_list=[output])
out_np = np.array(res[0])
out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32)
self.assertTrue(
(out_np == out_ref).all(),
msg='flip output is wrong, out =' + str(out_np))
def test_dygraph(self):
img = np.array([[1, 2, 3], [4, 5, 6]]).astype(np.float32)
with fluid.dygraph.guard():
inputs = fluid.dygraph.to_variable(img)
ret = paddle.flip(inputs, [0])
out_ref = np.array([[4, 5, 6], [1, 2, 3]]).astype(np.float32)
self.assertTrue(
(ret.numpy() == out_ref).all(),
msg='flip output is wrong, out =' + str(ret.numpy()))
class TestFlipOp(OpTest):
def setUp(self):
self.op_type = 'flip'
self.init_test_case()
self.inputs = {'X': np.random.random(self.in_shape).astype('float64')}
self.init_attrs()
self.outputs = {'Out': self.calc_ref_res()}
def init_attrs(self):
self.attrs = {"dims": self.dims}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Out")
def init_test_case(self):
self.in_shape = (6, 4, 2, 3)
self.dims = [0, 1]
def calc_ref_res(self):
res = self.inputs['X']
for axis in self.dims:
res = np.flip(res, axis)
return res
class TestFlipOpAxis1(TestFlipOp):
def init_test_case(self):
self.in_shape = (2, 4, 4)
self.dims = [0]
class TestFlipOpAxis2(TestFlipOp):
def init_test_case(self):
self.in_shape = (4, 4, 6, 3)
self.dims = [0, 2]
class TestFlipOpAxis3(TestFlipOp):
def init_test_case(self):
self.in_shape = (4, 3, 1)
self.dims = [0, 1, 2]
class TestFlipOpAxis4(TestFlipOp):
def init_test_case(self):
self.in_shape = (6, 4, 2, 2)
self.dims = [0, 1, 2, 3]
if __name__ == "__main__":
unittest.main()
......@@ -155,7 +155,7 @@ from .logic import equal #DEFINE_ALIAS
# from .manipulation import unique_with_counts #DEFINE_ALIAS
# from .manipulation import unsqueeze #DEFINE_ALIAS
# from .manipulation import unstack #DEFINE_ALIAS
# from .manipulation import flip #DEFINE_ALIAS
from .manipulation import flip #DEFINE_ALIAS
# from .manipulation import unbind #DEFINE_ALIAS
# from .manipulation import roll #DEFINE_ALIAS
# from .search import argmax #DEFINE_ALIAS
......
......@@ -12,30 +12,89 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from ..fluid.layers import core
from ..fluid.layer_helper import LayerHelper
from ..fluid.framework import Variable, OpProtoHolder, in_dygraph_mode, convert_np_dtype_to_dtype_
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
# TODO: define functions to manipulate a tensor
# __all__ = ['cast',
# 'concat',
# 'expand',
# 'expand_as',
# 'flatten',
# 'gather',
# 'gather_nd',
# 'reshape',
# 'reverse',
# 'scatter',
# 'scatter_nd_add',
# 'scatter_nd',
# 'shard_index',
# 'slice',
# 'split',
# 'squeeze',
# 'stack',
# 'strided_slice',
# 'transpose',
# 'unique',
# 'unique_with_counts',
# 'unsqueeze',
# 'unstack',
# 'flip',
# 'unbind',
# 'roll']
__all__ = [
# 'cast',
# 'concat',
# 'expand',
# 'expand_as',
# 'flatten',
# 'gather',
# 'gather_nd',
# 'reshape',
# 'reverse',
# 'scatter',
# 'scatter_nd_add',
# 'scatter_nd',
# 'shard_index',
# 'slice',
# 'split',
# 'squeeze',
# 'stack',
# 'strided_slice',
# 'transpose',
# 'unique',
# 'unique_with_counts',
# 'unsqueeze',
# 'unstack',
'flip',
# 'unbind',
# 'roll'
]
def flip(input, dims, name=None):
"""
Reverse the order of a n-D tensor along given axis in dims.
Args:
input (Variable): A Tensor(or LoDTensor) with shape :math:`[N_1, N_2,..., N_k]` . The data type of the input Tensor
should be float32, float64, int32, int64, bool.
dims (list): The axis to flip on.
name (str, optional): The default value is None. Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name` .
Returns:
Variable: Tensor or LoDTensor calculated by flip layer. The data type is same with input.
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
import numpy as np
input = fluid.data(name="x", shape=[-1, 2, 2], dtype='float32')
output = paddle.flip(input, dims=[0, 1])
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
img = np.arange(12).reshape((3,2,2)).astype(np.float32)
res = exe.run(fluid.default_main_program(), feed={'x':img}, fetch_list=[output])
print(res) # [[[10,11][8, 9]],[[6, 7],[4, 5]] [[2, 3],[0, 1]]]
"""
helper = LayerHelper("flip", **locals())
check_type(input, 'X', (Variable), 'flip')
dtype = helper.input_dtype()
check_dtype(dtype, 'X',
['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
'flip')
check_type(dims, 'dims', (list, tuple), 'flip')
assert len(dims) > 0, 'len(dims) must be greater than 0.'
if name is None:
out = helper.create_variable_for_type_inference(dtype)
else:
out = helper.create_variable(name=name, dtype=dtype, persistable=False)
helper.append_op(
type="flip",
inputs={"X": input},
outputs={"Out": out},
attrs={"dims": dims})
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册