From e12b1d179266e13b5e604748ba6cc0a3dafd088f Mon Sep 17 00:00:00 2001 From: Bai Yifan Date: Fri, 27 Jul 2018 18:20:36 +0800 Subject: [PATCH] Add flatten op (#12341) * add flatten op --- paddle/fluid/operators/.flatten_op.cc.swp | Bin 0 -> 16384 bytes paddle/fluid/operators/CMakeLists.txt | 2 + paddle/fluid/operators/flatten_op.cc | 169 ++++++++++++++++++ .../fluid/tests/unittests/test_flatten_op.py | 68 +++++++ 4 files changed, 239 insertions(+) create mode 100644 paddle/fluid/operators/.flatten_op.cc.swp create mode 100644 paddle/fluid/operators/flatten_op.cc create mode 100644 python/paddle/fluid/tests/unittests/test_flatten_op.py diff --git a/paddle/fluid/operators/.flatten_op.cc.swp b/paddle/fluid/operators/.flatten_op.cc.swp new file mode 100644 index 0000000000000000000000000000000000000000..3395b6074b6a4c684a97674af702ca8b91dc85e9 GIT binary patch literal 16384 zcmeHNU5q4E6~6rWzlb146Tw3fhU%fGx@UIdV(;vj-ua`$?)0W-c6V^y9c#L7Pgkb9 zs;H`-A0`V1A`zZ^HSqzB2_zbM&;*|R0WkzM!NmB0Q67}U#Hh~%9{kR|Royi`J-9$j zFjghs)O6Lo=brnWbI-l^)Lk#1ubrXe)|9|?lMtU+onI&{rfdy-FA1~mhF|* zOD-__+I_G*IT7tf6K%IS5gaQ|tou6?_T&SBeE(1P|3CYr z5I+ST22KLE0-pxnxKW6|0#5;#fm?y!9TDPL;77oBfDW(;tOBQiD)0s1cHs3-2=NEt zTYw9U0dId?h#v!2fYZQR9~0tL;Bnv^!2Q4x;BOxl;=90;fD1H%HQ+LE7x3Cggm@l! z0*HX~z#=dX7{KR&H$My=fR})0fgb?R08axkun8;z3&1pRBk<~ng!m!wZJ+^M2V4uh zeuEG{2fhb90W^Tm0spvOh*yAL0WShC0AB~Tfy=;Mz)iqQNUCfDcLQ%D0rMB&6<`mr z0Jh1W-dBQ!hA*S#hh*=#5p{b}OzVERp1AgC~>Ci#RdJM>NwRn4yW;C#x$4!qd3$X<1n3# z!inLcwuk;QYqD#*p6hKK)Y1ll19H%=rpNU>84fFJ=>96Z?5}q$(abY zh6O_zwe0{F(oE_X+7iRu_Q|UEVz6!+xF@)TbVB$&f0^YifmdSMpJz zN~Ip04gJ`!dd_THHc?6!-9y-5^ITCTO1g?8LlZNNo@Yi)949JL(6O5`W5vBa-7(tB z_Oo#1%9SGF%uyI+${LAGbQ;yT?|Xe7Py;`WWu}*3umGip-Gku<%FxOEJ!!}dH;TjE zN@X^b@b`LPs%oVY)x0f#6GKnijOu}w$4Xr^U@q>7aft=mok(iOR4QLIl$F&xjKMW{ zbEIEqNaQi#u49Cru>;-H6GAo6^gV62+9G^Q zwU{AJac?cSo^##~%>ieHi5Wiwd+u@5;wbsWV{5VQ5o%1b8zQ{& zzo|;8Chu@RVCA_vw@YKGM=RZ;*pA&M9UyFIEND^O12rnxFEoKEjm>yxVg}?3KU0|K3OrE6luv;w% z7NRImjK;2@d_mi=d?_VWYnTpkFp*3G*$CD}lwNbtrirHXxKBEu2{;f(hG-fN8A?0W z(BS2P8r=PfojMX-j@vdoJ|y{I&+hQr0HF|F)N|S0nBy?>Br^`t+3MWf>G`$!rGGCk1C1@49lQVs%)re9Uj!C+;2NeKtFPkA{8Vkb0|(%QGbs-B0x2LpB-sA z+Vq9tm6B@KP<S~J28zz;lhYGen_Yt z2pAc32d_C|bJ*2UUC&MkD;Zt%GfbZzhG`V=w2~?vq2}PJKHbd1?sPpYT(SN!_MLd{ z3&q*zwJ=#4SZ!laep<)q8EmcGuFIha(>3)x2eFPL#c8*R)tIzsyd;*TjvmP~;+U$U zNg$hU3o)#Nm4460Vt<1tKDBrS*+*d!hLKmuyMs;bWHor{u2F=!& z=4vao`VxLEP<816y1%wGS0w2|K8EH_fE7=~qtwOZ=}N~E4li@mg2OqeF0W0q1+ZDb zAUCSZR9iOaRCT$wTomU#HU2{kd^f}zc72!Z0MeWGdPkx!+uWXwfvI;O*v)K*Zj7n8 zNRi))w-HSB3a1^%L8VmM-rlxsSgS29Kinum5HlB*PSX%@PL?nWDUI2apJO?}r zJPmvkcogUXo4|d*BJgG4IPh8EX5a?kE#v`y5Bvsr8F&f!CGZ&V5Kso*L|))U;7Q-7A1&fZJvjJmu-qQGtI><(N~OO(u2kk^%kFhzEpSdoO_cDdu87U*HMm-obSQLkLV85QjdM7E$VrGGlHn25Fwe4>LusS=pJP&@%BoWB#wJT~X zKfF0O@VGBzDDk86J^v}YsnUTNHb!JH9p1OI79$dvoaV&#kMo(%PEzQ-r!@s{7L`U< zpT5i(TyK1k`kBHW7^$Ibk}*?^H_O9$O|?tKpZ{VWvQH!IuQ#z7o!Ci{`CJ0a$v!%< ziOP;7)S_at)jq5ycEuT6CY{nrgAa0AQxDORyX<-MfrO()uATf*br6rfKkoKnDLwl1IF~4JDTP-8BoJ=zMOQoAh3u5<;M4`I_gu2 +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +class FlattenOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input (X) of Flatten op should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output (Output) of Flatten op should not be null."); + const auto &axis = ctx->Attrs().Get("axis"); + const auto &in_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE(axis >= 0, "The axis should be greater than or equal to 0."); + PADDLE_ENFORCE( + axis <= in_dims.size(), + "The axis should be less than or equal to input tensor's rank."); + + const auto &out_dims = GetOutputShape(axis, in_dims); + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + if (in_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } + } + + static std::vector GetOutputShape(const int axis, + const framework::DDim &in_dims) { + int64_t outer = 1, inner = 1; + for (int i = 0; i < in_dims.size(); ++i) { + if (i < axis) { + outer *= in_dims[i]; + } else { + inner *= in_dims[i]; + } + } + std::vector out_shape(2); + out_shape[0] = outer; + out_shape[1] = inner; + return out_shape; + } +}; + +class FlattenOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto &axis = Attr("axis"); + auto in_dims = + scope.FindVar(Input("X"))->Get().dims(); + const auto &out_dims = FlattenOpInferShape::GetOutputShape(axis, in_dims); + + framework::AttributeMap attrs; + attrs["shape"] = out_dims; + attrs["inplace"] = false; + // Invoke Reshape Op + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, + {{"Out", {Output("Out")}}}, attrs); + reshape_op->Run(scope, place); + } +}; + +class FlattenOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) A tensor of rank >= axis."); + AddOutput("Out", + "A 2D tensor is reshaped input tensor. The input dimensions" + "up to axis are flattened to the outer dimension of the output" + "and the remaining input dimensions are flattened into the inner" + "dimension of the output."); + AddAttr("axis", + "(int)" + "Indicate up to which input dimensions (exclusive) should be" + "flattened to the outer dimension of the output. The value" + "for axis must be in the range [0, R], where R is the rank of" + "the input tensor. When axis = 0, the shape of the output" + "tensor is (1, (d_0 X d_1 ... d_n), where the shape of the" + "input tensor is (d_0, d_1, ... d_n).") + .SetDefault(1); + AddComment(R"DOC( +Flatten Operator + +Flattens the input tensor into a 2D matrix. + +Examples: +Case 1: + Given + X.shape = (3, 100, 100, 4) + and + axis = 2 + We get: + Out.shape = (3 * 100, 4 * 100) + +Case 2: + Given + X.shape = (3, 100, 100, 4) + and + axis = 0 + We get: + Out.shape = (1, 3 * 100 * 100 * 4) +)DOC"); + } +}; + +class FlattenGradInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + context->SetOutputDim(framework::GradVarName("X"), + context->GetInputDim("X")); + context->ShareLoD("X", framework::GradVarName("X")); + } +}; + +class FlattenGradOp : public framework::OperatorBase { + public: + using OperatorBase::OperatorBase; + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto dx_name = Output(framework::GradVarName("X")); + auto dout_name = Input(framework::GradVarName("Out")); + auto in_dims = + scope.FindVar(Input("X"))->Get().dims(); + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(in_dims); + attrs["inplace"] = false; + + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, + attrs); + reshape_op->Run(scope, place); + } +}; + +} // namespace operators +} // namespace paddle + +USE_OP(reshape); + +namespace ops = paddle::operators; +REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker, + ops::FlattenOpInferShape, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape); diff --git a/python/paddle/fluid/tests/unittests/test_flatten_op.py b/python/paddle/fluid/tests/unittests/test_flatten_op.py new file mode 100644 index 0000000000..f8692ce2ea --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_flatten_op.py @@ -0,0 +1,68 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +from op_test import OpTest + + +class TestFlattenOp(OpTest): + def setUp(self): + self.op_type = "flatten" + self.init_test_case() + self.inputs = {"X": np.random.random(self.in_shape).astype("float32")} + self.init_attrs() + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.in_shape = (3, 2, 2, 5) + self.axis = 1 + self.new_shape = (3, 20) + + def init_attrs(self): + self.attrs = {"axis": self.axis} + + +class TestFlattenOp(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 2, 3) + self.axis = 0 + self.new_shape = (1, 36) + + +class TestFlattenOpWithDefaultAxis(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 2, 3) + self.new_shape = (3, 12) + + def init_attrs(self): + self.attrs = {} + + +class TestFlattenOpSixDims(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 3, 2, 4, 4) + self.axis = 4 + self.new_shape = (36, 16) + + +if __name__ == "__main__": + unittest.main() -- GitLab