diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index ab1d2143330fb8cbfd535758a83bc71de939c4e0..c01d9bc38459952f9f422fc8c1ac5b0fbcb3c4d0 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -265,6 +265,7 @@ op_library(recurrent_op DEPS executor) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(cos_sim_op DEPS cos_sim_functor) op_library(parallel_do_op DEPS executor) +op_library(squeeze_op DEPS reshape_op) if (WITH_GPU) op_library(conv_op DEPS vol2col depthwise_conv im2col) diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc index 639480aba41783a5e830270733e175f502087b8f..26c3ea344921d9df05ba425fbd49d2cf9645f5c7 100644 --- a/paddle/fluid/operators/squeeze_op.cc +++ b/paddle/fluid/operators/squeeze_op.cc @@ -12,33 +12,28 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/squeeze_op.h" #include #include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { -using framework::OpKernelType; -using framework::Tensor; - -class SqueezeOp : public framework::OperatorWithKernel { +class SqueezeOpInferShape : public framework::InferShapeBase { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { + void operator()(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SqueezeOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of SqueezeOp should not be null."); - const auto& x_dims = ctx->GetInputDim("X"); + const auto &x_dims = ctx->GetInputDim("X"); // Check input tensor dims (<9). PADDLE_ENFORCE(x_dims.size() <= 9, "Invalid dimnesions, dynamic dimensions must have " "between [1, 9] dimensions."); - const auto& axes = ctx->Attrs().Get>("axes"); + const auto &axes = ctx->Attrs().Get>("axes"); for (int a : axes) { PADDLE_ENFORCE_LT(a, x_dims.size(), "The axis must be less than input tensor's rank."); @@ -55,7 +50,7 @@ class SqueezeOp : public framework::OperatorWithKernel { } static framework::DDim GetOutputShape(const std::vector squeeze_dims, - const framework::DDim& in_dims) { + const framework::DDim &in_dims) { int num_squeeze_dims = squeeze_dims.size(); int cnt_squeezed_dims = 0; bool should_squeeze[9] = {false}; @@ -100,6 +95,31 @@ class SqueezeOp : public framework::OperatorWithKernel { } }; +class SqueezeOp : public framework::OperatorBase { + public: + SqueezeOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto &axes = Attr>("axes"); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + auto out_dims = SqueezeOpInferShape::GetOutputShape(axes, x_dims); + + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(out_dims); + attrs["inplace"] = Attr("inplace"); + // Invoke Reshape Op + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {Input("X")}}, {"Shape", {}}}, + {{"Out", {Output("Out")}}}, attrs); + reshape_op->Run(scope, place); + } +}; + class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { @@ -116,67 +136,73 @@ class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { "tensor is created, and its data are copied from Input(x).") .SetDefault(false); AddComment(R"DOC( - Squeeze Operator. - - Remove single-dimensional entries from the shape of a tensor. - Takes a parameter axes with a list of axes to squeeze. - If axes is not provided, all the single dimensions will be removed from the shape. + Squeeze Operator. + + Remove single-dimensional entries from the shape of a tensor. + Takes a parameter axes with a list of axes to squeeze. + If axes is not provided, all the single dimensions will be removed from the shape. If an axis is selected with shape entry not equal to one, an error is raised. - - Examples: - Case 1: - Given - X.shape = (1, 3, 1, 5) - and - axes = [0] - we get: - Out.shape = (3, 1, 5) - - Case 2: - Given - X.shape = (1, 3, 1, 5) - we get: - Out.shape = (3, 5) + + Examples: + Case 1: + Given + X.shape = (1, 3, 1, 5) + and + axes = [0] + we get: + Out.shape = (3, 1, 5) + + Case 2: + Given + X.shape = (1, 3, 1, 5) + we get: + Out.shape = (3, 5) )DOC"); } }; -class SqueezeGradOp : public framework::OperatorWithKernel { +class SqueezeGradInferShape : public framework::InferShapeBase { public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of SqueezeGradOp should not be null."); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), - "Output(Out@GRAD) of SqueezeGradOp should not be null."); - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + void operator()(framework::InferShapeContext *context) const override { + context->SetOutputDim(framework::GradVarName("X"), + context->GetInputDim("X")); + context->ShareLoD("X", framework::GradVarName("X")); } +}; - protected: - framework::OpKernelType GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return framework::OpKernelType( - framework::ToDataType(ctx.Input("X")->type()), - ctx.device_context()); +class SqueezeGradOp : public framework::OperatorBase { + public: + SqueezeGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &place) const override { + auto dx_name = Output(framework::GradVarName("X")); + auto dout_name = Input(framework::GradVarName("Out")); + auto x_dims = scope.FindVar(Input("X"))->Get().dims(); + framework::AttributeMap attrs; + attrs["shape"] = framework::vectorize2int(x_dims); + attrs["inplace"] = Attr("inplace"); + + auto reshape_op = framework::OpRegistry::CreateOp( + "reshape", {{"X", {dout_name}}, {"Shape", {}}}, {{"Out", {dx_name}}}, + attrs); + reshape_op->Run(scope, place); } }; } // namespace operators } // namespace paddle +// Tell linker to use reshape op +USE_OP(reshape); + namespace ops = paddle::operators; REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker, + ops::SqueezeOpInferShape, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp); -REGISTER_OP_CPU_KERNEL( - squeeze, ops::SqueezeKernel, - ops::SqueezeKernel, - ops::SqueezeKernel, - ops::SqueezeKernel); -REGISTER_OP_CPU_KERNEL( - squeeze_grad, - ops::SqueezeGradKernel, - ops::SqueezeGradKernel, - ops::SqueezeGradKernel, - ops::SqueezeGradKernel); +REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp, ops::SqueezeGradInferShape); diff --git a/paddle/fluid/operators/squeeze_op.cu b/paddle/fluid/operators/squeeze_op.cu deleted file mode 100644 index 1096907daa5dfca4f12d0d8de6ff6fdca16ca6dd..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/squeeze_op.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#define EIGEN_USE_GPU - -#include "paddle/fluid/operators/squeeze_op.h" - -namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - squeeze, ops::SqueezeKernel, - ops::SqueezeKernel, - ops::SqueezeKernel, - ops::SqueezeKernel); -REGISTER_OP_CUDA_KERNEL( - squeeze_grad, - ops::SqueezeGradKernel, - ops::SqueezeGradKernel, - ops::SqueezeGradKernel, - ops::SqueezeGradKernel); diff --git a/paddle/fluid/operators/squeeze_op.h b/paddle/fluid/operators/squeeze_op.h deleted file mode 100644 index 44ef324c7dc5a702fcc8e7846f3870a94c4aa953..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/squeeze_op.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class SqueezeKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *out = ctx.Output("Out"); - auto *in = ctx.Input("X"); - - framework::DDim out_dims = out->dims(); - - bool inplace = ctx.Attr("inplace"); - out->Resize(out_dims); - if (!inplace) { - out->mutable_data(ctx.GetPlace()); - framework::TensorCopySync(*in, ctx.GetPlace(), out); - out->Resize(out_dims); - } else { - out->ShareDataWith(*in); - out->Resize(out_dims); - } - } -}; - -template -class SqueezeGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *d_out = ctx.Input(framework::GradVarName("Out")); - auto *d_x = ctx.Output(framework::GradVarName("X")); - - d_x->mutable_data(ctx.GetPlace()); - bool inplace = ctx.Attr("inplace"); - - auto in_dims = d_x->dims(); - if (!inplace) { - framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); - ctx.device_context().Wait(); - d_x->Resize(in_dims); - } else { - d_x->ShareDataWith(*d_out); - d_x->Resize(in_dims); - } - } -}; - -} // namespace operators -} // namespace paddle