提交 5f2cd1a4 编写于 作者: F fengjiayi

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_program_proto

......@@ -25,7 +25,7 @@ function(target_circle_link_libraries TARGET_NAME)
endif()
endforeach()
if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "Clang" OR "${CMAKE_CXX_COMPILER_ID}" STREQUAL "AppleClang")
if(IOS AND NOT IOS_ENABLE_BITCODE)
if(NOT IOS_ENABLE_BITCODE)
list(APPEND LIBS "-undefined dynamic_lookup")
endif()
endif()
......
......@@ -292,5 +292,13 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) {
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
DDim stride(const DDim& ddim) {
std::vector<int64_t> strides(ddim.size());
strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1];
}
return framework::make_ddim(strides);
}
} // namespace framework
} // namespace paddle
......@@ -121,6 +121,7 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims);
DDim flatten_to_1d(const DDim& src);
DDim stride(const DDim& ddim);
} // namespace framework
} // namespace paddle
......
......@@ -55,6 +55,13 @@ function(op_library TARGET)
set(pybind_flag 1)
endif()
# activation_op contains several operators
if ("${TARGET}" STREQUAL "activation_op")
set(pybind_flag 1)
# It's enough to just adding one operator to pybind
file(APPEND ${pybind_file} "USE_OP(sigmoid);\n")
endif()
# pybind USE_NO_KERNEL_OP
file(READ ${TARGET}.cc TARGET_CONTENT)
string(REGEX MATCH "OperatorWithKernel" regex_result "${TARGET_CONTENT}")
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/activation_op.h"
namespace paddle {
namespace operators {
class ActivationOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::LoDTensor>("Y")->Resize(
ctx.Input<framework::Tensor>("X")->dims());
}
};
class ActivationOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
->Resize(ctx.Input<framework::Tensor>("Y")->dims());
}
};
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sigmoid operator");
AddOutput("Y", "Output of Sigmoid operator");
AddComment("Sigmoid activation operator, sigmoid = 1 / (1 + exp(-x))");
}
};
class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ExpOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Exp operator");
AddOutput("Y", "Output of Exp operator");
AddComment("Exp activation operator, exp(x) = e^x");
}
};
class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu operator");
AddOutput("Y", "Output of Relu operator");
AddComment("Relu activation operator, relu(x) = max(x, 0)");
}
};
class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Tanh operator");
AddOutput("Y", "Output of Tanh operator");
AddComment(
"Tanh activation operator, tanh = (exp(x) - exp(-x)) / (exp(x) + "
"exp(-x))");
}
};
class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SqrtOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sqrt operator");
AddOutput("Y", "Output of Sqrt operator");
AddComment("Sqrt activation operator, sqrt(x) = x^(1/2)");
}
};
class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
AbsOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Abs operator");
AddOutput("Y", "Output of Abs operator");
AddComment("Abs activation operator, abs(x) = |x|");
}
};
class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ReciprocalOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Reciprocal operator");
AddOutput("Y", "Output of Reciprocal operator");
AddComment("Reciprocal activation operator, reciprocal(x) = 1 / x");
}
};
class LogOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LogOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Log operator");
AddOutput("Y", "Output of Log operator");
AddComment("Log activation operator, log(x) = natural logarithm of x");
}
};
class SquareOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SquareOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Square operator");
AddOutput("Y", "Output of Square operator");
AddComment("Square activation operator, square(x) = x^2");
}
};
template <typename AttrType>
class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
BReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of BRelu operator");
AddOutput("Y", "Output of BRelu operator");
AddComment("BRelu activation operator, brelu = max(min(x, t_min), t_max)");
AddAttr<AttrType>("t_min", "The min marginal value of BRelu")
.SetDefault(static_cast<AttrType>(0));
AddAttr<AttrType>("t_max", "The max marginal value of BRelu")
.SetDefault(static_cast<AttrType>(24));
}
};
template <typename AttrType>
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SoftReluOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SoftRelu operator");
AddOutput("Y", "Output of SoftRelu operator");
AddComment(
"SoftRelu activation operator, soft_relu = log(1 + exp(max(min(x, "
"threshold), threshold)))");
AddAttr<AttrType>("threshold", "The threshold value of SoftRelu")
.SetDefault(static_cast<AttrType>(40));
}
};
template <typename AttrType>
class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public:
PowOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Pow operator");
AddOutput("Y", "Output of Pow operator");
AddComment("Pow activation operator, pow(x, factor) = x^factor");
AddAttr<AttrType>("factor", "The exponential factor of Pow")
.SetDefault(static_cast<AttrType>(1));
}
};
template <typename AttrType>
class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
public:
STanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of STanh operator");
AddOutput("Y", "Output of STanh operator");
AddComment("STanh activation operator, stanh = b * tanh(a * x)");
AddAttr<AttrType>("scale_a", "The scale parameter of a for the input")
.SetDefault(static_cast<AttrType>(2 / 3));
AddAttr<AttrType>("scale_b", "The scale parameter of b for the input")
.SetDefault(static_cast<AttrType>(1.7159));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::ActivationOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::SigmoidFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
sigmoid_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SigmoidGradFunctor<float>>);
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
exp,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::ExpFunctor>);
REGISTER_OP_CPU_KERNEL(exp_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace,
float, ops::ExpGradFunctor>);
REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::ReluFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::ReluGradFunctor<float>>);
REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
tanh,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::TanhFunctor>);
REGISTER_OP_CPU_KERNEL(
tanh_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::TanhGradFunctor<float>>);
REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
sqrt,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::SqrtFunctor>);
REGISTER_OP_CPU_KERNEL(
sqrt_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SqrtGradFunctor<float>>);
REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
abs,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::AbsFunctor>);
REGISTER_OP_CPU_KERNEL(abs_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace,
float, ops::AbsGradFunctor>);
REGISTER_OP(reciprocal, ops::ActivationOp, ops::ReciprocalOpMaker,
reciprocal_grad, ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(reciprocal,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::ReciprocalFunctor<float>>);
REGISTER_OP_CPU_KERNEL(
reciprocal_grad,
ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::ReciprocalGradFunctor<float>>);
REGISTER_OP(log, ops::ActivationOp, ops::LogOpMaker, log_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(
log,
ops::ActivationKernel<paddle::platform::CPUPlace, float, ops::LogFunctor>);
REGISTER_OP_CPU_KERNEL(
log_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::LogGradFunctor<float>>);
REGISTER_OP(square, ops::ActivationOp, ops::SquareOpMaker, square_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::CPUPlace, float,
ops::SquareFunctor>);
REGISTER_OP_CPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::CPUPlace, float,
ops::SquareGradFunctor<float>>);
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(brelu,
ops::BReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(brelu_grad,
ops::BReluGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>,
soft_relu_grad, ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(soft_relu,
ops::SoftReluKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
soft_relu_grad, ops::SoftReluGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(pow, ops::PowKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(pow_grad,
ops::PowGradKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad,
ops::ActivationOpGrad);
REGISTER_OP_CPU_KERNEL(stanh,
ops::STanhKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(stanh_grad,
ops::STanhGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/activation_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sigmoid,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::SigmoidFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
sigmoid_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SigmoidGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
exp,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::ExpFunctor>);
REGISTER_OP_GPU_KERNEL(exp_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace,
float, ops::ExpGradFunctor>);
REGISTER_OP_GPU_KERNEL(relu,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::ReluFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
relu_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::ReluGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
tanh,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::TanhFunctor>);
REGISTER_OP_GPU_KERNEL(
tanh_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::TanhGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
sqrt,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::SqrtFunctor>);
REGISTER_OP_GPU_KERNEL(
sqrt_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SqrtGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
abs,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::AbsFunctor>);
REGISTER_OP_GPU_KERNEL(abs_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace,
float, ops::AbsGradFunctor>);
REGISTER_OP_GPU_KERNEL(reciprocal,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::ReciprocalFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
reciprocal_grad,
ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::ReciprocalGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(
log,
ops::ActivationKernel<paddle::platform::GPUPlace, float, ops::LogFunctor>);
REGISTER_OP_GPU_KERNEL(
log_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::LogGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(square,
ops::ActivationKernel<paddle::platform::GPUPlace, float,
ops::SquareFunctor>);
REGISTER_OP_GPU_KERNEL(
square_grad, ops::ActivationGradKernel<paddle::platform::GPUPlace, float,
ops::SquareGradFunctor<float>>);
REGISTER_OP_GPU_KERNEL(brelu,
ops::BReluKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(brelu_grad,
ops::BReluGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(soft_relu,
ops::SoftReluKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
soft_relu_grad, ops::SoftReluGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pow, ops::PowKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(pow_grad,
ops::PowGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(stanh,
ops::STanhKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(stanh_grad,
ops::STanhGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T, typename Functor>
class ActivationKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
Functor functor;
functor(place, x, y);
}
};
template <typename Place, typename T, typename Functor>
class ActivationGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
Functor functor;
functor(place, x, y, dy, dx);
}
};
// sigmoid(x) = 1 / (1 + exp(-x))
template <typename T>
struct SigmoidFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
}
};
template <typename T>
struct SigmoidGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * y * (static_cast<T>(1) - y);
}
};
// exp(x) = e^x
struct ExpFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.exp();
}
};
struct ExpGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * y;
}
};
// relu(x) = max(x, 0)
template <typename T>
struct ReluFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.cwiseMax(static_cast<T>(0));
}
};
template <typename T>
struct ReluGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>();
}
};
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
struct TanhFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.tanh();
}
};
template <typename T>
struct TanhGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * (static_cast<T>(1) - y * y);
}
};
// sqrt(x) = x^(1/2)
struct SqrtFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.sqrt();
}
};
template <typename T>
struct SqrtGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
const Y y_conj = Eigen::numext::conj(y);
dx.device(d) = static_cast<T>(0.5) * dy / y_conj;
}
};
// abs(x) = |x|
struct AbsFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.abs();
}
};
struct AbsGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * x.sign();
}
};
// reciprocal(x) = 1 / x
template <typename T>
struct ReciprocalFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = static_cast<T>(1) / x;
}
};
template <typename T>
struct ReciprocalGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * static_cast<T>(-1) * y * y;
}
};
// log(x) = natural logarithm of x
struct LogFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.log();
}
};
template <typename T>
struct LogGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * (static_cast<T>(1) / x);
}
};
// square(x) = x^2
struct SquareFunctor {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x.square();
}
};
template <typename T>
struct SquareGradFunctor {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
dx.device(d) = dy * static_cast<T>(2) * x;
}
};
template <typename Place, typename T, typename AttrType = T>
class BReluKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
auto t_min = static_cast<T>(context.Attr<AttrType>("t_min"));
auto t_max = static_cast<T>(context.Attr<AttrType>("t_max"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
y.device(place) = x.cwiseMax(t_min).cwiseMin(t_max);
}
};
template <typename Place, typename T, typename AttrType = T>
class BReluGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto t_min = static_cast<T>(context.Attr<AttrType>("t_min"));
auto t_max = static_cast<T>(context.Attr<AttrType>("t_max"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
dx.device(place) = dy * ((x > t_min) * (x < t_max)).template cast<T>();
}
};
template <typename Place, typename T, typename AttrType = T>
class SoftReluKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
auto threshold = static_cast<T>(context.Attr<AttrType>("threshold"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
auto temp = x.cwiseMax(-threshold).cwiseMin(threshold).eval();
y.device(place) = (static_cast<T>(1) + temp.exp()).log();
}
};
template <typename Place, typename T, typename AttrType = T>
class SoftReluGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto threshold = static_cast<T>(context.Attr<AttrType>("threshold"));
dX->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
auto temp = ((x > -threshold) * (x < threshold)).template cast<T>().eval();
dx.device(place) = dy * (static_cast<T>(1) - (-y).exp()) * temp;
}
};
template <typename Place, typename T, typename AttrType = T>
class PowKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
auto factor = static_cast<T>(context.Attr<AttrType>("factor"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
y.device(place) = x.pow(factor);
}
};
template <typename Place, typename T, typename AttrType = T>
class PowGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto factor = static_cast<T>(context.Attr<AttrType>("factor"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
dx.device(place) = dy * factor * x.pow(factor - static_cast<T>(1));
}
};
template <typename Place, typename T, typename AttrType = T>
class STanhKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
auto scale_a = static_cast<T>(context.Attr<AttrType>("scale_a"));
auto scale_b = static_cast<T>(context.Attr<AttrType>("scale_b"));
Y->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto place = context.GetEigenDevice<Place>();
y.device(place) = scale_b * (scale_a * x).tanh();
}
};
template <typename Place, typename T, typename AttrType = T>
class STanhGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto scale_a = static_cast<T>(context.Attr<AttrType>("scale_a"));
auto scale_b = static_cast<T>(context.Attr<AttrType>("scale_b"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto x = framework::EigenVector<T>::Flatten(*X);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto place = context.GetEigenDevice<Place>();
auto temp = (scale_a * x).tanh() * (scale_a * x).tanh();
dx.device(place) = dy * scale_a * scale_b * (static_cast<T>(1) - temp);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/crop_op.h"
#include <boost/lexical_cast.hpp>
namespace paddle {
namespace operators {
using framework::Tensor;
using framework::LoDTensor;
class CropOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of CropOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Output(Out) of CropOp should not be null.");
auto x_dim = ctx.Input<LoDTensor>("X")->dims();
auto *y = ctx.Input<LoDTensor>("Y");
auto *out = ctx.Output<LoDTensor>("Out");
if (y == nullptr) {
auto shape = Attr<std::vector<int>>("shape");
PADDLE_ENFORCE_EQ(
int64_t(shape.size()), x_dim.size(),
"Shape size should be equal to dimention size of input tensor.");
std::vector<int64_t> tensor_shape(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
tensor_shape[i] = static_cast<int64_t>(shape[i]);
}
out->Resize(framework::make_ddim(tensor_shape));
} else {
PADDLE_ENFORCE_EQ(framework::arity(x_dim), framework::arity(y->dims()),
"Tensor rank of both CropOp's "
"inputs must be same.");
out->Resize(y->dims());
}
}
};
class CropOpMaker : public framework::OpProtoAndCheckerMaker {
public:
CropOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X",
"The input of pad op. "
"The input should be a k-D tensor(k > 0 and k < 7)");
AddInput("Y",
"The input used as reference for cropping"
" with the same dimension as X. ");
AddOutput("Out",
"The output of crop op "
"with the same dimension as X.");
AddAttr<std::vector<int>>("offsets",
"A list<int> describing offsets to be cropped."
"The size of offsets list should be as same as "
"dimension size of input X.");
AddAttr<std::vector<int>>("shape",
"A list<int> describing the shape of output."
"The size of shape list should be as same as "
"dimension size of input X.")
.SetDefault(std::vector<int>());
AddComment(R"DOC(
Crop Operator.
Crop input into output, as specified by offsets and shape.
There are two ways to set shape:
1. referenc input: crop input X as shape as reference input.
The dimension of reference input should
be as same as input X.
2. shape list: crop input X by shape described by a list<int>.
The size of shape list should be as same as
dimension size of input X.
The input should be a k-D tensor(k > 0 and k < 7). As an example:
Given:
X = [[0, 1, 2, 0, 0]
[0, 3, 4, 0, 0]
[0, 0, 0, 0, 0]]
and
offsets = [0, 1]
and
shape = [2, 2]
then we get
Out = [[1, 2],
[3, 4]]
)DOC");
}
};
class CropOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) should not be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
auto x_dims = ctx.Input<LoDTensor>("X")->dims();
auto *x_grad = ctx.Output<LoDTensor>(framework::GradVarName("X"));
if (x_grad != nullptr) {
x_grad->Resize(x_dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(crop, ops::CropOp, ops::CropOpMaker, crop_grad, ops::CropOpGrad);
REGISTER_OP_CPU_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_CPU_KERNEL(crop_grad,
ops::CropGradKernel<paddle::platform::CPUPlace, float>);
......@@ -13,11 +13,9 @@
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/sigmoid_op.h"
#include "paddle/operators/crop_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
sigmoid_grad, ops::SigmoidGradKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(crop, ops::CropKernel<float>);
REGISTER_OP_GPU_KERNEL(crop_grad,
ops::CropGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 CropdleCropdle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/strided_memcpy.h"
namespace paddle {
namespace operators { // Internal
template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
using framework::Tensor;
template <typename T>
class CropKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
const T* x_data = x->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());
auto x_stride = framework::stride(x->dims());
auto out_stride = framework::stride(out->dims());
auto offsets = context.Attr<std::vector<int>>("offsets");
PADDLE_ENFORCE_EQ(
x->dims().size(), offsets.size(),
"Offsets size should be equal to dimension size of input tensor.");
int64_t offset = 0;
for (int i = 0; i < offsets.size(); ++i) {
offset += (x_stride[i] * offsets[i]);
}
StridedMemcpy<T>(context.device_context(), x_data + offset, x_stride,
out->dims(), out_stride, out_data);
}
};
template <typename Place, typename T, size_t D>
void CropGradFunction(const framework::ExecutionContext& context) {
auto* d_x = context.Output<Tensor>(framework::GradVarName("X"));
if (d_x != nullptr) {
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
d_x->mutable_data<T>(context.GetPlace());
auto offsets = context.Attr<std::vector<int>>("offsets");
Eigen::array<std::pair<int, int>, D> paddings;
for (int i = 0; i < D; ++i) {
paddings[i].first = offsets[i];
paddings[i].second = d_x->dims()[i] - d_out->dims()[i] - offsets[i];
}
auto d_x_tensor = EigenTensor<T, D>::From(*d_x);
auto d_out_tensor = EigenTensor<T, D>::From(*d_out);
d_x_tensor.device(context.GetEigenDevice<Place>()) =
d_out_tensor.pad(paddings, 0);
}
}
template <typename Place, typename T>
class CropGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
size_t rank =
context.Input<Tensor>(framework::GradVarName("Out"))->dims().size();
switch (rank) {
case 1:
CropGradFunction<Place, T, 1>(context);
break;
case 2:
CropGradFunction<Place, T, 2>(context);
break;
case 3:
CropGradFunction<Place, T, 3>(context);
break;
case 4:
CropGradFunction<Place, T, 4>(context);
break;
case 5:
CropGradFunction<Place, T, 5>(context);
break;
case 6:
CropGradFunction<Place, T, 6>(context);
break;
default:
PADDLE_THROW(
"CropOp only support tensors with no more than 6 dimensions.");
}
}
};
} // namespace operators
} // namespace paddle
......@@ -75,9 +75,6 @@ class GemmConv2DKernel : public framework::OpKernel {
framework::DDim output_matrix_shape = {output_channels,
output_height * output_width};
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
// convolution operator: im2col + gemm
int in_step = input_channels / groups;
int out_step = output_channels / groups;
......@@ -87,14 +84,14 @@ class GemmConv2DKernel : public framework::OpKernel {
for (int g = 0; g < groups; g++) {
// im2col
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1],
device_context);
im2col(context.device_context(), in_slice, col, strides[0], strides[1],
paddings[0], paddings[1]);
// gemm
Tensor out_slice = out_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice<T>(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(filter_slice, false, col_matrix, false, T(1.0),
&out_slice, T(0.0), device_context);
math::matmul<Place, T>(context.device_context(), filter_slice, false,
col_matrix, false, T(1.0), &out_slice, T(0.0));
}
}
}
......@@ -160,9 +157,6 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape);
auto* device_context =
const_cast<platform::DeviceContext*>(context.device_context_);
// convolution backward input operator: gemm + col2im
// convolution backward weight operator: im2col + gemm
int in_step = input_channels / groups;
......@@ -184,14 +178,15 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor filter_slice =
filter.Slice<T>(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(filter_slice, true, out_grad_slice, false,
T(1.0), &col_matrix, T(0.0), device_context);
math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
// col2im
Tensor in_grad_slice =
in_grad_batch.Slice<T>(g * in_step, (g + 1) * in_step);
col2im(in_grad_slice, col, strides[0], strides[1], paddings[0],
paddings[1], device_context);
col2im(context.device_context(), in_grad_slice, col, strides[0],
strides[1], paddings[0], paddings[1]);
}
}
}
......@@ -212,15 +207,15 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
Tensor out_grad_slice =
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
im2col(in_slice, col, strides[0], strides[1], paddings[0],
paddings[1], device_context);
im2col(context.device_context(), in_slice, col, strides[0],
strides[1], paddings[0], paddings[1]);
// gemm
Tensor filter_grad_slice =
filter_grad_.Slice<T>(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(out_grad_slice, false, col_matrix, true,
T(1.0), &filter_grad_slice, T(1.0),
device_context);
math::matmul<Place, T>(context.device_context(), out_grad_slice,
false, col_matrix, true, T(1.0),
&filter_grad_slice, T(1.0));
}
}
}
......
......@@ -27,9 +27,10 @@ template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
......@@ -79,9 +80,9 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::CPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -137,9 +138,10 @@ template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -197,9 +199,9 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::CPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......
......@@ -64,9 +64,10 @@ template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
......@@ -84,9 +85,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int block_y = (blocks + 512 - 1) / 512;
dim3 threads(1024, 1);
dim3 grid(block_x, block_y);
im2col<T><<<
grid, threads, 0,
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
im2col<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), num_outputs, input_height, input_width, filter_height,
filter_width, stride_height, stride_width, padding_height,
padding_width, output_height, output_width, col.data<T>());
......@@ -149,9 +150,9 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
platform::GPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
......@@ -174,9 +175,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
// To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions.
col2im<T><<<
grid, threads, 0,
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
col2im<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
num_kernels, col.data<T>(), input_height + 2 * padding_height,
input_width + 2 * padding_width, input_channels, filter_height,
filter_width, stride_height, stride_width, padding_height,
......@@ -235,9 +236,10 @@ template <class T>
class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -268,9 +270,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height);
im2colOCF<T><<<
grid, threads, 0,
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
im2colOCF<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width);
......@@ -318,9 +320,9 @@ template <class T>
class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
platform::GPUPlace, T> {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) {
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width) {
PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0];
......@@ -351,9 +353,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height);
col2imOCF<T><<<
grid, threads, 0,
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>(
col2imOCF<T><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width);
......
......@@ -72,17 +72,18 @@ enum class ColFormat { kCFO = 0, kOCF = 1 };
template <ColFormat Format, typename Place, typename T>
class Im2ColFunctor {
public:
void operator()(const framework::Tensor& im, framework::Tensor& col,
void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context);
int padding_width);
};
template <ColFormat Format, typename Place, typename T>
class Col2ImFunctor {
public:
void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context);
void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height,
int stride_width, int padding_height, int padding_width);
};
} // namespace math
......
......@@ -78,8 +78,8 @@ void testIm2col() {
PADDLE_THROW("no GPU support");
#endif // PADDLE_ONLY_CPU
}
im2col(input, output_cfo, stride, stride, padding, padding, context);
im2col_ocf(input, output_ocf, stride, stride, padding, padding, context);
im2col(*context, input, output_cfo, stride, stride, padding, padding);
im2col_ocf(*context, input, output_ocf, stride, stride, padding, padding);
float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/rank_loss_op.h"
namespace paddle {
namespace operators {
class RankLossOp : public framework::OperatorWithKernel {
public:
RankLossOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
// input check
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) shouldn't be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"),
"Input(Left) shouldn't be null");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"),
"Input(Right) shouldn't be null");
auto label_dims = ctx.Input<framework::Tensor>("Label")->dims();
auto left_dims = ctx.Input<framework::Tensor>("Left")->dims();
auto right_dims = ctx.Input<framework::Tensor>("Right")->dims();
PADDLE_ENFORCE((label_dims == left_dims) && (left_dims == right_dims),
"All inputs must have the same size");
PADDLE_ENFORCE((label_dims.size() == 2) && (label_dims[1] == 1),
"All inputs must be row vector with size batch_size x 1.");
ctx.Output<framework::LoDTensor>("Out")->Resize(label_dims);
}
};
class RankLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
RankLossOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Label",
"The label indicating A ranked higher than B or not, row vector.");
AddInput("Left", "The output of RankNet for doc A, vector.");
AddInput("Right", "The output of RankNet for doc B, vetor");
AddOutput("Out", "The output loss of RankLoss operator, vector.");
AddComment(R"DOC(RankLoss operator
Rank loss operator for RankNet[1]. RankNet is a pairwise ranking model with
one training sample consisting of a pair of doc A and B, and the label P
indicating that A is ranked higher than B or not:
P = {0, 1} or {0, 0.5, 1}, where 0.5 means no information about the rank of
the input pair.
The RankLoss operator contains three inputs: Left (o_i), Right (o_j) and Label
(P_{i,j}), which represent the output of RankNet for two docs and the label
respectively, and yields the rank loss C_{i,j} by following the expression
\f[
C_{i,j} = -\tilde{P_{ij}} * o_{i,j} + log(1 + e^{o_{i,j}}) \\
o_{i,j} = o_i - o_j \\
\tilde{P_{i,j}} = \left \{0, 0.5, 1 \right \} \ or \ \left \{0, 1 \right \}
\f]
The operator can take inputs of one sample or in batch.
[1]. Chris Burges, Tal Shaked, Erin Renshaw, et al. Learning to
Rank using Gradient Descent.
http://icml.cc/2015/wp-content/uploads/2015/06/icml_ranking.pdf
)DOC");
}
};
class RankLossGradOp : public framework::OperatorWithKernel {
public:
RankLossGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"),
"Input(Label) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Left"),
"Input(Left) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Right"),
"Input(Right) shouldn't be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
auto dims = ctx.Input<framework::Tensor>("Left")->dims();
auto *left_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Left"));
auto *right_grad =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Right"));
if (left_grad) {
left_grad->Resize(dims);
}
if (right_grad) {
right_grad->Resize(dims);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(rank_loss, ops::RankLossOp, ops::RankLossOpMaker, rank_loss_grad,
ops::RankLossGradOp);
REGISTER_OP_CPU_KERNEL(rank_loss,
ops::RankLossKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
rank_loss_grad, ops::RankLossGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/rank_loss_op.h"
REGISTER_OP_GPU_KERNEL(
rank_loss,
paddle::operators::RankLossKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
rank_loss_grad,
paddle::operators::RankLossGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename Place, typename T>
class RankLossKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_t = ctx.Output<framework::LoDTensor>("Out");
auto* label_t = ctx.Input<framework::Tensor>("Label");
auto* left_t = ctx.Input<framework::Tensor>("Left");
auto* right_t = ctx.Input<framework::Tensor>("Right");
out_t->mutable_data<T>(ctx.GetPlace());
auto out = framework::EigenVector<T>::Flatten(*out_t);
auto label = framework::EigenVector<T>::Flatten(*label_t);
auto left = framework::EigenVector<T>::Flatten(*left_t);
auto right = framework::EigenVector<T>::Flatten(*right_t);
auto& dev = ctx.GetEigenDevice<Place>();
out.device(dev) =
(1. + (left - right).exp()).log() - label * (left - right);
}
};
template <typename Place, typename T>
class RankLossGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_left_t =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Left"));
auto* d_right_t =
ctx.Output<framework::LoDTensor>(framework::GradVarName("Right"));
auto* d_out_t = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* label_t = ctx.Input<framework::Tensor>("Label");
auto* left_t = ctx.Input<framework::Tensor>("Left");
auto* right_t = ctx.Input<framework::Tensor>("Right");
auto& dev = ctx.GetEigenDevice<Place>();
auto d_out = framework::EigenVector<T>::Flatten(*d_out_t);
auto label = framework::EigenVector<T>::Flatten(*label_t);
auto left = framework::EigenVector<T>::Flatten(*left_t);
auto right = framework::EigenVector<T>::Flatten(*right_t);
// compute d_left
if (d_left_t) {
d_left_t->mutable_data<T>(ctx.GetPlace());
auto d_left = framework::EigenVector<T>::Flatten(*d_left_t);
d_left.device(dev) = d_out * (1. / (1. + (right - left).exp()) - label);
}
// compute d_right
if (d_right_t) {
d_right_t->mutable_data<T>(ctx.GetPlace());
auto d_right = framework::EigenVector<T>::Flatten(*d_right_t);
d_right.device(dev) =
-d_out * (1.0 / (1. + (right - left).exp()) - label);
}
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/sigmoid_op.h"
namespace paddle {
namespace operators {
class SigmoidOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of SigmoidOp should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Y"),
"Output(Y) of SigmoidOp should not be null.");
ctx.Output<framework::LoDTensor>("Y")->Resize(
ctx.Input<Tensor>("X")->dims());
}
};
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public:
SigmoidOpMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "sigmoid input");
AddOutput("Y", "sigmoid output");
AddComment("Sigmoid function");
}
};
class SigmoidOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
ctx.Output<framework::LoDTensor>(framework::GradVarName("X"))
->Resize(ctx.Input<Tensor>("Y")->dims());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(sigmoid, ops::SigmoidOp, ops::SigmoidOpMaker, sigmoid_grad,
ops::SigmoidOpGrad);
REGISTER_OP_CPU_KERNEL(sigmoid,
ops::SigmoidKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
sigmoid_grad, ops::SigmoidGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T>
class SigmoidKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<Tensor>("X");
auto output = context.Output<Tensor>("Y");
output->mutable_data<T>(context.GetPlace());
// The clipping is used in Paddle's raw implenmention
auto X = EigenVector<T>::Flatten(*input);
auto Y = EigenVector<T>::Flatten(*output);
auto place = context.GetEigenDevice<Place>();
Y.device(place) = 1. / (1. + (-X).exp());
}
};
template <typename Place, typename T>
class SigmoidGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto Y_t = context.Input<Tensor>("Y");
auto dY_t = context.Input<Tensor>(framework::GradVarName("Y"));
auto dX_t = context.Output<Tensor>(framework::GradVarName("X"));
dX_t->mutable_data<T>(context.GetPlace());
auto dX = EigenVector<T>::Flatten(*dX_t);
auto Y = EigenVector<T>::Flatten(*Y_t);
auto dY = EigenVector<T>::Flatten(*dY_t);
dX.device(context.GetEigenDevice<Place>()) = dY * Y * (1. - Y);
}
};
} // namespace operators
} // namespace paddle
import unittest
import numpy as np
from op_test import OpTest
class TestExp(OpTest):
def setUp(self):
self.op_type = "exp"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.exp(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestSigmoid(OpTest):
def setUp(self):
self.op_type = "sigmoid"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008)
class TestTanh(OpTest):
def setUp(self):
self.op_type = "tanh"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.tanh(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestSqrt(OpTest):
def setUp(self):
self.op_type = "sqrt"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.sqrt(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestAbs(OpTest):
def setUp(self):
self.op_type = "abs"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
# Because we set delta = 0.005 in caculating numeric gradient,
# if x is too small, such as 0.002, x_neg will be -0.003
# x_pos will be 0.007, so the numeric gradient is unaccurate.
# we should avoid this
x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x}
self.outputs = {'Y': np.abs(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestRelu(OpTest):
def setUp(self):
self.op_type = "relu"
x = np.random.uniform(-1, 1, [11, 17]).astype("float32")
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x}
self.outputs = {'Y': np.maximum(self.inputs['X'], 0)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestBRelu(OpTest):
def setUp(self):
self.op_type = "brelu"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
t_min = 1
t_max = 4
# The same with TestAbs
x[np.abs(x - t_min) < 0.005] = t_min + 0.02
x[np.abs(x - t_max) < 0.005] = t_max + 0.02
self.inputs = {'X': x}
self.attrs = {'t_min': t_min, 't_max': t_max}
t = np.copy(x)
t[t < t_min] = t_min
t[t > t_max] = t_max
self.outputs = {'Y': t}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestSoftRelu(OpTest):
def setUp(self):
self.op_type = "soft_relu"
x = np.random.uniform(-3, 3, [4, 4]).astype("float32")
threshold = 2
# The same reason with TestAbs
x[np.abs(x - threshold) < 0.005] = threshold + 0.02
x[np.abs(x + threshold) < 0.005] = -threshold + 0.02
self.inputs = {'X': x}
self.attrs = {'threshold': threshold}
t = np.copy(x)
t[t < -threshold] = -threshold
t[t > threshold] = threshold
self.outputs = {'Y': np.log((np.exp(t) + 1))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestReciprocal(OpTest):
def setUp(self):
self.op_type = "reciprocal"
self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")}
self.outputs = {'Y': np.reciprocal(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.01)
class TestLog(OpTest):
def setUp(self):
self.op_type = "log"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.log(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestSquare(OpTest):
def setUp(self):
self.op_type = "square"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.square(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
class TestPow(OpTest):
def setUp(self):
self.op_type = "pow"
self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")}
self.attrs = {'factor': 3}
self.outputs = {'Y': np.power(self.inputs['X'], 3)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
class TestSTanh(OpTest):
def setUp(self):
self.op_type = "stanh"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
scale_a = 2.0 / 3.0
scale_b = 1.7159
self.attrs = {'scale_a': scale_a, 'scale_b': scale_b}
self.outputs = {'Y': scale_b * np.tanh(self.inputs['X'] * scale_a)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
if __name__ == "__main__":
unittest.main()
......@@ -73,13 +73,22 @@ class TestConv2dOp(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(set(['Input', 'Filter']), 'Output')
self.check_grad(
set(['Input', 'Filter']), 'Output', max_relative_error=0.05)
def test_check_grad_no_filter(self):
self.check_grad(['Input'], 'Output', no_grad_set=set(['Filter']))
self.check_grad(
['Input'],
'Output',
max_relative_error=0.05,
no_grad_set=set(['Filter']))
def test_check_grad_no_input(self):
self.check_grad(['Filter'], 'Output', no_grad_set=set(['Input']))
self.check_grad(
['Filter'],
'Output',
max_relative_error=0.05,
no_grad_set=set(['Input']))
def init_groups(self):
self.groups = 1
......
import unittest
import numpy as np
from op_test import OpTest
def crop(data, offsets, crop_shape):
def indexOf(shape, index):
result = []
for dim in reversed(shape):
result.append(index % dim)
index = index / dim
return result[::-1]
result = []
for i, value in enumerate(data.flatten()):
index = indexOf(data.shape, i)
selected = True
if len(index) == len(offsets):
for j, offset in enumerate(offsets):
selected = selected and index[j] >= offset and index[
j] < crop_shape[j] + offset
if selected:
result.append(value)
return np.array(result).reshape(crop_shape)
class TestCropOp(OpTest):
def setUp(self):
self.op_type = "crop"
self.crop_by_input = False
self.attrs = {}
self.initTestCase()
self.attrs['offsets'] = self.offsets
if self.crop_by_input:
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
'Y': np.random.random(self.crop_shape).astype("float32")
}
else:
self.attrs['shape'] = self.crop_shape
self.inputs = {
'X': np.random.random(self.x_shape).astype("float32"),
}
self.outputs = {
'Out': crop(self.inputs['X'], self.offsets, self.crop_shape)
}
def initTestCase(self):
self.x_shape = (8, 8)
self.crop_shape = (2, 2)
self.offsets = [1, 2]
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=0.006)
class TestCase1(TestCropOp):
def initTestCase(self):
self.x_shape = (16, 8, 32)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
class TestCase2(TestCropOp):
def initTestCase(self):
self.x_shape = (4, 8)
self.crop_shape = [4, 8]
self.offsets = [0, 0]
class TestCase3(TestCropOp):
def initTestCase(self):
self.x_shape = (4, 8, 16)
self.crop_shape = [2, 2, 3]
self.offsets = [1, 5, 3]
self.crop_by_input = True
class TestCase4(TestCropOp):
def initTestCase(self):
self.x_shape = (4, 4)
self.crop_shape = [4, 4]
self.offsets = [0, 0]
self.crop_by_input = True
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestRankLossOp(OpTest):
def setUp(self):
self.op_type = "rank_loss"
batch_size = 5
# labels_{i} = {0, 1.0} or {0, 0.5, 1.0}
label = np.random.randint(0, 2, size=(batch_size, 1)).astype("float32")
left = np.random.random((batch_size, 1)).astype("float32")
right = np.random.random((batch_size, 1)).astype("float32")
loss = np.log(1.0 + np.exp(left - right)) - label * (left - right)
self.inputs = {'Label': label, 'Left': left, 'Right': right}
self.outputs = {'Out': loss}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["Left", "Right"], "Out")
def test_check_grad_ignore_left(self):
self.check_grad(["Right"], "Out", no_grad_set=set('Left'))
def test_check_grad_ignore_right(self):
self.check_grad(["Left"], "Out", no_grad_set=set('Right'))
if __name__ == '__main__':
unittest.main()
import unittest
import numpy as np
from op_test import OpTest
class TestSigmoidOp(OpTest):
def setUp(self):
self.op_type = "sigmoid"
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["X"], "Y", max_relative_error=0.007)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册