diff --git a/cmake/util.cmake b/cmake/util.cmake index e814cad36f2a8ce95a2dc9fabc35cb39506d4cd7..ac911052eb970c5a3e485e3178dd788b1517ca30 100644 --- a/cmake/util.cmake +++ b/cmake/util.cmake @@ -25,7 +25,7 @@ function(target_circle_link_libraries TARGET_NAME) endif() endforeach() if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang") - if(IOS AND NOT IOS_ENABLE_BITCODE) + if(NOT IOS_ENABLE_BITCODE) list(APPEND LIBS "-undefined dynamic_lookup") endif() endif() diff --git a/paddle/framework/ddim.cc b/paddle/framework/ddim.cc index fc3d508553c0e966978b28d58127bdbff10d45f1..a3357867530c110df16a5f3ec8c799735206cc71 100644 --- a/paddle/framework/ddim.cc +++ b/paddle/framework/ddim.cc @@ -292,5 +292,13 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) { DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } +DDim stride(const DDim& ddim) { + std::vector strides(ddim.size()); + strides[ddim.size() - 1] = 1; + for (int i = ddim.size() - 2; i >= 0; --i) { + strides[i] = strides[i + 1] * ddim[i + 1]; + } + return framework::make_ddim(strides); +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/ddim.h b/paddle/framework/ddim.h index ca29e7e8c7776de6adf3e3b0e8f11f0d4d8487c3..4a871bb0a91ed4050847509cc3f24218bcd57142 100644 --- a/paddle/framework/ddim.h +++ b/paddle/framework/ddim.h @@ -121,6 +121,7 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims); DDim flatten_to_1d(const DDim& src); +DDim stride(const DDim& ddim); } // namespace framework } // namespace paddle diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 90c7171419888612a48d929ed85039b16384a573..f8b0bce6815ff17a60ef64b0eec34a7cc9d16e72 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -55,6 +55,13 @@ function(op_library TARGET) set(pybind_flag 1) endif() + # activation_op contains several operators + if ("${TARGET}" STREQUAL "activation_op") + set(pybind_flag 1) + # It's enough to just adding one operator to pybind + file(APPEND ${pybind_file} "USE_OP(sigmoid);\n") + endif() + # pybind USE_NO_KERNEL_OP file(READ ${TARGET}.cc TARGET_CONTENT) string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}") diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..cc55767cef9552475321bcb8c06d74a8d91dc99b --- /dev/null +++ b/paddle/operators/activation_op.cc @@ -0,0 +1,306 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/activation_op.h" + +namespace paddle { +namespace operators { + +class ActivationOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + ctx.Output("Y")->Resize( + ctx.Input("X")->dims()); + } +}; + +class ActivationOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + ctx.Output(framework::GradVarName("X")) + ->Resize(ctx.Input("Y")->dims()); + } +}; + +class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SigmoidOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Sigmoid operator"); + AddOutput("Y", "Output of Sigmoid operator"); + AddComment("Sigmoid activation operator, sigmoid = 1 / (1 + exp(-x))"); + } +}; + +class ExpOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Exp operator"); + AddOutput("Y", "Output of Exp operator"); + AddComment("Exp activation operator, exp(x) = e^x"); + } +}; + +class ReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Relu operator"); + AddOutput("Y", "Output of Relu operator"); + AddComment("Relu activation operator, relu(x) = max(x, 0)"); + } +}; + +class TanhOpMaker : public framework::OpProtoAndCheckerMaker { + public: + TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Tanh operator"); + AddOutput("Y", "Output of Tanh operator"); + AddComment( + "Tanh activation operator, tanh = (exp(x) - exp(-x)) / (exp(x) + " + "exp(-x))"); + } +}; + +class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Sqrt operator"); + AddOutput("Y", "Output of Sqrt operator"); + AddComment("Sqrt activation operator, sqrt(x) = x^(1/2)"); + } +}; + +class AbsOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AbsOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Abs operator"); + AddOutput("Y", "Output of Abs operator"); + AddComment("Abs activation operator, abs(x) = |x|"); + } +}; + +class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker { + public: + ReciprocalOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Reciprocal operator"); + AddOutput("Y", "Output of Reciprocal operator"); + AddComment("Reciprocal activation operator, reciprocal(x) = 1 / x"); + } +}; + +class LogOpMaker : public framework::OpProtoAndCheckerMaker { + public: + LogOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Log operator"); + AddOutput("Y", "Output of Log operator"); + AddComment("Log activation operator, log(x) = natural logarithm of x"); + } +}; + +class SquareOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SquareOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Square operator"); + AddOutput("Y", "Output of Square operator"); + AddComment("Square activation operator, square(x) = x^2"); + } +}; + +template +class BReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + BReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of BRelu operator"); + AddOutput("Y", "Output of BRelu operator"); + AddComment("BRelu activation operator, brelu = max(min(x, t_min), t_max)"); + AddAttr("t_min", "The min marginal value of BRelu") + .SetDefault(static_cast(0)); + AddAttr("t_max", "The max marginal value of BRelu") + .SetDefault(static_cast(24)); + } +}; + +template +class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { + public: + SoftReluOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of SoftRelu operator"); + AddOutput("Y", "Output of SoftRelu operator"); + AddComment( + "SoftRelu activation operator, soft_relu = log(1 + exp(max(min(x, " + "threshold), threshold)))"); + AddAttr("threshold", "The threshold value of SoftRelu") + .SetDefault(static_cast(40)); + } +}; + +template +class PowOpMaker : public framework::OpProtoAndCheckerMaker { + public: + PowOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of Pow operator"); + AddOutput("Y", "Output of Pow operator"); + AddComment("Pow activation operator, pow(x, factor) = x^factor"); + AddAttr("factor", "The exponential factor of Pow") + .SetDefault(static_cast(1)); + } +}; + +template +class STanhOpMaker : public framework::OpProtoAndCheckerMaker { + public: + STanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "Input of STanh operator"); + AddOutput("Y", "Output of STanh operator"); + AddComment("STanh activation operator, stanh = b * tanh(a * x)"); + AddAttr("scale_a", "The scale parameter of a for the input") + .SetDefault(static_cast(2 / 3)); + AddAttr("scale_b", "The scale parameter of b for the input") + .SetDefault(static_cast(1.7159)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(sigmoid, + ops::ActivationKernel>); +REGISTER_OP_CPU_KERNEL( + sigmoid_grad, ops::ActivationGradKernel>); + +REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL( + exp, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL(exp_grad, + ops::ActivationGradKernel); + +REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(relu, + ops::ActivationKernel>); +REGISTER_OP_CPU_KERNEL( + relu_grad, ops::ActivationGradKernel>); + +REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL( + tanh, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL( + tanh_grad, ops::ActivationGradKernel>); + +REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL( + sqrt, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL( + sqrt_grad, ops::ActivationGradKernel>); + +REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL( + abs, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL(abs_grad, + ops::ActivationGradKernel); + +REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker, + reciprocal_grad, ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(reciprocal, + ops::ActivationKernel>); +REGISTER_OP_CPU_KERNEL( + reciprocal_grad, + ops::ActivationGradKernel>); + +REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL( + log, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL( + log_grad, ops::ActivationGradKernel>); + +REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(square, + ops::ActivationKernel); +REGISTER_OP_CPU_KERNEL( + square_grad, ops::ActivationGradKernel>); + +REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker, brelu_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(brelu, + ops::BReluKernel); +REGISTER_OP_CPU_KERNEL(brelu_grad, + ops::BReluGradKernel); + +REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker, + soft_relu_grad, ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(soft_relu, + ops::SoftReluKernel); +REGISTER_OP_CPU_KERNEL( + soft_relu_grad, ops::SoftReluGradKernel); + +REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker, pow_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(pow, ops::PowKernel); +REGISTER_OP_CPU_KERNEL(pow_grad, + ops::PowGradKernel); + +REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad, + ops::ActivationOpGrad); +REGISTER_OP_CPU_KERNEL(stanh, + ops::STanhKernel); +REGISTER_OP_CPU_KERNEL(stanh_grad, + ops::STanhGradKernel); diff --git a/paddle/operators/activation_op.cu b/paddle/operators/activation_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..feed1302b292a546f88fa35457c86aa2cfdaa307 --- /dev/null +++ b/paddle/operators/activation_op.cu @@ -0,0 +1,100 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/activation_op.h" + +namespace ops = paddle::operators; + +REGISTER_OP_GPU_KERNEL(sigmoid, + ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + sigmoid_grad, ops::ActivationGradKernel>); + +REGISTER_OP_GPU_KERNEL( + exp, + ops::ActivationKernel); +REGISTER_OP_GPU_KERNEL(exp_grad, + ops::ActivationGradKernel); +REGISTER_OP_GPU_KERNEL(relu, + ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + relu_grad, ops::ActivationGradKernel>); + +REGISTER_OP_GPU_KERNEL( + tanh, + ops::ActivationKernel); +REGISTER_OP_GPU_KERNEL( + tanh_grad, ops::ActivationGradKernel>); + +REGISTER_OP_GPU_KERNEL( + sqrt, + ops::ActivationKernel); +REGISTER_OP_GPU_KERNEL( + sqrt_grad, ops::ActivationGradKernel>); + +REGISTER_OP_GPU_KERNEL( + abs, + ops::ActivationKernel); +REGISTER_OP_GPU_KERNEL(abs_grad, + ops::ActivationGradKernel); + +REGISTER_OP_GPU_KERNEL(reciprocal, + ops::ActivationKernel>); +REGISTER_OP_GPU_KERNEL( + reciprocal_grad, + ops::ActivationGradKernel>); + +REGISTER_OP_GPU_KERNEL( + log, + ops::ActivationKernel); +REGISTER_OP_GPU_KERNEL( + log_grad, ops::ActivationGradKernel>); + +REGISTER_OP_GPU_KERNEL(square, + ops::ActivationKernel); +REGISTER_OP_GPU_KERNEL( + square_grad, ops::ActivationGradKernel>); + +REGISTER_OP_GPU_KERNEL(brelu, + ops::BReluKernel); +REGISTER_OP_GPU_KERNEL(brelu_grad, + ops::BReluGradKernel); + +REGISTER_OP_GPU_KERNEL(soft_relu, + ops::SoftReluKernel); +REGISTER_OP_GPU_KERNEL( + soft_relu_grad, ops::SoftReluGradKernel); + +REGISTER_OP_GPU_KERNEL(pow, ops::PowKernel); +REGISTER_OP_GPU_KERNEL(pow_grad, + ops::PowGradKernel); + +REGISTER_OP_GPU_KERNEL(stanh, + ops::STanhKernel); +REGISTER_OP_GPU_KERNEL(stanh_grad, + ops::STanhGradKernel); diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h new file mode 100644 index 0000000000000000000000000000000000000000..15f8afb4ba45cc989fe7576b82b8bf853b1df7de --- /dev/null +++ b/paddle/operators/activation_op.h @@ -0,0 +1,353 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class ActivationKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); + Y->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto place = context.GetEigenDevice(); + Functor functor; + functor(place, x, y); + } +}; + +template +class ActivationGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); + dX->mutable_data(context.GetPlace()); + + auto dy = framework::EigenVector::Flatten(*dY); + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + Functor functor; + functor(place, x, y, dy, dx); + } +}; + +// sigmoid(x) = 1 / (1 + exp(-x)) +template +struct SigmoidFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = static_cast(1) / (static_cast(1) + (-x).exp()); + } +}; + +template +struct SigmoidGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * y * (static_cast(1) - y); + } +}; + +// exp(x) = e^x +struct ExpFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.exp(); + } +}; + +struct ExpGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * y; + } +}; + +// relu(x) = max(x, 0) +template +struct ReluFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.cwiseMax(static_cast(0)); + } +}; + +template +struct ReluGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (x > static_cast(0)).template cast(); + } +}; + +// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) +struct TanhFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.tanh(); + } +}; + +template +struct TanhGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (static_cast(1) - y * y); + } +}; + +// sqrt(x) = x^(1/2) +struct SqrtFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.sqrt(); + } +}; + +template +struct SqrtGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + const Y y_conj = Eigen::numext::conj(y); + dx.device(d) = static_cast(0.5) * dy / y_conj; + } +}; + +// abs(x) = |x| +struct AbsFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.abs(); + } +}; + +struct AbsGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * x.sign(); + } +}; + +// reciprocal(x) = 1 / x +template +struct ReciprocalFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = static_cast(1) / x; + } +}; + +template +struct ReciprocalGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * static_cast(-1) * y * y; + } +}; + +// log(x) = natural logarithm of x +struct LogFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.log(); + } +}; + +template +struct LogGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * (static_cast(1) / x); + } +}; + +// square(x) = x^2 +struct SquareFunctor { + template + void operator()(Device d, X x, Y y) { + y.device(d) = x.square(); + } +}; + +template +struct SquareGradFunctor { + template + void operator()(Device d, X x, Y y, dY dy, dX dx) { + dx.device(d) = dy * static_cast(2) * x; + } +}; + +template +class BReluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); + auto t_min = static_cast(context.Attr("t_min")); + auto t_max = static_cast(context.Attr("t_max")); + Y->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto place = context.GetEigenDevice(); + y.device(place) = x.cwiseMax(t_min).cwiseMin(t_max); + } +}; + +template +class BReluGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); + auto t_min = static_cast(context.Attr("t_min")); + auto t_max = static_cast(context.Attr("t_max")); + dX->mutable_data(context.GetPlace()); + + auto dy = framework::EigenVector::Flatten(*dY); + auto x = framework::EigenVector::Flatten(*X); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + + dx.device(place) = dy * ((x > t_min) * (x < t_max)).template cast(); + } +}; + +template +class SoftReluKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); + auto threshold = static_cast(context.Attr("threshold")); + Y->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto place = context.GetEigenDevice(); + auto temp = x.cwiseMax(-threshold).cwiseMin(threshold).eval(); + y.device(place) = (static_cast(1) + temp.exp()).log(); + } +}; + +template +class SoftReluGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Input("Y"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); + auto threshold = static_cast(context.Attr("threshold")); + dX->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto dy = framework::EigenVector::Flatten(*dY); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + auto temp = ((x > -threshold) * (x < threshold)).template cast().eval(); + dx.device(place) = dy * (static_cast(1) - (-y).exp()) * temp; + } +}; + +template +class PowKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); + auto factor = static_cast(context.Attr("factor")); + Y->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto place = context.GetEigenDevice(); + y.device(place) = x.pow(factor); + } +}; + +template +class PowGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); + auto factor = static_cast(context.Attr("factor")); + dX->mutable_data(context.GetPlace()); + + auto dy = framework::EigenVector::Flatten(*dY); + auto x = framework::EigenVector::Flatten(*X); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + + dx.device(place) = dy * factor * x.pow(factor - static_cast(1)); + } +}; + +template +class STanhKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* Y = context.Output("Y"); + auto scale_a = static_cast(context.Attr("scale_a")); + auto scale_b = static_cast(context.Attr("scale_b")); + Y->mutable_data(context.GetPlace()); + + auto x = framework::EigenVector::Flatten(*X); + auto y = framework::EigenVector::Flatten(*Y); + auto place = context.GetEigenDevice(); + y.device(place) = scale_b * (scale_a * x).tanh(); + } +}; + +template +class STanhGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* X = context.Input("X"); + auto* dY = context.Input(framework::GradVarName("Y")); + auto* dX = context.Output(framework::GradVarName("X")); + auto scale_a = static_cast(context.Attr("scale_a")); + auto scale_b = static_cast(context.Attr("scale_b")); + dX->mutable_data(context.GetPlace()); + + auto dy = framework::EigenVector::Flatten(*dY); + auto x = framework::EigenVector::Flatten(*X); + auto dx = framework::EigenVector::Flatten(*dX); + auto place = context.GetEigenDevice(); + + auto temp = (scale_a * x).tanh() * (scale_a * x).tanh(); + dx.device(place) = dy * scale_a * scale_b * (static_cast(1) - temp); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/crop_op.cc b/paddle/operators/crop_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ed21f336f69e494f3c4039c609c83407a80cd8c --- /dev/null +++ b/paddle/operators/crop_op.cc @@ -0,0 +1,139 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/crop_op.h" +#include + +namespace paddle { +namespace operators { + +using framework::Tensor; +using framework::LoDTensor; + +class CropOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of CropOp should not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), + "Output(Out) of CropOp should not be null."); + auto x_dim = ctx.Input("X")->dims(); + auto *y = ctx.Input("Y"); + auto *out = ctx.Output("Out"); + if (y == nullptr) { + auto shape = Attr>("shape"); + PADDLE_ENFORCE_EQ( + int64_t(shape.size()), x_dim.size(), + "Shape size should be equal to dimention size of input tensor."); + std::vector tensor_shape(shape.size()); + for (size_t i = 0; i < shape.size(); ++i) { + tensor_shape[i] = static_cast(shape[i]); + } + out->Resize(framework::make_ddim(tensor_shape)); + } else { + PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()), + "Tensor rank of both CropOp's " + "inputs must be same."); + out->Resize(y->dims()); + } + } +}; + +class CropOpMaker : public framework::OpProtoAndCheckerMaker { + public: + CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The input of pad op. " + "The input should be a k-D tensor(k > 0 and k < 7)"); + AddInput("Y", + "The input used as reference for cropping" + " with the same dimension as X. "); + AddOutput("Out", + "The output of crop op " + "with the same dimension as X."); + AddAttr>("offsets", + "A list describing offsets to be cropped." + "The size of offsets list should be as same as " + "dimension size of input X."); + AddAttr>("shape", + "A list describing the shape of output." + "The size of shape list should be as same as " + "dimension size of input X.") + .SetDefault(std::vector()); + AddComment(R"DOC( +Crop Operator. +Crop input into output, as specified by offsets and shape. + +There are two ways to set shape: +1. referenc input: crop input X as shape as reference input. + The dimension of reference input should + be as same as input X. +2. shape list: crop input X by shape described by a list. + The size of shape list should be as same as + dimension size of input X. + +The input should be a k-D tensor(k > 0 and k < 7). As an example: + +Given: + + X = [[0, 1, 2, 0, 0] + [0, 3, 4, 0, 0] + [0, 0, 0, 0, 0]] + +and + + offsets = [0, 1] + +and + + shape = [2, 2] + +then we get + + Out = [[1, 2], + [3, 4]] + +)DOC"); + } +}; + +class CropOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + auto x_dims = ctx.Input("X")->dims(); + auto *x_grad = ctx.Output(framework::GradVarName("X")); + if (x_grad != nullptr) { + x_grad->Resize(x_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(crop, ops::CropOp, ops::CropOpMaker, crop_grad, ops::CropOpGrad); +REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel); +REGISTER_OP_CPU_KERNEL(crop_grad, + ops::CropGradKernel); diff --git a/paddle/operators/sigmoid_op.cu b/paddle/operators/crop_op.cu similarity index 72% rename from paddle/operators/sigmoid_op.cu rename to paddle/operators/crop_op.cu index 1a50dfe14a7b9e2614aadb7729de9f9e461e9905..f8ee18a1d6e894cbb2d71dd4b6b459abeb076817 100644 --- a/paddle/operators/sigmoid_op.cu +++ b/paddle/operators/crop_op.cu @@ -13,11 +13,9 @@ limitations under the License. */ #define EIGEN_USE_GPU -#include "paddle/operators/sigmoid_op.h" +#include "paddle/operators/crop_op.h" namespace ops = paddle::operators; - -REGISTER_OP_GPU_KERNEL(sigmoid, - ops::SigmoidKernel); -REGISTER_OP_GPU_KERNEL( - sigmoid_grad, ops::SigmoidGradKernel); +REGISTER_OP_GPU_KERNEL(crop, ops::CropKernel); +REGISTER_OP_GPU_KERNEL(crop_grad, + ops::CropGradKernel); diff --git a/paddle/operators/crop_op.h b/paddle/operators/crop_op.h new file mode 100644 index 0000000000000000000000000000000000000000..2f40c059033ec649b29f6ecdee4fcedd128a63a6 --- /dev/null +++ b/paddle/operators/crop_op.h @@ -0,0 +1,104 @@ +/* Copyright (c) 2016 CropdleCropdle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/strided_memcpy.h" + +namespace paddle { +namespace operators { // Internal + +template +using EigenTensor = framework::EigenTensor; +using framework::Tensor; + +template +class CropKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + const T* x_data = x->data(); + T* out_data = out->mutable_data(context.GetPlace()); + auto x_stride = framework::stride(x->dims()); + auto out_stride = framework::stride(out->dims()); + auto offsets = context.Attr>("offsets"); + PADDLE_ENFORCE_EQ( + x->dims().size(), offsets.size(), + "Offsets size should be equal to dimension size of input tensor."); + int64_t offset = 0; + for (int i = 0; i < offsets.size(); ++i) { + offset += (x_stride[i] * offsets[i]); + } + StridedMemcpy(context.device_context(), x_data + offset, x_stride, + out->dims(), out_stride, out_data); + } +}; + +template +void CropGradFunction(const framework::ExecutionContext& context) { + auto* d_x = context.Output(framework::GradVarName("X")); + if (d_x != nullptr) { + auto* d_out = context.Input(framework::GradVarName("Out")); + d_x->mutable_data(context.GetPlace()); + auto offsets = context.Attr>("offsets"); + Eigen::array, D> paddings; + for (int i = 0; i < D; ++i) { + paddings[i].first = offsets[i]; + paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i]; + } + auto d_x_tensor = EigenTensor::From(*d_x); + auto d_out_tensor = EigenTensor::From(*d_out); + d_x_tensor.device(context.GetEigenDevice()) = + d_out_tensor.pad(paddings, 0); + } +} + +template +class CropGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + size_t rank = + context.Input(framework::GradVarName("Out"))->dims().size(); + switch (rank) { + case 1: + CropGradFunction(context); + break; + case 2: + CropGradFunction(context); + break; + case 3: + CropGradFunction(context); + break; + case 4: + CropGradFunction(context); + break; + case 5: + CropGradFunction(context); + break; + case 6: + CropGradFunction(context); + break; + default: + PADDLE_THROW( + "CropOp only support tensors with no more than 6 dimensions."); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/gemm_conv2d_op.h index 08b7df1dfead72fe8de8e89fa633c7bfc7bdbf33..5c9e81732aa72211c2021382cf9a907880c53c17 100644 --- a/paddle/operators/gemm_conv2d_op.h +++ b/paddle/operators/gemm_conv2d_op.h @@ -75,9 +75,6 @@ class GemmConv2DKernel : public framework::OpKernel { framework::DDim output_matrix_shape = {output_channels, output_height * output_width}; - auto* device_context = - const_cast(context.device_context_); - // convolution operator: im2col + gemm int in_step = input_channels / groups; int out_step = output_channels / groups; @@ -87,14 +84,14 @@ class GemmConv2DKernel : public framework::OpKernel { for (int g = 0; g < groups; g++) { // im2col Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], - device_context); + im2col(context.device_context(), in_slice, col, strides[0], strides[1], + paddings[0], paddings[1]); // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, T(1.0), - &out_slice, T(0.0), device_context); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); } } } @@ -160,9 +157,6 @@ class GemmConvGrad2DKernel : public framework::OpKernel { filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); - auto* device_context = - const_cast(context.device_context_); - // convolution backward input operator: gemm + col2im // convolution backward weight operator: im2col + gemm int in_step = input_channels / groups; @@ -184,14 +178,15 @@ class GemmConvGrad2DKernel : public framework::OpKernel { out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, true, out_grad_slice, false, - T(1.0), &col_matrix, T(0.0), device_context); + math::matmul(context.device_context(), filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); // col2im Tensor in_grad_slice = in_grad_batch.Slice(g * in_step, (g + 1) * in_step); - col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], - paddings[1], device_context); + col2im(context.device_context(), in_grad_slice, col, strides[0], + strides[1], paddings[0], paddings[1]); } } } @@ -212,15 +207,15 @@ class GemmConvGrad2DKernel : public framework::OpKernel { Tensor out_grad_slice = out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); - im2col(in_slice, col, strides[0], strides[1], paddings[0], - paddings[1], device_context); + im2col(context.device_context(), in_slice, col, strides[0], + strides[1], paddings[0], paddings[1]); // gemm Tensor filter_grad_slice = filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(out_grad_slice, false, col_matrix, true, - T(1.0), &filter_grad_slice, T(1.0), - device_context); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); } } } diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 5727c1cab16c1379ffe77f5594c057e93a042785..c08a3380f042886cd400df0d840e61856274619c 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -27,9 +27,10 @@ template class Im2ColFunctor { public: - void operator()(const framework::Tensor& im, framework::Tensor& col, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width, platform::DeviceContext* context) { + int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); @@ -79,9 +80,9 @@ template class Col2ImFunctor { public: - void operator()(framework::Tensor& im, const framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width, platform::DeviceContext* context) { + void operator()(const platform::DeviceContext& context, framework::Tensor& im, + const framework::Tensor& col, int stride_height, + int stride_width, int padding_height, int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -137,9 +138,10 @@ template class Im2ColFunctor { public: - void operator()(const framework::Tensor& im, framework::Tensor& col, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width, platform::DeviceContext* context) { + int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -197,9 +199,9 @@ template class Col2ImFunctor { public: - void operator()(framework::Tensor& im, const framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width, platform::DeviceContext* context) { + void operator()(const platform::DeviceContext& context, framework::Tensor& im, + const framework::Tensor& col, int stride_height, + int stride_width, int padding_height, int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; diff --git a/paddle/operators/math/im2col.cu b/paddle/operators/math/im2col.cu index 9bff7bee3c95093852305d392af0949b831e5665..01f60bfe70f844fdcfd5aa481c27d9f12ec51305 100644 --- a/paddle/operators/math/im2col.cu +++ b/paddle/operators/math/im2col.cu @@ -64,9 +64,10 @@ template class Im2ColFunctor { public: - void operator()(const framework::Tensor& im, framework::Tensor& col, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width, platform::DeviceContext* context) { + int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); @@ -84,9 +85,9 @@ class Im2ColFunctor<<< - grid, threads, 0, - reinterpret_cast(context)->stream()>>>( + im2col<<(context) + .stream()>>>( im.data(), num_outputs, input_height, input_width, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, output_height, output_width, col.data()); @@ -149,9 +150,9 @@ template class Col2ImFunctor { public: - void operator()(framework::Tensor& im, const framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width, platform::DeviceContext* context) { + void operator()(const platform::DeviceContext& context, framework::Tensor& im, + const framework::Tensor& col, int stride_height, + int stride_width, int padding_height, int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); @@ -174,9 +175,9 @@ class Col2ImFunctor<<< - grid, threads, 0, - reinterpret_cast(context)->stream()>>>( + col2im<<(context) + .stream()>>>( num_kernels, col.data(), input_height + 2 * padding_height, input_width + 2 * padding_width, input_channels, filter_height, filter_width, stride_height, stride_width, padding_height, @@ -235,9 +236,10 @@ template class Im2ColFunctor { public: - void operator()(const framework::Tensor& im, framework::Tensor& col, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width, platform::DeviceContext* context) { + int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -268,9 +270,9 @@ class Im2ColFunctor<<< - grid, threads, 0, - reinterpret_cast(context)->stream()>>>( + im2colOCF<<(context) + .stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, output_height, output_width); @@ -318,9 +320,9 @@ template class Col2ImFunctor { public: - void operator()(framework::Tensor& im, const framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width, platform::DeviceContext* context) { + void operator()(const platform::DeviceContext& context, framework::Tensor& im, + const framework::Tensor& col, int stride_height, + int stride_width, int padding_height, int padding_width) { PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(col.dims().size() == 5); int input_channels = im.dims()[0]; @@ -351,9 +353,9 @@ class Col2ImFunctor<<< - grid, threads, 0, - reinterpret_cast(context)->stream()>>>( + col2imOCF<<(context) + .stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, output_height, output_width); diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index 8958c5457cc2c3034c34ca82fb2e98cc06be63c5..7b717e1603c94cd77c74cb0d86f1d23e2692f9d8 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -72,17 +72,18 @@ enum class ColFormat { kCFO = 0, kOCF = 1 }; template class Im2ColFunctor { public: - void operator()(const framework::Tensor& im, framework::Tensor& col, + void operator()(const platform::DeviceContext& context, + const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width, platform::DeviceContext* context); + int padding_width); }; template class Col2ImFunctor { public: - void operator()(framework::Tensor& im, const framework::Tensor& col, - int stride_height, int stride_width, int padding_height, - int padding_width, platform::DeviceContext* context); + void operator()(const platform::DeviceContext& context, framework::Tensor& im, + const framework::Tensor& col, int stride_height, + int stride_width, int padding_height, int padding_width); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 4f380388b108dc173d847f027ba5c9db387a87f8..f0b8c885918afe7f80edc465c6d9be7c11ac066f 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -78,8 +78,8 @@ void testIm2col() { PADDLE_THROW("no GPU support"); #endif // PADDLE_ONLY_CPU } - im2col(input, output_cfo, stride, stride, padding, padding, context); - im2col_ocf(input, output_ocf, stride, stride, padding, padding, context); + im2col(*context, input, output_cfo, stride, stride, padding, padding); + im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding); float* out_cfo_ptr; if (paddle::platform::is_cpu_place(*place)) { diff --git a/paddle/operators/rank_loss_op.cc b/paddle/operators/rank_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..4bba4200728ebf7e7810ed935f6fdf51c96cbc7a --- /dev/null +++ b/paddle/operators/rank_loss_op.cc @@ -0,0 +1,126 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/rank_loss_op.h" + +namespace paddle { +namespace operators { + +class RankLossOp : public framework::OperatorWithKernel { + public: + RankLossOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + // input check + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) shouldn't be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"), + "Input(Left) shouldn't be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"), + "Input(Right) shouldn't be null"); + auto label_dims = ctx.Input("Label")->dims(); + auto left_dims = ctx.Input("Left")->dims(); + auto right_dims = ctx.Input("Right")->dims(); + PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims), + "All inputs must have the same size"); + PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1), + "All inputs must be row vector with size batch_size x 1."); + ctx.Output("Out")->Resize(label_dims); + } +}; + +class RankLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + RankLossOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Label", + "The label indicating A ranked higher than B or not, row vector."); + AddInput("Left", "The output of RankNet for doc A, vector."); + AddInput("Right", "The output of RankNet for doc B, vetor"); + AddOutput("Out", "The output loss of RankLoss operator, vector."); + AddComment(R"DOC(RankLoss operator + +Rank loss operator for RankNet[1]. RankNet is a pairwise ranking model with +one training sample consisting of a pair of doc A and B, and the label P +indicating that A is ranked higher than B or not: + +P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of +the input pair. + +The RankLoss operator contains three inputs: Left (o_i), Right (o_j) and Label +(P_{i,j}), which represent the output of RankNet for two docs and the label +respectively, and yields the rank loss C_{i,j} by following the expression + +\f[ + C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\ + o_{i,j} = o_i - o_j \\ + \tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \} +\f] + +The operator can take inputs of one sample or in batch. + +[1]. Chris Burges, Tal Shaked, Erin Renshaw, et al. Learning to + Rank using Gradient Descent. + http://icml.cc/2015/wp-content/uploads/2015/06/icml_ranking.pdf +)DOC"); + } +}; + +class RankLossGradOp : public framework::OperatorWithKernel { + public: + RankLossGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"), + "Input(Left) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"), + "Input(Right) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + auto dims = ctx.Input("Left")->dims(); + auto *left_grad = + ctx.Output(framework::GradVarName("Left")); + auto *right_grad = + ctx.Output(framework::GradVarName("Right")); + if (left_grad) { + left_grad->Resize(dims); + } + if (right_grad) { + right_grad->Resize(dims); + } + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OP(rank_loss, ops::RankLossOp, ops::RankLossOpMaker, rank_loss_grad, + ops::RankLossGradOp); +REGISTER_OP_CPU_KERNEL(rank_loss, + ops::RankLossKernel); +REGISTER_OP_CPU_KERNEL( + rank_loss_grad, ops::RankLossGradKernel); diff --git a/paddle/operators/rank_loss_op.cu b/paddle/operators/rank_loss_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..779588ff36c792b8925a535d60f1cfbbe3c66d86 --- /dev/null +++ b/paddle/operators/rank_loss_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/rank_loss_op.h" + +REGISTER_OP_GPU_KERNEL( + rank_loss, + paddle::operators::RankLossKernel); +REGISTER_OP_GPU_KERNEL( + rank_loss_grad, + paddle::operators::RankLossGradKernel); diff --git a/paddle/operators/rank_loss_op.h b/paddle/operators/rank_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..9776d123fe4b0cb0cd16a15770fcf42a966fa011 --- /dev/null +++ b/paddle/operators/rank_loss_op.h @@ -0,0 +1,80 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class RankLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + auto* label_t = ctx.Input("Label"); + auto* left_t = ctx.Input("Left"); + auto* right_t = ctx.Input("Right"); + out_t->mutable_data(ctx.GetPlace()); + + auto out = framework::EigenVector::Flatten(*out_t); + auto label = framework::EigenVector::Flatten(*label_t); + auto left = framework::EigenVector::Flatten(*left_t); + auto right = framework::EigenVector::Flatten(*right_t); + + auto& dev = ctx.GetEigenDevice(); + out.device(dev) = + (1. + (left - right).exp()).log() - label * (left - right); + } +}; + +template +class RankLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_left_t = + ctx.Output(framework::GradVarName("Left")); + auto* d_right_t = + ctx.Output(framework::GradVarName("Right")); + + auto* d_out_t = ctx.Input(framework::GradVarName("Out")); + auto* label_t = ctx.Input("Label"); + auto* left_t = ctx.Input("Left"); + auto* right_t = ctx.Input("Right"); + + auto& dev = ctx.GetEigenDevice(); + auto d_out = framework::EigenVector::Flatten(*d_out_t); + auto label = framework::EigenVector::Flatten(*label_t); + auto left = framework::EigenVector::Flatten(*left_t); + auto right = framework::EigenVector::Flatten(*right_t); + + // compute d_left + if (d_left_t) { + d_left_t->mutable_data(ctx.GetPlace()); + auto d_left = framework::EigenVector::Flatten(*d_left_t); + d_left.device(dev) = d_out * (1. / (1. + (right - left).exp()) - label); + } + // compute d_right + if (d_right_t) { + d_right_t->mutable_data(ctx.GetPlace()); + auto d_right = framework::EigenVector::Flatten(*d_right_t); + d_right.device(dev) = + -d_out * (1.0 / (1. + (right - left).exp()) - label); + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/sigmoid_op.cc b/paddle/operators/sigmoid_op.cc deleted file mode 100644 index 992b19965e0ca9ce7dba1b8b3c5b7780af06eb45..0000000000000000000000000000000000000000 --- a/paddle/operators/sigmoid_op.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#include "paddle/operators/sigmoid_op.h" - -namespace paddle { -namespace operators { - -class SigmoidOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), - "Input(X) of SigmoidOp should not be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"), - "Output(Y) of SigmoidOp should not be null."); - - ctx.Output("Y")->Resize( - ctx.Input("X")->dims()); - } -}; - -class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { - public: - SigmoidOpMaker(framework::OpProto *proto, - framework::OpAttrChecker *op_checker) - : OpProtoAndCheckerMaker(proto, op_checker) { - AddInput("X", "sigmoid input"); - AddOutput("Y", "sigmoid output"); - AddComment("Sigmoid function"); - } -}; - -class SigmoidOpGrad : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - ctx.Output(framework::GradVarName("X")) - ->Resize(ctx.Input("Y")->dims()); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad, - ops::SigmoidOpGrad); -REGISTER_OP_CPU_KERNEL(sigmoid, - ops::SigmoidKernel); -REGISTER_OP_CPU_KERNEL( - sigmoid_grad, ops::SigmoidGradKernel); diff --git a/paddle/operators/sigmoid_op.h b/paddle/operators/sigmoid_op.h deleted file mode 100644 index b01a9b3f23283471f8846325075719ba0e75ed35..0000000000000000000000000000000000000000 --- a/paddle/operators/sigmoid_op.h +++ /dev/null @@ -1,62 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. */ - -#pragma once -#include "paddle/framework/eigen.h" -#include "paddle/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -template -using EigenVector = framework::EigenVector; - -template -class SigmoidKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto input = context.Input("X"); - auto output = context.Output("Y"); - output->mutable_data(context.GetPlace()); - - // The clipping is used in Paddle's raw implenmention - auto X = EigenVector::Flatten(*input); - auto Y = EigenVector::Flatten(*output); - auto place = context.GetEigenDevice(); - - Y.device(place) = 1. / (1. + (-X).exp()); - } -}; - -template -class SigmoidGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto Y_t = context.Input("Y"); - auto dY_t = context.Input(framework::GradVarName("Y")); - auto dX_t = context.Output(framework::GradVarName("X")); - - dX_t->mutable_data(context.GetPlace()); - - auto dX = EigenVector::Flatten(*dX_t); - auto Y = EigenVector::Flatten(*Y_t); - auto dY = EigenVector::Flatten(*dY_t); - dX.device(context.GetEigenDevice()) = dY * Y * (1. - Y); - } -}; - -} // namespace operators -} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_activation_op.py b/python/paddle/v2/framework/tests/test_activation_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8f6d2be17758b7f6604d2db74fe466fb30695bd5 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_activation_op.py @@ -0,0 +1,223 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestExp(OpTest): + def setUp(self): + self.op_type = "exp" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.exp(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestSigmoid(OpTest): + def setUp(self): + self.op_type = "sigmoid" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.008) + + +class TestTanh(OpTest): + def setUp(self): + self.op_type = "tanh" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.tanh(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestSqrt(OpTest): + def setUp(self): + self.op_type = "sqrt" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.sqrt(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestAbs(OpTest): + def setUp(self): + self.op_type = "abs" + x = np.random.uniform(-1, 1, [4, 4]).astype("float32") + # Because we set delta = 0.005 in caculating numeric gradient, + # if x is too small, such as 0.002, x_neg will be -0.003 + # x_pos will be 0.007, so the numeric gradient is unaccurate. + # we should avoid this + x[np.abs(x) < 0.005] = 0.02 + self.inputs = {'X': x} + self.outputs = {'Y': np.abs(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestRelu(OpTest): + def setUp(self): + self.op_type = "relu" + x = np.random.uniform(-1, 1, [11, 17]).astype("float32") + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + self.inputs = {'X': x} + self.outputs = {'Y': np.maximum(self.inputs['X'], 0)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestBRelu(OpTest): + def setUp(self): + self.op_type = "brelu" + x = np.random.uniform(-1, 1, [4, 4]).astype("float32") + t_min = 1 + t_max = 4 + # The same with TestAbs + x[np.abs(x - t_min) < 0.005] = t_min + 0.02 + x[np.abs(x - t_max) < 0.005] = t_max + 0.02 + + self.inputs = {'X': x} + self.attrs = {'t_min': t_min, 't_max': t_max} + t = np.copy(x) + t[t < t_min] = t_min + t[t > t_max] = t_max + self.outputs = {'Y': t} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.02) + + +class TestSoftRelu(OpTest): + def setUp(self): + self.op_type = "soft_relu" + x = np.random.uniform(-3, 3, [4, 4]).astype("float32") + threshold = 2 + # The same reason with TestAbs + x[np.abs(x - threshold) < 0.005] = threshold + 0.02 + x[np.abs(x + threshold) < 0.005] = -threshold + 0.02 + self.inputs = {'X': x} + self.attrs = {'threshold': threshold} + t = np.copy(x) + t[t < -threshold] = -threshold + t[t > threshold] = threshold + self.outputs = {'Y': np.log((np.exp(t) + 1))} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.02) + + +class TestReciprocal(OpTest): + def setUp(self): + self.op_type = "reciprocal" + self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} + self.outputs = {'Y': np.reciprocal(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.01) + + +class TestLog(OpTest): + def setUp(self): + self.op_type = "log" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.log(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestSquare(OpTest): + def setUp(self): + self.op_type = "square" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + self.outputs = {'Y': np.square(self.inputs['X'])} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +class TestPow(OpTest): + def setUp(self): + self.op_type = "pow" + self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} + self.attrs = {'factor': 3} + self.outputs = {'Y': np.power(self.inputs['X'], 3)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.02) + + +class TestSTanh(OpTest): + def setUp(self): + self.op_type = "stanh" + self.inputs = { + 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") + } + scale_a = 2.0 / 3.0 + scale_b = 1.7159 + self.attrs = {'scale_a': scale_a, 'scale_b': scale_b} + self.outputs = {'Y': scale_b * np.tanh(self.inputs['X'] * scale_a)} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Y', max_relative_error=0.007) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_conv2d_op.py b/python/paddle/v2/framework/tests/test_conv2d_op.py index 3142a60a1ae7d1874d02b81a4bb90c1fc50d07b9..118a5fc1cde5f4a908b065d581956e0855d50a52 100644 --- a/python/paddle/v2/framework/tests/test_conv2d_op.py +++ b/python/paddle/v2/framework/tests/test_conv2d_op.py @@ -73,13 +73,22 @@ class TestConv2dOp(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(set(['Input', 'Filter']), 'Output') + self.check_grad( + set(['Input', 'Filter']), 'Output', max_relative_error=0.05) def test_check_grad_no_filter(self): - self.check_grad(['Input'], 'Output', no_grad_set=set(['Filter'])) + self.check_grad( + ['Input'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Filter'])) def test_check_grad_no_input(self): - self.check_grad(['Filter'], 'Output', no_grad_set=set(['Input'])) + self.check_grad( + ['Filter'], + 'Output', + max_relative_error=0.05, + no_grad_set=set(['Input'])) def init_groups(self): self.groups = 1 diff --git a/python/paddle/v2/framework/tests/test_crop_op.py b/python/paddle/v2/framework/tests/test_crop_op.py new file mode 100644 index 0000000000000000000000000000000000000000..62c883bdc130021d06c33ded9c2865505da0b719 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_crop_op.py @@ -0,0 +1,91 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def crop(data, offsets, crop_shape): + def indexOf(shape, index): + result = [] + for dim in reversed(shape): + result.append(index % dim) + index = index / dim + return result[::-1] + + result = [] + for i, value in enumerate(data.flatten()): + index = indexOf(data.shape, i) + selected = True + if len(index) == len(offsets): + for j, offset in enumerate(offsets): + selected = selected and index[j] >= offset and index[ + j] < crop_shape[j] + offset + if selected: + result.append(value) + return np.array(result).reshape(crop_shape) + + +class TestCropOp(OpTest): + def setUp(self): + self.op_type = "crop" + self.crop_by_input = False + self.attrs = {} + self.initTestCase() + self.attrs['offsets'] = self.offsets + if self.crop_by_input: + self.inputs = { + 'X': np.random.random(self.x_shape).astype("float32"), + 'Y': np.random.random(self.crop_shape).astype("float32") + } + else: + self.attrs['shape'] = self.crop_shape + self.inputs = { + 'X': np.random.random(self.x_shape).astype("float32"), + } + self.outputs = { + 'Out': crop(self.inputs['X'], self.offsets, self.crop_shape) + } + + def initTestCase(self): + self.x_shape = (8, 8) + self.crop_shape = (2, 2) + self.offsets = [1, 2] + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out', max_relative_error=0.006) + + +class TestCase1(TestCropOp): + def initTestCase(self): + self.x_shape = (16, 8, 32) + self.crop_shape = [2, 2, 3] + self.offsets = [1, 5, 3] + + +class TestCase2(TestCropOp): + def initTestCase(self): + self.x_shape = (4, 8) + self.crop_shape = [4, 8] + self.offsets = [0, 0] + + +class TestCase3(TestCropOp): + def initTestCase(self): + self.x_shape = (4, 8, 16) + self.crop_shape = [2, 2, 3] + self.offsets = [1, 5, 3] + self.crop_by_input = True + + +class TestCase4(TestCropOp): + def initTestCase(self): + self.x_shape = (4, 4) + self.crop_shape = [4, 4] + self.offsets = [0, 0] + self.crop_by_input = True + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_rank_loss_op.py b/python/paddle/v2/framework/tests/test_rank_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..0e41ab1b3fd8fa8b62c5f3b914b752918119a265 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_rank_loss_op.py @@ -0,0 +1,32 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestRankLossOp(OpTest): + def setUp(self): + self.op_type = "rank_loss" + batch_size = 5 + # labels_{i} = {0, 1.0} or {0, 0.5, 1.0} + label = np.random.randint(0, 2, size=(batch_size, 1)).astype("float32") + left = np.random.random((batch_size, 1)).astype("float32") + right = np.random.random((batch_size, 1)).astype("float32") + loss = np.log(1.0 + np.exp(left - right)) - label * (left - right) + self.inputs = {'Label': label, 'Left': left, 'Right': right} + self.outputs = {'Out': loss} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(["Left", "Right"], "Out") + + def test_check_grad_ignore_left(self): + self.check_grad(["Right"], "Out", no_grad_set=set('Left')) + + def test_check_grad_ignore_right(self): + self.check_grad(["Left"], "Out", no_grad_set=set('Right')) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_sigmoid_op.py b/python/paddle/v2/framework/tests/test_sigmoid_op.py deleted file mode 100644 index d65d887db4af58c40e4e78fdbfd8e8ee668b7ee3..0000000000000000000000000000000000000000 --- a/python/paddle/v2/framework/tests/test_sigmoid_op.py +++ /dev/null @@ -1,22 +0,0 @@ -import unittest -import numpy as np -from op_test import OpTest - - -class TestSigmoidOp(OpTest): - def setUp(self): - self.op_type = "sigmoid" - self.inputs = { - 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") - } - self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(["X"], "Y", max_relative_error=0.007) - - -if __name__ == '__main__': - unittest.main()