From 70729ad6416eecb8cb7f4e1d648f83e92bb73bdf Mon Sep 17 00:00:00 2001 From: chenweihang Date: Fri, 29 Jun 2018 13:13:05 +0000 Subject: [PATCH] Add Unsqueeze Operator Framework, not finshed --- paddle/fluid/operators/unsqueeze_op.cc | 148 ++++++++++++++++++ paddle/fluid/operators/unsqueeze_op.cu | 30 ++++ paddle/fluid/operators/unsqueeze_op.h | 72 +++++++++ .../tests/unittests/test_unsqueeze_op.py | 98 ++++++++++++ 4 files changed, 348 insertions(+) create mode 100644 paddle/fluid/operators/unsqueeze_op.cc create mode 100644 paddle/fluid/operators/unsqueeze_op.cu create mode 100644 paddle/fluid/operators/unsqueeze_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_unsqueeze_op.py diff --git a/paddle/fluid/operators/unsqueeze_op.cc b/paddle/fluid/operators/unsqueeze_op.cc new file mode 100644 index 00000000000..8d2a1866854 --- /dev/null +++ b/paddle/fluid/operators/unsqueeze_op.cc @@ -0,0 +1,148 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/unsqueeze_op.h" +#include +#include + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class UnsqueezeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of UnsqueezeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of UnsqueezeOp should not be null."); + + const auto& x_dims = ctx->GetInputDim("X"); + const auto& axes = ctx->Attrs().Get>("axes"); + // Check output tensor dims (<9). + PADDLE_ENFORCE_LE(x_dims.size() + axes.size(), 9, + "Invalid dimnesions, dynamic dimensions must have " + "between [1, 9] dimensions."); + // Check the range of unsqueeze aixs. + for (int a : axes) { + PADDLE_ENFORCE_LT(a, static_cast(x_dims.size() + axes.size()), + "The axis must be less than output tensor's rank."); + } + + auto out_dims = GetOutputShape(axes, x_dims); + ctx->SetOutputDim("Out", out_dims); + } + + static framework::DDim GetOutputShape(const std::vector unsqueeze_dims, + const framework::DDim& in_dims) { + int out_dims_size = in_dims.size() + unsqueeze_dims.size(); + bool should_unsqueeze[9] = {false}; + + // Determines the dimensions should be unsqueezed in output tensor after. + for (unsigned int idx = 0; idx < unsqueeze_dims.size(); ++idx) { + int current = unsqueeze_dims[idx] < 0 + ? unsqueeze_dims[idx] + out_dims_size + : unsqueeze_dims[idx]; + // Check current index. + PADDLE_ENFORCE_GE(current, 0, + "Invaild axis, negative axis is out of range."); + should_unsqueeze[idx] = true; + } + + // Make output dimensions + std::vector output_shape(out_dims_size, 0); + for (int in_idx = 0, out_idx = 0; out_idx < out_dims_size; ++out_idx) { + if (!should_unsqueeze[out_idx]) { + output_shape[out_idx] = in_dims[in_idx++]; + } else { + output_shape[out_idx] = 1; + } + } + + return framework::make_ddim(output_shape); + } +}; + +class UnsqueezeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor). The input tensor of unsqueeze operator."); + AddOutput("Out", "(Tensor). The output tensor of unsqueeze operator."); + AddAttr>("axes", + "(std::vector). List of positive integers," + " indicate the dimensions to be inserted"); + AddAttr( + "inplace", + "(default: false) Unsqueeze the source tensor's shape without " + "memory copy. When Attr(inplace) is set true, the output " + "tensor shares memory with Input(X), otherwise, a new output " + "tensor is created, and its data are copied from Input(x).") + .SetDefault(false); + AddComment(R"DOC( + Unsqueeze Operator. + + Insert single-dimensional entries to the shape of a tensor. + Takes one required argument axes, a list of dimensions that will be inserted. + Dimension indices in axes are as seen in the output tensor. + + For example: + Given a tensor such that tensor with shape [3, 4, 5], + then Unsqueeze(tensor, axes=[0, 4]) has shape [1, 3, 4, 5, 1] + )DOC"); + } +}; + +class UnsqueezeGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of UnsqueezeGradOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Output(Out@GRAD) of UnsqueezeGradOp should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(unsqueeze, ops::UnsqueezeOp, ops::UnsqueezeOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(unsqueeze_grad, ops::UnsqueezeGradOp); +REGISTER_OP_CPU_KERNEL( + unsqueeze, ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel); +REGISTER_OP_CPU_KERNEL( + unsqueeze_grad, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel); diff --git a/paddle/fluid/operators/unsqueeze_op.cu b/paddle/fluid/operators/unsqueeze_op.cu new file mode 100644 index 00000000000..891f6cc5489 --- /dev/null +++ b/paddle/fluid/operators/unsqueeze_op.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/fluid/operators/unsqueeze_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + squeeze, ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel, + ops::UnsqueezeKernel); +REGISTER_OP_CUDA_KERNEL( + squeeze_grad, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel, + ops::UnsqueezeGradKernel); diff --git a/paddle/fluid/operators/unsqueeze_op.h b/paddle/fluid/operators/unsqueeze_op.h new file mode 100644 index 00000000000..aa45fb3113e --- /dev/null +++ b/paddle/fluid/operators/unsqueeze_op.h @@ -0,0 +1,72 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class UnsqueezeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *out = ctx.Output("Out"); + auto *in = ctx.Input("X"); + + framework::DDim out_dims = out->dims(); + + bool inplace = ctx.Attr("inplace"); + out->Resize(out_dims); + if (!inplace) { + out->mutable_data(ctx.GetPlace()); + framework::TensorCopySync(*in, ctx.GetPlace(), out); + out->Resize(out_dims); + } else { + out->ShareDataWith(*in); + out->Resize(out_dims); + } + } +}; + +template +class UnsqueezeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_out = ctx.Input(framework::GradVarName("Out")); + auto *d_x = ctx.Output(framework::GradVarName("X")); + + d_x->mutable_data(ctx.GetPlace()); + bool inplace = ctx.Attr("inplace"); + + auto in_dims = d_x->dims(); + if (!inplace) { + framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); + ctx.device_context().Wait(); + d_x->Resize(in_dims); + } else { + d_x->ShareDataWith(*d_out); + d_x->Resize(in_dims); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py new file mode 100644 index 00000000000..273a2c075f3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_unsqueeze_op.py @@ -0,0 +1,98 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from op_test import OpTest + + +# Correct: General. +class TestSqueezeOp1(OpTest): + def setUp(self): + ori_shape = (3, 5) + axes = (0, 2) + new_shape = (1, 3, 1, 5) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: There is mins axis. +class TestSqueezeOp2(OpTest): + def setUp(self): + ori_shape = (3, 5) + axes = (0, -2) + new_shape = (1, 3, 1, 5) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": False} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: Inplace. +class TestUnsqueezeOpInplace1(OpTest): + def setUp(self): + ori_shape = (3, 5) + axes = (0, 2) + new_shape = (1, 3, 1, 5) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inplace": True} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +# Correct: Inplace. There is mins axis. +class TestUnsqueezeOpInplace2(OpTest): + def setUp(self): + ori_shape = (3, 5) + axes = (0, -2) + new_shape = (1, 3, 1, 5) + + self.op_type = "unsqueeze" + self.inputs = {"X": np.random.random(ori_shape).astype("float32")} + self.attrs = {"axes": axes, "inpalce": True} + self.outputs = {"Out": self.inputs["X"].reshape(new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + +if __name__ == "__main__": + unittest.main() -- GitLab