diff --git a/paddle/fluid/operators/flatten_op.cc b/paddle/fluid/operators/flatten_op.cc index 350d40ce8322acdc99f852d912bd0ace683f639f..9f2a122203bf9bed2d8737dc2056b16b4d7b7b8e 100644 --- a/paddle/fluid/operators/flatten_op.cc +++ b/paddle/fluid/operators/flatten_op.cc @@ -1,4 +1,4 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -12,6 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include "paddle/fluid/operators/flatten_op.h" +#include +#include +#include #include #include "paddle/fluid/framework/op_registry.h" @@ -20,18 +24,21 @@ namespace operators { using Tensor = framework::Tensor; -class FlattenOpInferShape : public framework::InferShapeBase { +class FlattenOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input (X) of Flatten op should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output (Output) of Flatten op should not be null."); + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input (X) of Flatten op should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output (Output) of Flatten op should not be null."); const auto &axis = ctx->Attrs().Get("axis"); const auto &in_dims = ctx->GetInputDim("X"); - PADDLE_ENFORCE(axis >= 0, "The axis should be greater than or equal to 0."); - PADDLE_ENFORCE( - axis <= in_dims.size(), + PADDLE_ENFORCE_GE(axis, 0, + "The axis should be greater than or equal to 0."); + PADDLE_ENFORCE_LE( + axis, in_dims.size(), "The axis should be less than or equal to input tensor's rank."); const auto &out_dims = GetOutputShape(axis, in_dims); @@ -58,28 +65,12 @@ class FlattenOpInferShape : public framework::InferShapeBase { out_shape[1] = inner; return out_shape; } -}; -class FlattenOp : public framework::OperatorBase { - public: - using OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto &axis = Attr("axis"); - auto in_dims = - scope.FindVar(Input("X"))->Get().dims(); - const auto &out_dims = FlattenOpInferShape::GetOutputShape(axis, in_dims); - - framework::AttributeMap attrs; - attrs["shape"] = out_dims; - attrs["inplace"] = false; - // Invoke Reshape Op - auto reshape_op = framework::OpRegistry::CreateOp( - "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, - {{"Out", {Output("Out")}}}, attrs); - reshape_op->Run(scope, place); + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -126,34 +117,21 @@ Case 2: } }; -class FlattenGradInferShape : public framework::InferShapeBase { +class FlattenGradOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *context) const override { + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *context) const override { context->SetOutputDim(framework::GradVarName("X"), context->GetInputDim("X")); context->ShareLoD("X", framework::GradVarName("X")); } -}; -class FlattenGradOp : public framework::OperatorBase { - public: - using OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto dx_name = Output(framework::GradVarName("X")); - auto dout_name = Input(framework::GradVarName("Out")); - auto in_dims = - scope.FindVar(Input("X"))->Get().dims(); - framework::AttributeMap attrs; - attrs["shape"] = framework::vectorize2int(in_dims); - attrs["inplace"] = false; - - auto reshape_op = framework::OpRegistry::CreateOp( - "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, - attrs); - reshape_op->Run(scope, place); + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType(ctx.Input("X")->type(), + ctx.device_context()); } }; @@ -162,13 +140,33 @@ class FlattenGradOp : public framework::OperatorBase { // flatten_grad, in this way, the framework can reuse the memory of X // immediately the flatten2_op is finished. // Considering compatibility issues, we could not fix flatten2_op -class Flatten2OpInferShape : public FlattenOpInferShape { +class Flatten2Op : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *ctx) const override { - FlattenOpInferShape::operator()(ctx); - PADDLE_ENFORCE(ctx->HasOutput("XShape"), - "Output (XShape) of Flatten op should not be null."); + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input (X) of Flatten op should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output (Output) of Flatten op should not be null."); + const auto &axis = ctx->Attrs().Get("axis"); const auto &in_dims = ctx->GetInputDim("X"); + PADDLE_ENFORCE_GE(axis, 0, + "The axis should be greater than or equal to 0."); + PADDLE_ENFORCE_LE( + axis, in_dims.size(), + "The axis should be less than or equal to input tensor's rank."); + + const auto &out_dims = FlattenOp::GetOutputShape(axis, in_dims); + ctx->SetOutputDim("Out", framework::make_ddim(out_dims)); + if (in_dims[0] == out_dims[0]) { + // Only pass LoD when the first dimension of output and Input(X) + // are the same. + ctx->ShareLoD("X", "Out"); + } + + PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true, + "Output (XShape) of Flatten op should not be null."); std::vector xshape_dims(in_dims.size() + 1); xshape_dims[0] = 0; for (int i = 0; i < in_dims.size(); ++i) { @@ -179,29 +177,6 @@ class Flatten2OpInferShape : public FlattenOpInferShape { } }; -class Flatten2Op : public framework::OperatorBase { - public: - using OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto &axis = Attr("axis"); - auto in_dims = - scope.FindVar(Input("X"))->Get().dims(); - const auto &out_dims = FlattenOpInferShape::GetOutputShape(axis, in_dims); - - framework::AttributeMap attrs; - attrs["shape"] = out_dims; - attrs["inplace"] = false; - // Invoke Reshape Op - auto reshape_op = framework::OpRegistry::CreateOp( - "reshape2", {{"X", {Input("X")}}, {"Shape", {}}}, - {{"Out", {Output("Out")}}, {"XShape", {Output("XShape")}}}, attrs); - reshape_op->Run(scope, place); - } -}; - class Flatten2OpMaker : public FlattenOpMaker { public: void Make() override { @@ -228,43 +203,27 @@ class Flatten2GradOpMaker : public framework::SingleGradOpDescMaker { } }; -class Flatten2GradInferShape : public framework::InferShapeBase { +class Flatten2GradOp : public framework::OperatorWithKernel { public: - void operator()(framework::InferShapeContext *context) const override { - PADDLE_ENFORCE(context->HasInput("XShape"), - "Input(XShape) shouldn't be null."); - PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *context) const override { + PADDLE_ENFORCE_EQ(context->HasInput("XShape"), true, + "Input(XShape) shouldn't be null."); + PADDLE_ENFORCE_EQ(context->HasInput(framework::GradVarName("Out")), true, + "Input(Out@GRAD) shouldn't be null."); auto xshape_dims = context->GetInputDim("XShape"); auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); context->SetOutputDim(framework::GradVarName("X"), x_dims); context->ShareLoD("XShape", framework::GradVarName("X")); } -}; -class Flatten2GradOp : public framework::OperatorBase { - public: - using OperatorBase::OperatorBase; - - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto dx_name = Output(framework::GradVarName("X")); - auto dout_name = Input(framework::GradVarName("Out")); - auto xshape_name = Input("XShape"); - auto xshape_dims = - scope.FindVar(xshape_name)->Get().dims(); - auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); - - framework::AttributeMap attrs; - attrs["shape"] = framework::vectorize2int(x_dims); - attrs["inplace"] = false; - - auto reshape_grad_op = framework::OpRegistry::CreateOp( - "reshape2_grad", - {{"Out@GRAD", {dout_name}}, {"Shape", {}}, {"XShape", {xshape_name}}}, - {{"X@GRAD", {dx_name}}}, attrs); - reshape_grad_op->Run(scope, place); + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + ctx.Input(framework::GradVarName("Out"))->type(), + ctx.device_context()); } }; @@ -276,18 +235,41 @@ DECLARE_INPLACE_OP_INFERER(FlattenGradInplaceinToOut, } // namespace operators } // namespace paddle -USE_OP(reshape); - namespace ops = paddle::operators; REGISTER_OPERATOR(flatten, ops::FlattenOp, ops::FlattenOpMaker, - ops::FlattenOpInferShape, paddle::framework::DefaultGradOpDescMaker, ops::FlattenOpInplaceInToOut); -REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInferShape, +REGISTER_OPERATOR(flatten_grad, ops::FlattenGradOp, ops::FlattenGradInplaceinToOut); REGISTER_OPERATOR(flatten2, ops::Flatten2Op, ops::Flatten2OpMaker, - ops::Flatten2OpInferShape, ops::Flatten2GradOpMaker, - ops::FlattenOpInplaceInToOut); + ops::Flatten2GradOpMaker, ops::FlattenOpInplaceInToOut); REGISTER_OPERATOR(flatten2_grad, ops::Flatten2GradOp, - ops::Flatten2GradInferShape, ops::FlattenGradInplaceinToOut); + ops::FlattenGradInplaceinToOut); + +REGISTER_OP_CPU_KERNEL( + flatten, ops::FlattenKernel, + ops::FlattenKernel, + ops::FlattenKernel, + ops::FlattenKernel, + ops::FlattenKernel); +REGISTER_OP_CPU_KERNEL( + flatten_grad, + ops::FlattenGradKernel, + ops::FlattenGradKernel, + ops::FlattenGradKernel, + ops::FlattenGradKernel, + ops::FlattenGradKernel); +REGISTER_OP_CPU_KERNEL( + flatten2, ops::Flatten2Kernel, + ops::Flatten2Kernel, + ops::Flatten2Kernel, + ops::Flatten2Kernel, + ops::Flatten2Kernel); +REGISTER_OP_CPU_KERNEL( + flatten2_grad, + ops::Flatten2GradKernel, + ops::Flatten2GradKernel, + ops::Flatten2GradKernel, + ops::Flatten2GradKernel, + ops::Flatten2GradKernel); diff --git a/paddle/fluid/operators/flatten_op.cu.cc b/paddle/fluid/operators/flatten_op.cu.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac4ad8e2dc1c09f5ee9f0adfb8b19e0e4ec374a4 --- /dev/null +++ b/paddle/fluid/operators/flatten_op.cu.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/flatten_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + flatten, ops::FlattenKernel, + ops::FlattenKernel, + ops::FlattenKernel, + ops::FlattenKernel, + ops::FlattenKernel); +REGISTER_OP_CUDA_KERNEL( + flatten_grad, + ops::FlattenGradKernel, + ops::FlattenGradKernel, + ops::FlattenGradKernel, + ops::FlattenGradKernel, + ops::FlattenGradKernel); +REGISTER_OP_CUDA_KERNEL( + flatten2, ops::Flatten2Kernel, + ops::Flatten2Kernel, + ops::Flatten2Kernel, + ops::Flatten2Kernel, + ops::Flatten2Kernel); +REGISTER_OP_CUDA_KERNEL( + flatten2_grad, + ops::Flatten2GradKernel, + ops::Flatten2GradKernel, + ops::Flatten2GradKernel, + ops::Flatten2GradKernel, + ops::Flatten2GradKernel); diff --git a/paddle/fluid/operators/flatten_op.h b/paddle/fluid/operators/flatten_op.h new file mode 100644 index 0000000000000000000000000000000000000000..165832c0e68bdef38f0382ea29f7655a18345805 --- /dev/null +++ b/paddle/fluid/operators/flatten_op.h @@ -0,0 +1,116 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/pooling.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace operators { + +template +class FlattenKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *in = context.Input("X"); + auto *out = context.Output("Out"); + + auto &axes = context.Attr("axis"); + auto x_dims = in->dims(); + auto out_dims = framework::make_ddim(GetOutputShape(axes, x_dims)); + + out->mutable_data(context.GetPlace(), in->type()); + framework::TensorCopy( + *in, context.GetPlace(), + context.template device_context(), out); + out->Resize(out_dims); + } + + static std::vector GetOutputShape(const int axis, + const framework::DDim &in_dims) { + int64_t outer = 1, inner = 1; + for (int i = 0; i < in_dims.size(); ++i) { + if (i < axis) { + outer *= in_dims[i]; + } else { + inner *= in_dims[i]; + } + } + std::vector out_shape(2); + out_shape[0] = outer; + out_shape[1] = inner; + return out_shape; + } +}; + +template +class FlattenGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + auto in_dims = ctx.Input("X")->dims(); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + d_x->Resize(in_dims); + } +}; + +template +class Flatten2Kernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto &axes = context.Attr("axis"); + + auto *in = context.Input("X"); + auto x_dims = in->dims(); + + auto *out = context.Output("Out"); + + auto out_dims = framework::make_ddim( + FlattenKernel::GetOutputShape(axes, x_dims)); + + out->mutable_data(context.GetPlace(), in->type()); + framework::TensorCopy( + *in, context.GetPlace(), + context.template device_context(), out); + out->Resize(out_dims); + } +}; + +template +class Flatten2GradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_x = ctx.Output(framework::GradVarName("X")); + auto *d_out = + ctx.Input(framework::GradVarName("Out")); + + auto xshape_dims = ctx.Input("XShape")->dims(); + auto x_dims = framework::slice_ddim(xshape_dims, 1, xshape_dims.size()); + + d_x->mutable_data(ctx.GetPlace(), d_out->type()); + framework::TensorCopySync(*d_out, ctx.GetPlace(), d_x); + d_x->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_flatten2_op.py b/python/paddle/fluid/tests/unittests/test_flatten2_op.py new file mode 100644 index 0000000000000000000000000000000000000000..59185855a5f13b82ca26bc26ead73fbe5fb96443 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_flatten2_op.py @@ -0,0 +1,73 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np + +from op_test import OpTest + + +class TestFlattenOp(OpTest): + def setUp(self): + self.op_type = "flatten2" + self.init_test_case() + self.inputs = {"X": np.random.random(self.in_shape).astype("float32")} + self.init_attrs() + self.outputs = { + "Out": self.inputs["X"].reshape(self.new_shape), + "XShape": np.random.random(self.in_shape).astype("float32") + } + + def test_check_output(self): + self.check_output(no_check_set=["XShape"]) + + def test_check_grad(self): + self.check_grad(["X"], "Out") + + def init_test_case(self): + self.in_shape = (3, 2, 2, 5) + self.axis = 1 + self.new_shape = (3, 20) + + def init_attrs(self): + self.attrs = {"axis": self.axis} + + +class TestFlattenOp(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 2, 3) + self.axis = 0 + self.new_shape = (1, 36) + + +class TestFlattenOpWithDefaultAxis(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 2, 3) + self.new_shape = (3, 12) + + def init_attrs(self): + self.attrs = {} + + +class TestFlattenOpSixDims(TestFlattenOp): + def init_test_case(self): + self.in_shape = (3, 2, 3, 2, 4, 4) + self.axis = 4 + self.new_shape = (36, 16) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_flatten_op.py b/python/paddle/fluid/tests/unittests/test_flatten_op.py index effa2a148eef8b0047b12c676803abb2871e8118..91251147ebc7908893e90467c1305fca89917ed7 100644 --- a/python/paddle/fluid/tests/unittests/test_flatten_op.py +++ b/python/paddle/fluid/tests/unittests/test_flatten_op.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -22,17 +22,14 @@ from op_test import OpTest class TestFlattenOp(OpTest): def setUp(self): - self.op_type = "flatten2" + self.op_type = "flatten" self.init_test_case() self.inputs = {"X": np.random.random(self.in_shape).astype("float32")} self.init_attrs() - self.outputs = { - "Out": self.inputs["X"].reshape(self.new_shape), - "XShape": np.random.random(self.in_shape).astype("float32") - } + self.outputs = {"Out": self.inputs["X"].reshape(self.new_shape)} def test_check_output(self): - self.check_output(no_check_set=["XShape"]) + self.check_output() def test_check_grad(self): self.check_grad(["X"], "Out")