// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/operators/stack_op.h" #include #include namespace plat = paddle::platform; namespace ops = paddle::operators; namespace paddle { namespace operators { class StackOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE_GT(ctx->Inputs("X").size(), 0, platform::errors::InvalidArgument( "Number of Inputs(X) must be larger than 0, but" " received value is:%d.", ctx->Inputs("X").size())); PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true, platform::errors::InvalidArgument( "Output(Y) of stack_op should not be null.")); auto input_dims = ctx->GetInputsDim("X"); for (size_t i = 1; i < input_dims.size(); ++i) { PADDLE_ENFORCE_EQ(input_dims[i], input_dims[0], platform::errors::InvalidArgument( "Dims of all Inputs(X) must be the same, but" " received input %d dim is:%d not equal to input 0" " dim:%d.", i, input_dims[i], input_dims[0])); } // Only lod of X[0] would be shared with Y ctx->ShareLoD("X", /*->*/ "Y"); int axis = ctx->Attrs().Get("axis"); int rank = input_dims[0].size(); PADDLE_ENFORCE_GE( axis, -(rank + 1), platform::errors::InvalidArgument( "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d, " "but received axis is:%d.", rank, axis)); PADDLE_ENFORCE_LT( axis, rank + 1, platform::errors::InvalidArgument( "Attr(axis) must be inside [-(rank+1), rank+1), where rank = %d, " "but received axis is:%d", rank, axis)); if (axis < 0) axis += (rank + 1); auto vec = framework::vectorize(input_dims[0]); vec.insert(vec.begin() + axis, input_dims.size()); ctx->SetOutputDim("Y", framework::make_ddim(vec)); } }; class StackOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "The input of stack op.").AsDuplicable(); AddOutput("Y", "The output of stack op."); AddAttr("axis", "The axis along which all of the Inputs(X) should be stacked.") .SetDefault(0); AddComment(R"DOC( Stack Operator. Stack all of the Inputs(X) into one tensor along Attr(axis). The dims of all Inputs(X) must be the same. )DOC"); } }; class StackOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE_EQ( ctx->HasInput(framework::GradVarName("Y")), true, platform::errors::InvalidArgument("Input(Y@Grad) not exist.")); int axis = ctx->Attrs().Get("axis"); auto dy_dim = ctx->GetInputDim(framework::GradVarName("Y")); int rank = dy_dim.size(); PADDLE_ENFORCE_GE( axis, -rank, platform::errors::InvalidArgument( "Attr(axis) must be inside [-rank, rank), where rank = %d, " "but received axis is:%d.", rank, axis)); PADDLE_ENFORCE_LT( axis, rank, platform::errors::InvalidArgument( "Attr(axis) must be inside [-rank, rank), where rank = %d, " "but received axis is:%d.", rank, axis)); if (axis < 0) axis += rank; PADDLE_ENFORCE_EQ( ctx->Outputs(framework::GradVarName("X")).size(), static_cast(dy_dim[axis]), platform::errors::InvalidArgument( "Number of Outputs(X@Grad) is equal to dy dim at axis, but" " received outputs size is:%d, dy dims is:%d.", ctx->Outputs(framework::GradVarName("X")).size(), static_cast(dy_dim[axis]))); auto vec = framework::vectorize(dy_dim); vec.erase(vec.begin() + axis); ctx->SetOutputsDim( framework::GradVarName("X"), std::vector(dy_dim[axis], framework::make_ddim(vec))); } }; template class StackGradOpMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; protected: std::unique_ptr Apply() const override { std::unique_ptr op(new T()); op->SetType("stack_grad"); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X", false)); op->SetAttrMap(this->Attrs()); return op; } }; } // namespace operators } // namespace paddle REGISTER_OPERATOR(stack, ops::StackOp, ops::StackOpMaker, ops::StackGradOpMaker, ops::StackGradOpMaker); REGISTER_OPERATOR(stack_grad, ops::StackOpGrad); REGISTER_OP_CPU_KERNEL(stack, ops::StackKernel, ops::StackKernel, ops::StackKernel, ops::StackKernel); REGISTER_OP_CPU_KERNEL(stack_grad, ops::StackGradKernel, ops::StackGradKernel, ops::StackGradKernel, ops::StackGradKernel);