// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { class Relu3Op : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { auto in_dims = ctx->GetInputDim("X"); ctx->SetOutputDim("Y", in_dims); } }; class Relu3OpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "The input tensor."); AddOutput("Y", "Output of relu_op"); AddComment(R"DOC( Relu3 Operator. )DOC"); } }; class Relu3GradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { auto in_dims = ctx->GetInputDim(framework::GradVarName("Y")); ctx->SetOutputDim(framework::GradVarName("X"), in_dims); } }; template class Relu3GradMaker : public framework::SingleGradOpMaker { public: using framework::SingleGradOpMaker::SingleGradOpMaker; void Apply(GradOpPtr op) const override { op->SetType("relu3_grad"); op->SetInput("Y", this->Output("Y")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetAttrMap(this->Attrs()); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); } }; using Tensor = framework::Tensor; template class Relu3Kernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* in_t = ctx.Input("X"); auto* out_t = ctx.Output("Y"); auto x = in_t->data(); auto y = out_t->mutable_data(ctx.GetPlace()); for (int i = 0; i < in_t->numel(); ++i) { y[i] = std::max(static_cast(0.), x[i]); } } }; template class Relu3GradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { auto* dy_t = ctx.Input(framework::GradVarName("Y")); auto* y_t = ctx.Input("Y"); auto* dx_t = ctx.Output(framework::GradVarName("X")); auto dy = dy_t->data(); auto y = y_t->data(); auto dx = dx_t->mutable_data(ctx.GetPlace()); for (int i = 0; i < y_t->numel(); ++i) { dx[i] = dy[i] * (y[i] > static_cast(0) ? 1. : 0.); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; using CPU = paddle::platform::CPUDeviceContext; REGISTER_OPERATOR(relu3, ops::Relu3Op, ops::Relu3OpMaker, ops::Relu3GradMaker, ops::Relu3GradMaker); REGISTER_OPERATOR(relu3_grad, ops::Relu3GradOp); REGISTER_OP_CPU_KERNEL(relu3, ops::Relu3Kernel, ops::Relu3Kernel); REGISTER_OP_CPU_KERNEL(relu3_grad, ops::Relu3GradKernel, ops::Relu3GradKernel);