From 490ca5f1aeb5bfebd1a9ba4ac3e27518c979ef44 Mon Sep 17 00:00:00 2001 From: zchen0211 Date: Thu, 14 Sep 2017 22:31:12 -0700 Subject: [PATCH] prelu_op --- paddle/operators/prelu_op.cc | 16 +++++++------- paddle/operators/prelu_op.cu | 21 ------------------- paddle/operators/prelu_op.h | 17 +++++++-------- .../v2/framework/tests/test_prelu_op.py | 5 +++-- 4 files changed, 20 insertions(+), 39 deletions(-) delete mode 100644 paddle/operators/prelu_op.cu diff --git a/paddle/operators/prelu_op.cc b/paddle/operators/prelu_op.cc index 831958e3a4..030f320ab9 100644 --- a/paddle/operators/prelu_op.cc +++ b/paddle/operators/prelu_op.cc @@ -33,20 +33,20 @@ class PreluOp : public framework::OperatorWithKernel { } }; -template +// template class PreluOpMaker : public framework::OpProtoAndCheckerMaker { public: PreluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input tensor of prelu operator.").NotInGradient(); - AddOutput("Out", "The output tensor of prelu operator.").NotInGradient(); + AddInput("X", "The input tensor of prelu operator."); + AddOutput("Out", "The output tensor of prelu operator."); AddComment(R"DOC(Prelu operator The equation is: f(x) = alpha * x , for x < 0 f(x) = x , for x >= 0 )DOC"); - AddAttr("alpha", "The scaling factor alpha of prelu.") + AddAttr("alpha", "The scaling factor alpha of prelu.") .SetDefault(0.0); } }; @@ -58,8 +58,10 @@ class PreluGradOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { - auto X_grad = ctx.Output(framework::GradVarName("X")); - auto X = ctx.Input("X"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + auto *X_grad = + ctx.Output(framework::GradVarName("X")); + auto *X = ctx.Input("X"); X_grad->Resize(X->dims()); } @@ -70,7 +72,7 @@ class PreluGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; -REGISTER_OP(prelu, ops::PreluOp, ops::PreluOpMaker, prelu_grad, +REGISTER_OP(prelu, ops::PreluOp, ops::PreluOpMaker, prelu_grad, ops::PreluGradOp); REGISTER_OP_CPU_KERNEL(prelu, ops::PreluKernel); diff --git a/paddle/operators/prelu_op.cu b/paddle/operators/prelu_op.cu deleted file mode 100644 index 314dcba375..0000000000 --- a/paddle/operators/prelu_op.cu +++ /dev/null @@ -1,21 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/prelu_op.h" - -REGISTER_OP_GPU_KERNEL( - prelu, paddle::operators::PreluKernel); -REGISTER_OP_GPU_KERNEL( - prelu_grad, - paddle::operators::PreluGradKernel); diff --git a/paddle/operators/prelu_op.h b/paddle/operators/prelu_op.h index 0bb6f61e30..a1e719e314 100644 --- a/paddle/operators/prelu_op.h +++ b/paddle/operators/prelu_op.h @@ -24,7 +24,7 @@ template using EigenVector = framework::EigenVector; -template +template class PreluKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { @@ -33,30 +33,29 @@ class PreluKernel : public framework::OpKernel { Out->mutable_data(context.GetPlace()); - auto alpha = static_cast(context.Attr("alpha")); + auto alpha = static_cast(context.Attr("alpha")); auto X_vec = EigenVector::Flatten(*X); auto Out_vec = EigenVector::Flatten(*Out); - auto place = context.GetEigenDevice(); - - Out_vec.device(place) = X_vec.cwiseMax(0.f) + X_vec.cwiseMin(0.f) * alpha; + // auto place = context.GetEigenDevice(); + // Out_vec.device(place) + Out_vec = X_vec.cwiseMax(0.f) + X_vec.cwiseMin(0.f) * alpha; } }; -template +template class PreluGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* dX = context.Output(framework::GradVarName("X")); auto* dO = context.Input(framework::GradVarName("Out")); - auto* Out = context.Output("Out"); + auto* Out = context.Input("Out"); - auto alpha = static_cast(context.Attr("alpha")); + auto alpha = static_cast(context.Attr("alpha")); dX->mutable_data(context.GetPlace()); - for (int i = 0; i < dX->numel(); ++i) { if (Out->data()[i] > 0) { dX->data()[i] = dO->data()[i]; diff --git a/python/paddle/v2/framework/tests/test_prelu_op.py b/python/paddle/v2/framework/tests/test_prelu_op.py index c207940d1f..39b6f673fd 100644 --- a/python/paddle/v2/framework/tests/test_prelu_op.py +++ b/python/paddle/v2/framework/tests/test_prelu_op.py @@ -6,11 +6,12 @@ from op_test import OpTest class PreluTest(OpTest): def setUp(self): self.op_type = "prelu" - self.inputs = {'X': np.random.random((10, 10)).astype("float32")} + self.inputs = {'X': np.random.normal(size=(3, 5)).astype("float32")} self.attrs = {'alpha': 0.1} out_np = np.maximum(self.inputs['X'], 0.) out_np = out_np + np.minimum(self.inputs['X'], 0.) * self.attrs['alpha'] - self.outputs = {'Out': self.inputs['X'] * self.attrs['scale']} + assert out_np is not self.inputs['X'] + self.outputs = {'Out': out_np} def test_check_output(self): self.check_output() -- GitLab