From d3f219aa9911015bd8c4a1316b85620a07eb9f49 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Mon, 21 Aug 2017 18:09:17 +0800 Subject: [PATCH] Change IdentityOp to ScaleOp --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/pybind.cc | 3 +- paddle/framework/tensor.h | 5 +- paddle/operators/CMakeLists.txt | 2 +- paddle/operators/identity_op.cc | 71 ------------ paddle/operators/net_op.cc | 9 +- paddle/operators/scale_op.cc | 102 ++++++++++++++++++ .../operators/{identity_op.cu => scale_op.cu} | 5 +- .../operators/{identity_op.h => scale_op.h} | 16 ++- .../paddle/v2/framework/tests/CMakeLists.txt | 2 +- .../v2/framework/tests/gradient_checker.py | 7 +- ...ty_op.py => test_scale_and_identity_op.py} | 19 ++++ 12 files changed, 158 insertions(+), 85 deletions(-) delete mode 100644 paddle/operators/identity_op.cc create mode 100644 paddle/operators/scale_op.cc rename paddle/operators/{identity_op.cu => scale_op.cu} (81%) rename paddle/operators/{identity_op.h => scale_op.h} (66%) rename python/paddle/v2/framework/tests/{test_identity_op.py => test_scale_and_identity_op.py} (51%) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index f249512f472..5df14ae78d4 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -56,5 +56,5 @@ cc_library(paddle_pybind SHARED uniform_random_op gaussian_random_op fill_zeros_like_op - identity_op) + scale_op) endif(WITH_PYTHON) diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index ddb244623fd..3aaf0de1506 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -42,7 +42,8 @@ USE_OP(fill_zeros_like); USE_OP_ITSELF(recurrent_op); USE_OP(gaussian_random); USE_OP(uniform_random); -USE_OP(identity); +USE_OP(scale); +USE_OP_ITSELF(identity); namespace paddle { namespace framework { diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index b8c779f4e5f..643f8754917 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -105,7 +105,10 @@ class Tensor { template inline Tensor Slice(const int& begin_idx, const int& end_idx) const; - platform::Place place() const { return holder_->place(); } + platform::Place place() const { + PADDLE_ENFORCE_NOT_NULL(holder_, "Tensor get place() must contains holder"); + return holder_->place(); + } private: template diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 20e562c7d34..0ba598823b9 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -68,4 +68,4 @@ op_library(recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc DEPS framework_proto tensor op_registry operator net_op) op_library(uniform_random_op SRCS uniform_random_op.cc uniform_random_op.cu) -op_library(identity_op SRCS identity_op.cc identity_op.cu DEPS net_op) +op_library(scale_op SRCS scale_op.cc scale_op.cu DEPS net_op) diff --git a/paddle/operators/identity_op.cc b/paddle/operators/identity_op.cc deleted file mode 100644 index cac44020bcf..00000000000 --- a/paddle/operators/identity_op.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/identity_op.h" -#include "paddle/operators/net_op.h" - -namespace paddle { -namespace operators { - -class IdentityOp : public framework::OperatorWithKernel { - public: - IdentityOp(const std::string &type, const VarNameMap &inputs, - const VarNameMap &outputs, const framework::AttributeMap &attrs) - : OperatorWithKernel(type, inputs, outputs, attrs) {} - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - auto *in = ctx.Input("X"); - auto *out = ctx.Output("Out"); - out->Resize(in->dims()); - } -}; - -class IdentityOpMaker : public framework::OpProtoAndCheckerMaker { - public: - IdentityOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "The input tensor of identity operator.").NotInGradient(); - AddOutput("Out", "The output tensor of identity operator.").NotInGradient(); - AddComment(R"DOC(Identity operator - -The equation is: Out = X -)DOC"); - } -}; - -// Identity Op's gradient is identity op, too. -// Grad(Out=identity_op(X)) => Grad(X) = identity_op(Grad(Out)) -class IdentityGradOp : public NetOp { - public: - IdentityGradOp(const std::string &type, const VarNameMap &inputs, - const VarNameMap &outputs, - const framework::AttributeMap &attrs) - : NetOp(type, inputs, outputs, attrs) { - AddOp(framework::OpRegistry::CreateOp( - "identity", {{"X", {Input(framework::GradVarName("Out"))}}}, - {{"Out", {Output(framework::GradVarName("X"))}}}, {})); - CompleteAddOp(false); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP(identity, ops::IdentityOp, ops::IdentityOpMaker, identity_grad, - ops::IdentityGradOp); -REGISTER_OP_CPU_KERNEL(identity, ops::IdentityKernel); diff --git a/paddle/operators/net_op.cc b/paddle/operators/net_op.cc index a7d71051109..7e3779ed2e9 100644 --- a/paddle/operators/net_op.cc +++ b/paddle/operators/net_op.cc @@ -68,10 +68,15 @@ std::string NetOp::DebugString() const { bool NetOp::IsNetOp() const { return true; } std::vector NetOp::OutputVars(bool has_intermediate) const { + std::vector all; + for (auto& pair : this->outputs_) { + for (auto& var_name : pair.second) { + all.push_back(var_name); + } + } if (has_intermediate) { - return this->outputs_.at(kAll); + return all; } - auto& all = this->outputs_.at(kAll); std::vector ret_val; for (auto& each : all) { if (!Contains(intermediate_outputs_, each)) { diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc new file mode 100644 index 00000000000..3b18ff078e7 --- /dev/null +++ b/paddle/operators/scale_op.cc @@ -0,0 +1,102 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/scale_op.h" +#include "paddle/operators/net_op.h" + +namespace paddle { +namespace operators { + +class ScaleOp : public framework::OperatorWithKernel { + public: + ScaleOp(const std::string &type, const VarNameMap &inputs, + const VarNameMap &outputs, const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + auto *in = ctx.Input("X"); + auto *out = ctx.Output("Out"); + out->Resize(in->dims()); + } +}; + +template +class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ScaleOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input tensor of scale operator.").NotInGradient(); + AddOutput("Out", "The output tensor of scale operator.").NotInGradient(); + AddComment(R"DOC(Scale operator + +The equation is: Out = scale*X +)DOC"); + AddAttr("scale", "scale of scale operator.").SetDefault(1.0); + } +}; + +// Identity Op's gradient is identity op, too. +// Grad(Out=scale(X)) => Grad(X) = scale(Grad(Out)) +template +class ScaleGradOp : public NetOp { + public: + ScaleGradOp(const std::string &type, const VarNameMap &inputs, + const VarNameMap &outputs, const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + AddOp(framework::OpRegistry::CreateOp( + "scale", {{"X", {Input(framework::GradVarName("Out"))}}}, + {{"Out", {Output(framework::GradVarName("X"))}}}, + {{"scale", GetAttr("scale")}})); + CompleteAddOp(false); + } +}; + +// identity is a alias of scale op. This is also a example for creating a alias +// operator. +template +class IdentityOpMaker : public framework::OpProtoAndCheckerMaker { + public: + IdentityOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "input tensor of identity op"); + AddOutput("Out", "output tensor of identity op"); + AddComment("identity operator. Just a alias of scale op which scale = 1.0"); + } +}; + +template +class IdentityOp : public NetOp { + public: + IdentityOp(const std::string &type, const VarNameMap &inputs, + const VarNameMap &outputs, const framework::AttributeMap &attrs) + : NetOp(type, inputs, outputs, attrs) { + AddOp(framework::OpRegistry::CreateOp( + "scale", {{"X", {Input("X")}}}, {{"Out", {Output("Out")}}}, + {{"scale", static_cast(1)}})); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP(scale, ops::ScaleOp, ops::ScaleOpMaker, scale_grad, + ops::ScaleGradOp); +REGISTER_OP_CPU_KERNEL(scale, + ops::ScaleKernel); +REGISTER_OP_WITHOUT_GRADIENT(identity, ops::IdentityOp, + ops::IdentityOpMaker); diff --git a/paddle/operators/identity_op.cu b/paddle/operators/scale_op.cu similarity index 81% rename from paddle/operators/identity_op.cu rename to paddle/operators/scale_op.cu index 3053104bbec..63efbe0da8a 100644 --- a/paddle/operators/identity_op.cu +++ b/paddle/operators/scale_op.cu @@ -12,6 +12,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/operators/identity_op.h" +#include "paddle/operators/scale_op.h" -REGISTER_OP_GPU_KERNEL(identity, paddle::operators::IdentityKernel); +REGISTER_OP_GPU_KERNEL( + scale, paddle::operators::ScaleKernel); diff --git a/paddle/operators/identity_op.h b/paddle/operators/scale_op.h similarity index 66% rename from paddle/operators/identity_op.h rename to paddle/operators/scale_op.h index 14a832257b1..aea64f1b042 100644 --- a/paddle/operators/identity_op.h +++ b/paddle/operators/scale_op.h @@ -14,17 +14,25 @@ #pragma once +#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" -#include "paddle/memory/memcpy.h" + namespace paddle { namespace operators { -template -class IdentityKernel : public framework::OpKernel { +template +class ScaleKernel : public framework::OpKernel { public: virtual void Compute(const framework::ExecutionContext& context) const { auto* tensor = context.Output("Out"); auto* in = context.Input("X"); - tensor->CopyFrom(*in, in->place()); + tensor->mutable_data(in->place()); + + auto scale = static_cast(context.op_.GetAttr("scale")); + + auto eigen_out = framework::EigenVector::Flatten(*tensor); + auto eigen_in = framework::EigenVector::Flatten(*in); + auto& dev = context.GetEigenDevice(); + eigen_out.device(dev) = scale * eigen_in; } }; diff --git a/python/paddle/v2/framework/tests/CMakeLists.txt b/python/paddle/v2/framework/tests/CMakeLists.txt index cf7baa55569..0e8811bfe76 100644 --- a/python/paddle/v2/framework/tests/CMakeLists.txt +++ b/python/paddle/v2/framework/tests/CMakeLists.txt @@ -27,4 +27,4 @@ py_test(test_uniform_random_op SRCS test_uniform_random_op.py) py_test(test_recurrent_op SRCS test_recurrent_op.py) py_test(test_sgd_op SRCS test_sgd_op.py) py_test(test_gradient_checker SRCS test_gradient_checker.py) -py_test(test_identity_op SRCS test_identity_op.py) +py_test(test_scale_and_identity_op SRCS test_scale_and_identity_op.py) diff --git a/python/paddle/v2/framework/tests/gradient_checker.py b/python/paddle/v2/framework/tests/gradient_checker.py index 8b8e2f444be..c22c6f8831b 100644 --- a/python/paddle/v2/framework/tests/gradient_checker.py +++ b/python/paddle/v2/framework/tests/gradient_checker.py @@ -160,8 +160,13 @@ class GradientChecker(unittest.TestCase): grad_tensor.set(data, place) # run backward op - for name in backward_op.outputs(): + backward_outs = backward_op.outputs() + backward_names = [ + item for key in backward_outs for item in backward_outs[key] + ] + for name in backward_names: scope.new_var(name) + backward_op.infer_shape(scope) backward_op.run(scope, ctx) diff --git a/python/paddle/v2/framework/tests/test_identity_op.py b/python/paddle/v2/framework/tests/test_scale_and_identity_op.py similarity index 51% rename from python/paddle/v2/framework/tests/test_identity_op.py rename to python/paddle/v2/framework/tests/test_scale_and_identity_op.py index 181d9c0c216..69b301c376e 100644 --- a/python/paddle/v2/framework/tests/test_identity_op.py +++ b/python/paddle/v2/framework/tests/test_scale_and_identity_op.py @@ -2,6 +2,7 @@ import unittest from op_test_util import OpTestMeta from gradient_checker import GradientChecker, create_op import numpy as np +from paddle.v2.framework.op import Operator class IdentityTest(unittest.TestCase): @@ -20,5 +21,23 @@ class IdentityGradOpTest(GradientChecker): self.check_grad(op, inputs, set("X"), "Out") +class ScaleTest(unittest.TestCase): + __metaclass__ = OpTestMeta + + def setUp(self): + self.type = "scale" + self.inputs = {'X': np.random.random((32, 784)).astype("float32")} + self.attrs = {'scale': -2.3} + self.outputs = {'Out': self.inputs['X'] * self.attrs['scale']} + + +class ScaleGradTest(GradientChecker): + def test_normal(self): + op = Operator("scale", X="X", Out="Out", scale=3.2) + self.check_grad(op, + {"X": np.random.random((10, 10)).astype("float32")}, + set("X"), "Out") + + if __name__ == '__main__': unittest.main() -- GitLab