// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { class Relu2Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { auto in_dims = ctx->GetInputDim("X"); ctx->SetOutputDim("Y", in_dims); } }; class Relu2OpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "The input tensor."); AddOutput("Y", "Output of relu_op"); AddComment(R"DOC( Relu2 Operator. )DOC"); } }; class Relu2GradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { auto in_dims = ctx->GetInputDim(framework::GradVarName("Y")); ctx->SetOutputDim(framework::GradVarName("X"), in_dims); } }; class Relu2GradMaker : public framework::SingleGradOpDescMaker { public: using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; std::unique_ptr Apply() const override { auto* op = new framework::OpDesc(); op->SetType("relu2_grad"); op->SetInput("Y", Output("Y")); op->SetInput(framework::GradVarName("Y"), OutputGrad("Y")); op->SetAttrMap(Attrs()); op->SetOutput(framework::GradVarName("X"), InputGrad("X")); return std::unique_ptr(op); } }; using Tensor = framework::Tensor; template class Relu2Kernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in_t = ctx.Input("X"); auto* out_t = ctx.Output("Y"); auto x = in_t->data(); auto y = out_t->mutable_data(ctx.GetPlace()); for (int i = 0; i < in_t->numel(); ++i) { y[i] = std::max(static_cast(0.), x[i]); } } }; template class Relu2GradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* dy_t = ctx.Input(framework::GradVarName("Y")); auto* y_t = ctx.Input("Y"); auto* dx_t = ctx.Output(framework::GradVarName("X")); auto dy = dy_t->data(); auto y = y_t->data(); auto dx = dx_t->mutable_data(ctx.GetPlace()); for (int i = 0; i < y_t->numel(); ++i) { dx[i] = dy[i] * (y[i] > static_cast(0) ? 1. : 0.); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR(relu2, ops::Relu2Op, ops::Relu2OpMaker, ops::Relu2GradMaker); REGISTER_OPERATOR(relu2_grad, ops::Relu2GradOp); REGISTER_OP_CPU_KERNEL(relu2, ops::Relu2Kernel, ops::Relu2Kernel); REGISTER_OP_CPU_KERNEL(relu2_grad, ops::Relu2GradKernel, ops::Relu2GradKernel);