diff --git a/paddle/fluid/operators/unstack_op.cc b/paddle/fluid/operators/unstack_op.cc index 4ff3249cc333231a0624cd5aab9603a6a75f4480..204aa1fa6709485b7f277270cd4cc8e32b757515 100644 --- a/paddle/fluid/operators/unstack_op.cc +++ b/paddle/fluid/operators/unstack_op.cc @@ -1,26 +1,140 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #include "paddle/fluid/operators/unstack_op.h" +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +class UnStackOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) must exist."); + + int axis = ctx->Attrs().Get("axis"); + int num = ctx->Attrs().Get("num"); + auto x_dim = ctx->GetInputDim("X"); + int rank = x_dim.size(); + PADDLE_ENFORCE_GE( + axis, -rank, "Attr(axis) must be inside [-rank, rank), where rank = %d", + rank); + PADDLE_ENFORCE_LT( + axis, rank, "Attr(axis) must be inside [-rank, rank), where rank = %d", + rank); + if (axis < 0) axis += rank; + + PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast(num), + "Number of Outputs(Y) is wrong"); + if (x_dim[axis] > 0) { + PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong"); + } + auto vec = framework::vectorize(x_dim); + vec.erase(vec.begin() + axis); + ctx->SetOutputsDim("Y", std::vector( // NOLINT + x_dim[axis], framework::make_ddim(vec))); + } +}; + +class UnStackOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "The input of unstack op."); + AddOutput("Y", "The output of unstack op.").AsDuplicable(); + AddAttr("axis", "The axis along which Input(X) should be unstacked.") + .SetDefault(0); + AddAttr("num", "The number of outputs(Y).").GreaterThan(0); + AddComment(R"DOC( + UnStack Operator. + + UnStack Input(X) into several tensors along Attr(axis). + )DOC"); + } +}; + +class UnStackGradOpDescMaker : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + std::unique_ptr op(new framework::OpDesc()); + op->SetType("unstack_grad"); + op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); + op->SetOutput(framework::GradVarName("X"), InputGrad("X")); + op->SetAttrMap(Attrs()); + return op; + } +}; + +class UnStackGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0, + "Number of Inputs(Y@Grad) must be larger than 0"); + PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true, + "Output(X@Grad) must exist."); + + auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y")); + for (size_t i = 1; i < input_dims.size(); ++i) { + PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0], + "Dims of all Inputs(Y@Grad) must be the same"); + } + + int axis = ctx->Attrs().Get("axis"); + int rank = input_dims[0].size(); + PADDLE_ENFORCE_GE( + axis, -(rank + 1), + "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank); + PADDLE_ENFORCE_LT( + axis, rank + 1, + "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank); + if (axis < 0) axis += (rank + 1); + + auto vec = framework::vectorize(input_dims[0]); + vec.insert(vec.begin() + axis, input_dims.size()); + ctx->SetOutputDim(framework::GradVarName("X"), framework::make_ddim(vec)); + } +}; + +} // namespace operators +} // namespace paddle namespace plat = paddle::platform; namespace ops = paddle::operators; -USE_OP(stack); - REGISTER_OPERATOR(unstack, ops::UnStackOp, ops::UnStackOpMaker, - ops::UnStackOpInferShape, ops::UnStackGradOpDescMaker); + ops::UnStackGradOpDescMaker); + +REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp); + +REGISTER_OP_CPU_KERNEL(unstack, + ops::UnStackKernel, + ops::UnStackKernel, + ops::UnStackKernel, + ops::UnStackKernel); -REGISTER_OPERATOR(unstack_grad, ops::UnStackGradOp, - ops::UnStackOpGradInferShape); +REGISTER_OP_CPU_KERNEL(unstack_grad, + ops::UnStackGradKernel, + ops::UnStackGradKernel, + ops::UnStackGradKernel, + ops::UnStackGradKernel); diff --git a/paddle/fluid/operators/unstack_op.cu b/paddle/fluid/operators/unstack_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..b591898a4d7aa3918e41118f1f1b3137f4638a18 --- /dev/null +++ b/paddle/fluid/operators/unstack_op.cu @@ -0,0 +1,32 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/unstack_op.h" + +namespace plat = paddle::platform; +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + unstack, ops::UnStackKernel, + ops::UnStackKernel, + ops::UnStackKernel, + ops::UnStackKernel, + ops::UnStackKernel); + +REGISTER_OP_CUDA_KERNEL( + unstack_grad, ops::UnStackGradKernel, + ops::UnStackGradKernel, + ops::UnStackGradKernel, + ops::UnStackGradKernel, + ops::UnStackGradKernel); diff --git a/paddle/fluid/operators/unstack_op.h b/paddle/fluid/operators/unstack_op.h index 6247a1f1d94273d1095f3212fb88e53f2da687a0..6344ea16f81cddb1c8f4f07f28fd318f40296427 100644 --- a/paddle/fluid/operators/unstack_op.h +++ b/paddle/fluid/operators/unstack_op.h @@ -1,134 +1,173 @@ -// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ #pragma once #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/for_range.h" + +#ifdef __NVCC__ +#include +#include "paddle/fluid/framework/array.h" +#endif namespace paddle { namespace operators { -class UnStackOpInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must exist."); - - int axis = ctx->Attrs().Get("axis"); - int num = ctx->Attrs().Get("num"); - auto x_dim = ctx->GetInputDim("X"); - int rank = x_dim.size(); - PADDLE_ENFORCE(axis >= -rank && axis < rank, - "Attr(axis) must be inside [-rank, rank), where rank = %d", - rank); - if (axis < 0) axis += rank; - - PADDLE_ENFORCE_EQ(ctx->Outputs("Y").size(), static_cast(num), - "Number of Outputs(Y) is wrong"); - if (x_dim[axis] > 0) { - PADDLE_ENFORCE_EQ(num, x_dim[axis], "Number of Outputs(Y) is wrong"); - } - auto vec = framework::vectorize(x_dim); - vec.erase(vec.begin() + axis); - ctx->SetOutputsDim("Y", std::vector( // NOLINT - x_dim[axis], framework::make_ddim(vec))); - } -}; +template +struct StackFunctor { + HOSTDEVICE StackFunctor(const VecXType &x, T *y, int n, int post) + : x_(x), y_(y), n_(n), post_(post) {} -class UnStackOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput("X", "The input of unstack op."); - AddOutput("Y", "The output of unstack op.").AsDuplicable(); - AddAttr("axis", "The axis along which Input(X) should be unstacked.") - .SetDefault(0); - AddAttr("num", "The number of outputs(Y).").GreaterThan(0); - AddComment(R"DOC( - UnStack Operator. - - UnStack Input(X) into several tensors along Attr(axis). - )DOC"); + HOSTDEVICE void operator()(int idx) { + int i = idx / (n_ * post_); + int which_x = idx / post_ - i * n_; + int x_index = i * post_ + idx % post_; + y_[idx] = x_[which_x][x_index]; } -}; - -class UnStackOp : public framework::OperatorBase { - public: - using OperatorBase::OperatorBase; private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto stack_grad_op = framework::OpRegistry::CreateOp( - "stack_grad", {{framework::GradVarName("Y"), {Input("X")}}}, - {{framework::GradVarName("X"), Outputs("Y")}}, Attrs()); - stack_grad_op->Run(scope, place); - } + VecXType x_; + T *y_; + int n_; + int post_; }; -class UnStackOpGradInferShape : public framework::InferShapeBase { - public: - void operator()(framework::InferShapeContext *ctx) const override { - PADDLE_ENFORCE_GT(ctx->Inputs(framework::GradVarName("Y")).size(), 0, - "Number of Inputs(Y@Grad) must be larger than 0"); - PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), - "Output(X@Grad) must exist."); - - auto input_dims = ctx->GetInputsDim(framework::GradVarName("Y")); - for (size_t i = 1; i < input_dims.size(); ++i) { - PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0], - "Dims of all Inputs(Y@Grad) must be the same"); - } - - int axis = ctx->Attrs().Get("axis"); - int rank = input_dims[0].size(); - PADDLE_ENFORCE( - axis >= -(rank + 1) && axis < rank + 1, - "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d", rank); - if (axis < 0) axis += (rank + 1); +template +struct StackGradFunctor { + HOSTDEVICE StackGradFunctor(const VecDxType &dx, const T *dy, int n, int post) + : dx_(dx), dy_(dy), n_(n), post_(post) {} - auto vec = framework::vectorize(input_dims[0]); - vec.insert(vec.begin() + axis, input_dims.size()); - ctx->SetOutputDim(framework::GradVarName("X"), framework::make_ddim(vec)); + HOSTDEVICE void operator()(int idx) { + int i = idx / (n_ * post_); + int which_x = idx / post_ - i * n_; + int x_index = i * post_ + idx % post_; + dx_[which_x][x_index] = dy_[idx]; } + + private: + VecDxType dx_; + const T *dy_; + int n_; + int post_; }; -class UnStackGradOpDescMaker : public framework::SingleGradOpDescMaker { +template +static inline void StackFunctorForRange(const DeviceContext &ctx, + const VecXType &x, T *y, int total_num, + int n, int post) { + platform::ForRange for_range(ctx, total_num); + for_range(StackFunctor(x, y, n, post)); +} + +template +static inline void StackGradFunctorForRange(const DeviceContext &ctx, + const VecDxType &dx, const T *dy, + int total_num, int n, int post) { + platform::ForRange for_range(ctx, total_num); + for_range(StackGradFunctor(dx, dy, n, post)); +} + +template +class UnStackGradKernel : public framework::OpKernel { + using Tensor = framework::LoDTensor; + public: - using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; - - protected: - std::unique_ptr Apply() const override { - std::unique_ptr op(new framework::OpDesc()); - op->SetType("unstack_grad"); - op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); - op->SetOutput(framework::GradVarName("X"), InputGrad("X")); - op->SetAttrMap(Attrs()); - return op; + void Compute(const framework::ExecutionContext &ctx) const override { + auto x = ctx.MultiInput(framework::GradVarName("Y")); + auto *y = ctx.Output(framework::GradVarName("X")); + + int axis = ctx.Attr("axis"); + if (axis < 0) axis += (x[0]->dims().size() + 1); + + int n = static_cast(x.size()); + auto *y_data = y->mutable_data(ctx.GetPlace()); + std::vector x_datas(n); + for (int i = 0; i < n; i++) x_datas[i] = x[i]->data(); + + int pre = 1; + int post = 1; + auto &dim = x[0]->dims(); + for (auto i = 0; i < axis; ++i) pre *= dim[i]; + for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; + +#ifdef __NVCC__ + int total_num = pre * n * post; + auto &dev_ctx = ctx.template device_context(); + + thrust::device_vector device_x_vec(x_datas); + auto x_data_arr = device_x_vec.data().get(); + + StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); + + // Wait() must be called because device_x_vec may be destructed before + // kernel ends + dev_ctx.Wait(); +#else + auto x_data_arr = x_datas.data(); + + size_t x_offset = 0; + size_t y_offset = 0; + for (int i = 0; i < pre; i++) { + for (int j = 0; j < n; j++) { + std::memcpy(y_data + y_offset, x_data_arr[j] + x_offset, + post * sizeof(T)); + y_offset += post; + } + x_offset += post; + } +#endif } }; -class UnStackGradOp : public framework::OperatorBase { - public: - using OperatorBase::OperatorBase; +template +class UnStackKernel : public framework::OpKernel { + using Tensor = framework::LoDTensor; - private: - void RunImpl(const framework::Scope &scope, - const platform::Place &place) const override { - auto stack_op = framework::OpRegistry::CreateOp( - "stack", {{"X", Inputs(framework::GradVarName("Y"))}}, - {{"Y", {Output(framework::GradVarName("X"))}}}, Attrs()); - stack_op->Run(scope, place); + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto *dy = ctx.Input("X"); + auto dx = ctx.MultiOutput("Y"); + int axis = ctx.Attr("axis"); + if (axis < 0) axis += dy->dims().size(); + + int n = dy->dims()[axis]; + std::vector dx_datas(n); // NOLINT + for (int i = 0; i < n; i++) { + dx_datas[i] = dx[i]->mutable_data(ctx.GetPlace()); + } + auto dy_data = dy->data(); + + int pre = 1; + for (int i = 0; i < axis; ++i) pre *= dy->dims()[i]; + int total_num = dy->numel(); + int post = total_num / (n * pre); + + auto &dev_ctx = ctx.template device_context(); +#ifdef __NVCC__ + thrust::device_vector device_dx_vec(dx_datas); + auto dx_data_arr = device_dx_vec.data().get(); +#else + auto dx_data_arr = dx_datas.data(); +#endif + StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post); +#ifdef __NVCC__ + // Wait() must be called because device_dx_vec may be destructed before + // kernel ends + dev_ctx.Wait(); +#endif } };