From 298e74da1e6a1b8512cf2f7f0093e5659fca51f7 Mon Sep 17 00:00:00 2001 From: chenweihang Date: Thu, 28 Jun 2018 09:09:24 +0000 Subject: [PATCH] add squeeze op c++ part, compile success --- paddle/fluid/operators/squeeze_op.cc | 155 +++++++++++++++++++++++++++ paddle/fluid/operators/squeeze_op.cu | 30 ++++++ paddle/fluid/operators/squeeze_op.h | 73 +++++++++++++ 3 files changed, 258 insertions(+) create mode 100644 paddle/fluid/operators/squeeze_op.cc create mode 100644 paddle/fluid/operators/squeeze_op.cu create mode 100644 paddle/fluid/operators/squeeze_op.h diff --git a/paddle/fluid/operators/squeeze_op.cc b/paddle/fluid/operators/squeeze_op.cc new file mode 100644 index 0000000000..8f453b059f --- /dev/null +++ b/paddle/fluid/operators/squeeze_op.cc @@ -0,0 +1,155 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/squeeze_op.h" +#include +#include + +namespace paddle { +namespace operators { + +using framework::OpKernelType; +using framework::Tensor; + +class SqueezeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SqueezeOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SqueezeOp should not be null."); + + const auto& x_dims = ctx->GetInputDim("X"); + // TODO(chenweihang): need check input tensor dims (<9). + + const auto& axes = ctx->Attrs().Get>("axes"); + // TODO(chenweihang): need check axes is valid. + // PADDLE_ENFORCE(); + for (int a : axes) { + PADDLE_ENFORCE_LT(a, x_dims.size(), + "The axis must be less than input tensor's rank."); + } + + auto out_dims = GetOutputShape(axes, x_dims); + ctx->SetOutputDim("Out", out_dims); + // TODO(chenweihang): need other check. + } + + static framework::DDim GetOutputShape(const std::vector squeeze_dims, + const framework::DDim& in_dims) { + int num_squeeze_dims = squeeze_dims.size(); + int cnt_squeezed_dims = 0; + bool should_squeeze[9] = {false}; + + // Determines number of dimensions of output tensor after squeeze. + // Mark and count the dimensions need to be squeezed + if (num_squeeze_dims == 0) { + for (int idx = 0; idx < in_dims.size(); ++idx) { + if (in_dims[idx] == 1) { + should_squeeze[idx] = true; + ++cnt_squeezed_dims; + } + } + } else { + for (int idx = 0; idx < num_squeeze_dims; ++idx) { + int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size() + : squeeze_dims[idx]; + // TODO(chenweihang): shoude use PADALE_ENFORCE ? or if. + PADDLE_ENFORCE_GE(current, 0, "Invalid axis is given."); + PADDLE_ENFORCE_LT(current, in_dims.size(), "Invalid axis is given."); + PADDLE_ENFORCE_EQ(in_dims[current], 1, "Invalid axis is given."); + + if (!(should_squeeze[current])) ++cnt_squeezed_dims; + should_squeeze[current] = true; + } + } + + // Make output dimensions + std::vector output_shape(in_dims.size() - cnt_squeezed_dims, 0); + for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) { + if (!should_squeeze[in_idx]) { + output_shape[out_idx++] = in_dims[in_idx]; + } + } + + return framework::make_ddim(output_shape); + } +}; + +class SqueezeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor), Tensors with at least max(dims) dimensions."); + AddOutput("Out", "(Tensor), Reshaped tensor with same data as input."); + AddAttr>("axes", + "List of positive integers," + " indicate the dimensions to squeeze."); + AddAttr("inplace", + "(default: false) Change the source tensor's shape without " + "memory copy. When Attr(inplace) is set true, the output " + "tensor shares memory with Input(X), otherwise, a new output " + "tensor is created, and its data are copied from Input(x).") + .SetDefault(false); + AddComment(R"DOC( + Squeeze Operator. + + Remove single-dimensional entries from the shape of a tensor. + Takes a parameter axes with a list of axes to squeeze. + If axes is not provided, all the single dimensions will be removed from the shape. + If an axis is selected with shape entry not equal to one, an error is raised. + )DOC"); + } +}; + +class SqueezeGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), + "Input(X) of SqueezeOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Output(Out@GRAD/) of SqueezeOp should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + framework::ToDataType(ctx.Input("X")->type()), + ctx.device_context()); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(squeeze, ops::SqueezeOp, ops::SqueezeOpMaker, + paddle::framework::DefaultGradOpDescMaker); +REGISTER_OPERATOR(squeeze_grad, ops::SqueezeGradOp); +REGISTER_OP_CPU_KERNEL( + squeeze, ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel); +REGISTER_OP_CPU_KERNEL( + squeeze_grad, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel); diff --git a/paddle/fluid/operators/squeeze_op.cu b/paddle/fluid/operators/squeeze_op.cu new file mode 100644 index 0000000000..1096907daa --- /dev/null +++ b/paddle/fluid/operators/squeeze_op.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#define EIGEN_USE_GPU + +#include "paddle/fluid/operators/squeeze_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + squeeze, ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel, + ops::SqueezeKernel); +REGISTER_OP_CUDA_KERNEL( + squeeze_grad, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel, + ops::SqueezeGradKernel); diff --git a/paddle/fluid/operators/squeeze_op.h b/paddle/fluid/operators/squeeze_op.h new file mode 100644 index 0000000000..ce6f40e7a4 --- /dev/null +++ b/paddle/fluid/operators/squeeze_op.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class SqueezeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *out = ctx.Output("Out"); + auto *in = ctx.Input("X"); + + framework::DDim out_dims = out->dims(); + + // TODO(chenweihang): Where is this attr be add. + bool inplace = ctx.Attr("inplace"); + out->Resize(out_dims); + if (!inplace) { + out->mutable_data(ctx.GetPlace()); + framework::TensorCopySync(*in, ctx.GetPlace(), out); + out->Resize(out_dims); + } else { + out->ShareDataWith(*in); + out->Resize(out_dims); + } + } +}; + +template +class SqueezeGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *d_out = ctx.Input(framework::GradVarName("Out")); + auto *d_x = ctx.Output(framework::GradVarName("X")); + + d_x->mutable_data(ctx.GetPlace()); + bool inplace = ctx.Attr("inplace"); + + auto in_dims = d_x->dims(); + if (!inplace) { + framework::TensorCopy(*d_out, ctx.GetPlace(), ctx.device_context(), d_x); + ctx.device_context().Wait(); + d_x->Resize(in_dims); + } else { + d_x->ShareDataWith(*d_out); + d_x->Resize(in_dims); + } + } +}; + +} // namespace operators +} // namespace paddle -- GitLab