From 950cc60d2b2e6ab9c05f82df3f2d3f3179541209 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 22 Aug 2017 15:29:38 +0800 Subject: [PATCH] Add minus --- paddle/framework/CMakeLists.txt | 3 +- paddle/framework/pybind.cc | 1 + paddle/operators/CMakeLists.txt | 1 + paddle/operators/minus_op.cc | 84 +++++++++++++++++++++++++++++++++ paddle/operators/minus_op.cu | 18 +++++++ paddle/operators/minus_op.h | 39 +++++++++++++++ 6 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 paddle/operators/minus_op.cc create mode 100644 paddle/operators/minus_op.cu create mode 100644 paddle/operators/minus_op.h diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 5df14ae78d..c9cf45e9d7 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -56,5 +56,6 @@ cc_library(paddle_pybind SHARED uniform_random_op gaussian_random_op fill_zeros_like_op - scale_op) + scale_op + minus_op) endif(WITH_PYTHON) diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index 3aaf0de150..b4b7921d33 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -44,6 +44,7 @@ USE_OP(gaussian_random); USE_OP(uniform_random); USE_OP(scale); USE_OP_ITSELF(identity); +USE_OP(minus); namespace paddle { namespace framework { diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 0ba598823b..61f7a4070f 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -69,3 +69,4 @@ op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op) +op_library(minus_op SRCS minus_op.cc minus_op.cu DEPS scale_op) diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc new file mode 100644 index 0000000000..c660ab5d32 --- /dev/null +++ b/paddle/operators/minus_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/minus_op.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class MinusOp : public framework::OperatorWithKernel { + public: + MinusOp(const std::string &type, const VarNameMap &inputs, + const VarNameMap &outputs, const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto *left_tensor = ctx.Input("X"); + auto *right_tensor = ctx.Input("Y"); + + PADDLE_ENFORCE_EQ( + framework::product(left_tensor->dims()), + framework::product(right_tensor->dims()), + "Minus operator must take two tensor with same num of elements"); + ctx.Output("Out")->Resize(left_tensor->dims()); + } +}; + +class MinusOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MinusOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The left tensor of minus operator.").NotInGradient(); + AddInput("Y", "The right tensor of minus operator.").NotInGradient(); + AddOutput("Out", "The output tensor of minus operator.").NotInGradient(); + + AddComment(R"DOC(Minus Operator + +Equation: Out = X - Y +)DOC"); + } +}; +template +class MinusGradOp : public NetOp { + public: + MinusGradOp(const std::string &type, const VarNameMap &inputs, + const VarNameMap &outputs, const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + auto out_grad = Input(framework::GradVarName("Out")); + auto x_grad = Output(framework::GradVarName("X")); + auto y_grad = Output(framework::GradVarName("Y")); + + // x_grad = out_grad + AddOp(framework::OpRegistry::CreateOp("identity", {{"X", {out_grad}}}, + {{"Out", {x_grad}}}, {})); + + framework::AttributeMap scale_attr; + scale_attr["scale"] = static_cast(-1); + AddOp(framework::OpRegistry::CreateOp("scale", {{"X", {out_grad}}}, + {{"Out", {y_grad}}}, scale_attr)); + } +}; + +} // namespace operators +} // namespace paddle + +USE_OP(scale); +USE_OP_ITSELF(identity); +namespace ops = paddle::operators; +REGISTER_OP(minus, ops::MinusOp, ops::MinusOpMaker, minus_grad, + ops::MinusGradOp); +REGISTER_OP_CPU_KERNEL(minus, + ops::MinusKernel); diff --git a/paddle/operators/minus_op.cu b/paddle/operators/minus_op.cu new file mode 100644 index 0000000000..a8375cc630 --- /dev/null +++ b/paddle/operators/minus_op.cu @@ -0,0 +1,18 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/minus_op.h" + +REGISTER_OP_GPU_KERNEL( + minus, paddle::operators::MinusKernel); diff --git a/paddle/operators/minus_op.h b/paddle/operators/minus_op.h new file mode 100644 index 0000000000..6310a4fd51 --- /dev/null +++ b/paddle/operators/minus_op.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class MinusKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* left_tensor = context.Input("X"); + auto* right_tensor = context.Input("Y"); + auto* out_tensor = context.Output("Out"); + + out_tensor->mutable_data(context.GetPlace()); + auto& dev = context.GetEigenDevice(); + framework::EigenVector::Flatten(*out_tensor).device(dev) = + framework::EigenVector::Flatten(*left_tensor) - + framework::EigenVector::Flatten(*right_tensor); + } +}; + +} // namespace operators +} // namespace paddle -- GitLab