提交 24458ae3 编写于 作者: Y Yibing Liu

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into profiler_tool

......@@ -21,6 +21,8 @@ cc_test(variable_test SRCS variable_test.cc)
cc_library(scope SRCS scope.cc DEPS glog)
cc_test(scope_test SRCS scope_test.cc DEPS scope)
cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)
cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc
......@@ -29,7 +31,8 @@ cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog
shape_inference data_transform)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog)
......@@ -65,6 +68,3 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece)
cc_test(init_test SRCS init_test.cc DEPS init)
cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto)
cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto)
cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context)
......@@ -32,17 +32,16 @@ using DataTransformFN =
const Variable& in, Variable* out)>;
using KernelTypePair = std::pair<OpKernelType, OpKernelType>;
static void hash_combine(std::size_t& seed, const OpKernelType& t) {
OpKernelType::Hash kernel_type_hasher;
seed ^= kernel_type_hasher(t) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
struct KernelTypePairHash {
static void HashCombine(const OpKernelType& t, std::size_t* seed) {
OpKernelType::Hash kernel_type_hasher;
(*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
}
size_t operator()(const KernelTypePair& kernel_pair) const {
std::size_t seed = 0;
hash_combine(seed, kernel_pair.first);
hash_combine(seed, kernel_pair.second);
HashCombine(kernel_pair.first, &seed);
HashCombine(kernel_pair.second, &seed);
return seed;
}
};
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm>
#include <atomic>
#include "paddle/framework/data_transform.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/operator.h"
......@@ -411,7 +412,38 @@ void OperatorWithKernel::Run(const Scope& scope,
expected_kernel_key);
}
kernel_iter->second->Compute(ctx);
if (actual_kernel_key == expected_kernel_key) {
kernel_iter->second->Compute(ctx);
} else {
Scope& op_scope = scope.NewScope();
auto input_vars = this->InputVars();
for (auto var_name : input_vars) {
op_scope.Var(var_name);
}
// TODO(qijun) get appropriate DeviceContext from DeviceContext pool
platform::DeviceContext* trans_dev_ctx = nullptr;
std::vector<platform::DeviceContext*> trans_dev_ctx_vec{trans_dev_ctx};
// TODO(qijun) get appropriate DataTransformFN from global map
framework::DataTransformFN trans_fun = nullptr;
// Wait for transform starting
dev_ctx->Wait();
for (auto var_name : input_vars) {
trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)),
op_scope.FindVar(var_name));
}
// Wait for data transform finishing
for (auto ctx : trans_dev_ctx_vec) {
ctx->Wait();
}
// Create a new ExecutionContext
ExecutionContext op_ctx(*this, op_scope, *dev_ctx);
kernel_iter->second->Compute(op_ctx);
}
}
OpKernelType OperatorWithKernel::GetActualKernelType(
......
......@@ -22,8 +22,8 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Y");
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
};
......@@ -32,7 +32,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y"));
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
}
};
......@@ -41,11 +41,11 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sigmoid operator");
AddOutput("Y", "Output of Sigmoid operator");
AddOutput("Out", "Output of Sigmoid operator");
AddComment(R"DOC(
Sigmoid Activation Operator
$$y = \frac{1}{1 + e^{-x}}$$
$$out = \frac{1}{1 + e^{-x}}$$
)DOC");
}
......@@ -56,11 +56,11 @@ class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
LogSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LogSigmoid operator");
AddOutput("Y", "Output of LogSigmoid operator");
AddOutput("Out", "Output of LogSigmoid operator");
AddComment(R"DOC(
Logsigmoid Activation Operator
$$y = \log \frac{1}{1 + e^{-x}}$$
$$out = \log \frac{1}{1 + e^{-x}}$$
)DOC");
}
......@@ -71,11 +71,11 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker {
ExpOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Exp operator");
AddOutput("Y", "Output of Exp operator");
AddOutput("Out", "Output of Exp operator");
AddComment(R"DOC(
Exp Activation Operator.
$y = e^x$
$out = e^x$
)DOC");
}
......@@ -86,11 +86,11 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker {
ReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu operator");
AddOutput("Y", "Output of Relu operator");
AddOutput("Out", "Output of Relu operator");
AddComment(R"DOC(
Relu Activation Operator.
$y = \max(x, 0)$
$out = \max(x, 0)$
)DOC");
}
......@@ -101,12 +101,12 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LeakyRelu operator");
AddOutput("Y", "Output of LeakyRelu operator");
AddOutput("Out", "Output of LeakyRelu operator");
AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
AddComment(R"DOC(
LeakyRelu Activation Operator.
$y = \max(x, \alpha * x)$
$out = \max(x, \alpha * x)$
)DOC");
}
......@@ -117,13 +117,13 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softshrink operator");
AddOutput("Y", "Output of Softshrink operator");
AddOutput("Out", "Output of Softshrink operator");
AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
AddComment(R"DOC(
Softshrink Activation Operator.
$$
y = \begin{cases}
out = \begin{cases}
x - \lambda, \text{if } x > \lambda \\
x + \lambda, \text{if } x < -\lambda \\
0, \text{otherwise}
......@@ -139,11 +139,11 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker {
TanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Tanh operator");
AddOutput("Y", "Output of Tanh operator");
AddOutput("Out", "Output of Tanh operator");
AddComment(R"DOC(
Tanh Activation Operator.
$$y = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
$$out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
)DOC");
}
......@@ -154,11 +154,11 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
TanhShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of TanhShrink operator");
AddOutput("Y", "Output of TanhShrink operator");
AddOutput("Out", "Output of TanhShrink operator");
AddComment(R"DOC(
TanhShrink Activation Operator.
$$y = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
$$out = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
)DOC");
}
......@@ -169,14 +169,14 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardShrink operator");
AddOutput("Y", "Output of HardShrink operator");
AddOutput("Out", "Output of HardShrink operator");
AddAttr<float>("threshold", "The value of threshold for HardShrink")
.SetDefault(0.5f);
AddComment(R"DOC(
HardShrink Activation Operator.
$$
y = \begin{cases}
out = \begin{cases}
x, \text{if } x > \lambda \\
x, \text{if } x < -\lambda \\
0, \text{otherwise}
......@@ -192,11 +192,11 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker {
SqrtOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Sqrt operator");
AddOutput("Y", "Output of Sqrt operator");
AddOutput("Out", "Output of Sqrt operator");
AddComment(R"DOC(
Sqrt Activation Operator.
$y = \sqrt{x}$
$out = \sqrt{x}$
)DOC");
}
......@@ -207,11 +207,11 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker {
AbsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Abs operator");
AddOutput("Y", "Output of Abs operator");
AddOutput("Out", "Output of Abs operator");
AddComment(R"DOC(
Abs Activation Operator.
$y = |x|$
$out = |x|$
)DOC");
}
......@@ -222,11 +222,11 @@ class CeilOpMaker : public framework::OpProtoAndCheckerMaker {
CeilOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Ceil operator");
AddOutput("Y", "Output of Ceil operator");
AddOutput("Out", "Output of Ceil operator");
AddComment(R"DOC(
Ceil Activation Operator.
$y = ceil(x)$
$out = ceil(x)$
)DOC");
}
......@@ -237,11 +237,11 @@ class FloorOpMaker : public framework::OpProtoAndCheckerMaker {
FloorOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Floor operator");
AddOutput("Y", "Output of Floor operator");
AddOutput("Out", "Output of Floor operator");
AddComment(R"DOC(
Floor Activation Operator.
$y = floor(x)$
$out = floor(x)$
)DOC");
}
......@@ -252,11 +252,11 @@ class RoundOpMaker : public framework::OpProtoAndCheckerMaker {
RoundOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Round operator");
AddOutput("Y", "Output of Round operator");
AddOutput("Out", "Output of Round operator");
AddComment(R"DOC(
Round Activation Operator.
$y = [x]$
$out = [x]$
)DOC");
}
......@@ -267,11 +267,11 @@ class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker {
ReciprocalOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Reciprocal operator");
AddOutput("Y", "Output of Reciprocal operator");
AddOutput("Out", "Output of Reciprocal operator");
AddComment(R"DOC(
Reciprocal Activation Operator.
$$y = \frac{1}{x}$$
$$out = \frac{1}{x}$$
)DOC");
}
......@@ -282,11 +282,11 @@ class LogOpMaker : public framework::OpProtoAndCheckerMaker {
LogOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Log operator");
AddOutput("Y", "Output of Log operator");
AddOutput("Out", "Output of Log operator");
AddComment(R"DOC(
Log Activation Operator.
$y = \ln(x)$
$out = \ln(x)$
Natural logarithm of x.
......@@ -299,11 +299,11 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker {
SquareOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Square operator");
AddOutput("Y", "Output of Square operator");
AddOutput("Out", "Output of Square operator");
AddComment(R"DOC(
Square Activation Operator.
$y = x^2$
$out = x^2$
)DOC");
}
......@@ -314,11 +314,11 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker {
SoftplusOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softplus operator");
AddOutput("Y", "Output of Softplus operator");
AddOutput("Out", "Output of Softplus operator");
AddComment(R"DOC(
Softplus Activation Operator.
$y = \ln(1 + e^{x})$
$out = \ln(1 + e^{x})$
)DOC");
}
......@@ -329,11 +329,11 @@ class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker {
SoftsignOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softsign operator");
AddOutput("Y", "Output of Softsign operator");
AddOutput("Out", "Output of Softsign operator");
AddComment(R"DOC(
Softsign Activation Operator.
$$y = \frac{x}{1 + |x|}$$
$$out = \frac{x}{1 + |x|}$$
)DOC");
}
......@@ -344,7 +344,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of BRelu operator");
AddOutput("Y", "Output of BRelu operator");
AddOutput("Out", "Output of BRelu operator");
AddAttr<float>("t_min", "The min marginal value of BRelu")
.SetDefault(static_cast<float>(0));
AddAttr<float>("t_max", "The max marginal value of BRelu")
......@@ -352,7 +352,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
BRelu Activation Operator.
$y = \max(\min(x, t_{min}), t_{max})$
$out = \max(\min(x, t_{min}), t_{max})$
)DOC");
}
......@@ -363,13 +363,13 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SoftRelu operator");
AddOutput("Y", "Output of SoftRelu operator");
AddOutput("Out", "Output of SoftRelu operator");
AddAttr<float>("threshold", "The threshold value of SoftRelu")
.SetDefault(40.0f);
AddComment(R"DOC(
SoftRelu Activation Operator.
$y = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
$out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
)DOC");
}
......@@ -380,7 +380,7 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ELU operator");
AddOutput("Y", "Output of ELU operator");
AddOutput("Out", "Output of ELU operator");
AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
AddComment(R"DOC(
ELU Activation Operator.
......@@ -388,7 +388,7 @@ ELU Activation Operator.
Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1511.07289.
$y = \max(0, x) + \min(0, \alpha * (e^x - 1))$
$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$
)DOC");
}
......@@ -399,13 +399,13 @@ class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu6 operator");
AddOutput("Y", "Output of Relu6 operator");
AddOutput("Out", "Output of Relu6 operator");
AddAttr<float>("threshold", "The threshold value of Relu6")
.SetDefault(6.0f);
AddComment(R"DOC(
Relu6 Activation Operator.
$y = \min(\max(0, x), 6)$
$out = \min(\max(0, x), 6)$
)DOC");
}
......@@ -416,12 +416,12 @@ class PowOpMaker : public framework::OpProtoAndCheckerMaker {
PowOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Pow operator");
AddOutput("Y", "Output of Pow operator");
AddOutput("Out", "Output of Pow operator");
AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
AddComment(R"DOC(
Pow Activation Operator.
$y = x^{factor}$
$out = x^{factor}$
)DOC");
}
......@@ -432,7 +432,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of STanh operator");
AddOutput("Y", "Output of STanh operator");
AddOutput("Out", "Output of STanh operator");
AddAttr<float>("scale_a", "The scale parameter of a for the input")
.SetDefault(2.0f / 3.0f);
AddAttr<float>("scale_b", "The scale parameter of b for the input")
......@@ -440,7 +440,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
AddComment(R"DOC(
STanh Activation Operator.
$$y = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
$$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
)DOC");
}
......@@ -451,14 +451,14 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Y", "Output of ThresholdedRelu operator");
AddOutput("Out", "Output of ThresholdedRelu operator");
AddAttr<float>("threshold", "The threshold location of activation")
.SetDefault(1.0f);
AddComment(R"DOC(
ThresholdedRelu Activation Operator.
$$
y = \begin{cases}
out = \begin{cases}
x, \text{if } x > threshold \\
0, \text{otherwise}
\end{cases}
......@@ -473,7 +473,7 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardSigmoid operator");
AddOutput("Y", "Output of HardSigmoid operator");
AddOutput("Out", "Output of HardSigmoid operator");
AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
.SetDefault(0.2f);
AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
......@@ -484,7 +484,7 @@ HardSigmoid Activation Operator.
Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391),
which is much faster than sigmoid.
$y = \max(0, \min(1, slope * x + shift))$
$out = \max(0, \min(1, slope * x + shift))$
The slope should be positive. The offset can be either positive or negative.
The default slope and shift are set according to the above reference.
......@@ -499,12 +499,12 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker {
SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: framework::OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Swish operator");
AddOutput("Y", "Output of Swish operator");
AddOutput("Out", "Output of Swish operator");
AddAttr<float>("beta", "Constant beta of swish operator").SetDefault(1.0f);
AddComment(R"DOC(
Swish Activation Operator.
$$y = \frac{x}{1 + e^{- \beta x}}$$
$$out = \frac{x}{1 + e^{- \beta x}}$$
)DOC");
}
......
......@@ -27,11 +27,11 @@ class ActivationKernel
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Output<framework::Tensor>("Y");
Y->mutable_data<T>(context.GetPlace());
auto* Out = context.Output<framework::Tensor>("Out");
Out->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto* place =
context.template device_context<DeviceContext>().eigen_device();
Functor functor;
......@@ -40,7 +40,7 @@ class ActivationKernel
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(*place, x, y);
functor(*place, x, out);
}
};
......@@ -51,14 +51,15 @@ class ActivationGradKernel
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<framework::Tensor>("X");
auto* Y = context.Input<framework::Tensor>("Y");
auto* dY = context.Input<framework::Tensor>(framework::GradVarName("Y"));
auto* Out = context.Input<framework::Tensor>("Out");
auto* dOut =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<framework::Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
auto dy = framework::EigenVector<T>::Flatten(*dY);
auto dout = framework::EigenVector<T>::Flatten(*dOut);
auto x = framework::EigenVector<T>::Flatten(*X);
auto y = framework::EigenVector<T>::Flatten(*Y);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto* place =
context.template device_context<DeviceContext>().eigen_device();
......@@ -67,7 +68,7 @@ class ActivationGradKernel
for (auto& attr : attrs) {
*attr.second = context.Attr<float>(attr.first);
}
functor(*place, x, y, dy, dx);
functor(*place, x, out, dout, dx);
}
};
......@@ -83,17 +84,18 @@ struct BaseActivationFunctor {
// sigmoid(x) = 1 / (1 + exp(-x))
template <typename T>
struct SigmoidFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
}
};
template <typename T>
struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * y * (static_cast<T>(1) - y);
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out * (static_cast<T>(1) - out);
}
};
......@@ -101,7 +103,7 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
// For numerical stability, we can use the log-sum-exp trick:
// https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/
// We can rewrite the above equation as:
// y = -log( exp(0) + exp(-x)) [since exp(0) = 1]
// out = -log( exp(0) + exp(-x)) [since exp(0) = 1]
// = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0)))
// = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x -
// max(-x, 0)))
......@@ -112,10 +114,10 @@ struct SigmoidGradFunctor : public BaseActivationFunctor<T> {
// + exp(-x - max(-x, 0))))
template <typename T>
struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0)
y.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log());
}
};
......@@ -124,62 +126,66 @@ struct LogSigmoidFunctor : public BaseActivationFunctor<T> {
// exp(-x - max(-x, 0)))
template <typename T>
struct LogSigmoidGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp = (-x).cwiseMax(static_cast<T>(0)); // temp = max(-x, 0)
dx.device(d) =
dy * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp()));
}
};
// exp(x) = e^x
template <typename T>
struct ExpFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.exp();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.exp();
}
};
template <typename T>
struct ExpGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * y;
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * out;
}
};
// relu(x) = max(x, 0)
template <typename T>
struct ReluFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(static_cast<T>(0));
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.cwiseMax(static_cast<T>(0));
}
};
template <typename T>
struct ReluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>();
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>();
}
};
// tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.tanh();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.tanh();
}
};
template <typename T>
struct TanhGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (static_cast<T>(1) - y * y);
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<T>(1) - out * out);
}
};
......@@ -187,17 +193,18 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> {
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
struct TanhShrinkFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x - x.tanh();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x - x.tanh();
}
};
template <typename T>
struct TanhShrinkGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x.tanh() * x.tanh());
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x.tanh() * x.tanh());
}
};
......@@ -210,11 +217,11 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
y.device(d) = x * (temp1 + temp2);
out.device(d) = x * (temp1 + temp2);
}
};
......@@ -226,11 +233,12 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
};
......@@ -243,12 +251,12 @@ struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
return {{"lambda", &lambda}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>().eval();
auto temp2 = (x < -lambdaT).template cast<T>().eval();
y.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
}
};
......@@ -258,46 +266,49 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"lambda", &lambda}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>().eval();
auto temp2 = (x < -lambdaT).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
};
// sqrt(x) = x^(1/2)
template <typename T>
struct SqrtFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.sqrt();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.sqrt();
}
};
template <typename T>
struct SqrtGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
const Y y_conj = Eigen::numext::conj(y);
dx.device(d) = static_cast<T>(0.5) * dy / y_conj;
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
const Out out_conj = Eigen::numext::conj(out);
dx.device(d) = static_cast<T>(0.5) * dout / out_conj;
}
};
// ceil(x) = ceiling(x)
template <typename T>
struct CeilFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.ceil();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.ceil();
}
};
template <typename T>
struct ZeroGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = static_cast<T>(0) / x;
}
};
......@@ -305,86 +316,90 @@ struct ZeroGradFunctor : public BaseActivationFunctor<T> {
// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.ceil();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.ceil();
}
};
// round(x) = [x]
template <typename T>
struct RoundFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.round();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.round();
}
};
// abs(x) = |x|
template <typename T>
struct AbsFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.abs();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.abs();
}
};
template <typename T>
struct AbsGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * x.sign();
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * x.sign();
}
};
// reciprocal(x) = 1 / x
template <typename T>
struct ReciprocalFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = static_cast<T>(1) / x;
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = static_cast<T>(1) / x;
}
};
template <typename T>
struct ReciprocalGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * static_cast<T>(-1) * y * y;
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(-1) * out * out;
}
};
// log(x) = natural logarithm of x
template <typename T>
struct LogFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.log();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.log();
}
};
template <typename T>
struct LogGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (static_cast<T>(1) / x);
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (static_cast<T>(1) / x);
}
};
// square(x) = x^2
template <typename T>
struct SquareFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.square();
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.square();
}
};
template <typename T>
struct SquareGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * static_cast<T>(2) * x;
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(2) * x;
}
};
......@@ -399,9 +414,9 @@ struct BReluFunctor : public BaseActivationFunctor<T> {
return {{"t_min", &t_min}, {"t_max", &t_max}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) =
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
x.cwiseMax(static_cast<T>(t_min)).cwiseMin(static_cast<T>(t_max));
}
};
......@@ -413,9 +428,10 @@ struct BReluGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"t_min", &t_min}, {"t_max", &t_max}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy *
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout *
((x > static_cast<T>(t_min)) * (x < static_cast<T>(t_max)))
.template cast<T>();
}
......@@ -430,9 +446,9 @@ struct Relu6Functor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) =
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
x.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(threshold));
}
};
......@@ -443,9 +459,10 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy *
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout *
((x > static_cast<T>(0)) * (x < static_cast<T>(threshold)))
.template cast<T>();
}
......@@ -458,10 +475,10 @@ struct Relu6GradFunctor : public BaseActivationFunctor<T> {
// Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0)))
template <typename T>
struct SoftplusFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
y.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log());
}
};
......@@ -471,19 +488,21 @@ struct SoftplusFunctor : public BaseActivationFunctor<T> {
// exp(x - max(x, 0)))
template <typename T>
struct SoftplusGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
auto temp = x.cwiseMax(static_cast<T>(0)); // temp = max(x, 0)
dx.device(d) = dy * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
dx.device(d) =
dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp()));
}
};
// softsign(x) = x / (1 + |x|)
template <typename T>
struct SoftsignFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) {
y.device(d) = x / (static_cast<T>(1) + x.abs());
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) {
out.device(d) = x / (static_cast<T>(1) + x.abs());
}
};
......@@ -491,10 +510,11 @@ struct SoftsignFunctor : public BaseActivationFunctor<T> {
// Taken from https://en.wikipedia.org/wiki/Activation_function
template <typename T>
struct SoftsignGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) {
dx.device(d) =
dy * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
dout * (static_cast<T>(1) / (static_cast<T>(1) + x.abs()).square());
}
};
......@@ -505,11 +525,11 @@ struct SoftReluFunctor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto tmp = static_cast<T>(threshold);
auto temp = x.cwiseMax(-tmp).cwiseMin(tmp);
y.device(d) = (static_cast<T>(1) + temp.exp()).log();
out.device(d) = (static_cast<T>(1) + temp.exp()).log();
}
};
......@@ -519,11 +539,12 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto tmp = static_cast<T>(threshold);
auto temp = ((x > -tmp) * (x < tmp)).template cast<T>().eval();
dx.device(d) = dy * (static_cast<T>(1) - (-y).exp()) * temp;
dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
}
};
......@@ -534,9 +555,9 @@ struct LeakyReluFunctor : public BaseActivationFunctor<T> {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.cwiseMax(static_cast<T>(alpha) * x);
}
};
......@@ -546,12 +567,13 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(alpha) *
(x < static_cast<T>(0)).template cast<T>().eval();
auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
dx.device(d) = dy * (temp1 + temp2).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
};
......@@ -562,11 +584,11 @@ struct ELUFunctor : public BaseActivationFunctor<T> {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.cwiseMax(static_cast<T>(0)) +
(static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
.cwiseMin(static_cast<T>(0));
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.cwiseMax(static_cast<T>(0)) +
(static_cast<T>(alpha) * (x.exp() - static_cast<T>(1)))
.cwiseMin(static_cast<T>(0));
}
};
......@@ -576,10 +598,11 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * (x > static_cast<T>(0)).template cast<T>() +
dy * (y + static_cast<T>(alpha)) *
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * (x > static_cast<T>(0)).template cast<T>() +
dout * (out + static_cast<T>(alpha)) *
(x < static_cast<T>(0)).template cast<T>();
}
};
......@@ -591,9 +614,9 @@ struct PowFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x.pow(static_cast<T>(factor));
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.pow(static_cast<T>(factor));
}
};
......@@ -603,9 +626,10 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) = dy * static_cast<T>(factor) *
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout * static_cast<T>(factor) *
x.pow(static_cast<T>(factor - static_cast<T>(1)));
}
};
......@@ -618,9 +642,9 @@ struct STanhFunctor : public BaseActivationFunctor<T> {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) =
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
static_cast<T>(scale_b) * (static_cast<T>(scale_a) * x).tanh();
}
};
......@@ -633,12 +657,13 @@ struct STanhGradFunctor : public BaseActivationFunctor<T> {
return {{"scale_a", &scale_a}, {"scale_b", &scale_b}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto a = static_cast<T>(scale_a);
auto b = static_cast<T>(scale_b);
auto temp = (a * x).tanh() * (a * x).tanh();
dx.device(d) = dy * a * b * (static_cast<T>(1) - temp);
dx.device(d) = dout * a * b * (static_cast<T>(1) - temp);
}
};
......@@ -649,10 +674,10 @@ struct ThresholdedReluFunctor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto th = static_cast<T>(threshold);
y.device(d) = (x > th).template cast<T>() * x;
out.device(d) = (x > th).template cast<T>() * x;
}
};
......@@ -663,10 +688,11 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor<T> {
return {{"threshold", &threshold}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto th = static_cast<T>(threshold);
dx.device(d) = dy * (x > th).template cast<T>();
dx.device(d) = dout * (x > th).template cast<T>();
}
};
......@@ -678,10 +704,11 @@ struct HardSigmoidFunctor : public BaseActivationFunctor<T> {
return {{"slope", &slope}, {"offset", &offset}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp = x * static_cast<T>(slope) + static_cast<T>(offset);
y.device(d) = temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
out.device(d) =
temp.cwiseMax(static_cast<T>(0)).cwiseMin(static_cast<T>(1));
}
};
......@@ -693,12 +720,13 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor<T> {
return {{"slope", &slope}, {"offset", &offset}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
dx.device(d) =
dy *
((y > static_cast<T>(0)) * (y < static_cast<T>(1))).template cast<T>() *
static_cast<T>(slope);
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
dx.device(d) = dout *
((out > static_cast<T>(0)) * (out < static_cast<T>(1)))
.template cast<T>() *
static_cast<T>(slope);
}
};
......@@ -709,9 +737,9 @@ struct SwishFunctor : public BaseActivationFunctor<T> {
return {{"beta", &beta}};
}
template <typename Device, typename X, typename Y>
void operator()(Device d, X x, Y y) const {
y.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x / (static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
}
};
......@@ -722,12 +750,13 @@ struct SwishGradFunctor : public BaseActivationFunctor<T> {
return {{"beta", &beta}};
}
template <typename Device, typename X, typename Y, typename dY, typename dX>
void operator()(Device d, X x, Y y, dY dy, dX dx) const {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(1) /
(static_cast<T>(1) + (static_cast<T>(-beta) * x).exp());
auto temp2 = temp1 * (static_cast<T>(1) - (beta * y));
dx.device(d) = dy * ((beta * y) + temp2);
auto temp2 = temp1 * (static_cast<T>(1) - (beta * out));
dx.device(d) = dout * ((beta * out) + temp2);
}
};
......
......@@ -24,13 +24,13 @@ class SoftmaxOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
"Output(Y) of SoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SoftmaxOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(x_dims.size() == 2UL,
"The input of softmax op must be a matrix.");
ctx->SetOutputDim("Y", x_dims);
ctx->SetOutputDim("Out", x_dims);
}
};
......@@ -41,7 +41,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("X",
"The input tensor of softmax. "
"2-D with shape [batch_size, input_feature_dimensions].");
AddOutput("Y", "The normalized values with the same shape as X.");
AddOutput("Out", "The normalized values with the same shape as X.");
AddComment(R"DOC(
Softmax Operator.
......@@ -59,7 +59,7 @@ exponential values of all the other dimensions is the output of the softmax
operator.
For each row $i$ and each column $j$ in Input(X), we have:
$$Y[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
$$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$
)DOC");
}
......@@ -70,12 +70,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Y"),
ctx->GetInputDim(framework::GradVarName("Y")),
"Input(Y) and its gradients should have a same shape.");
PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(ctx->GetInputDim("Out"),
ctx->GetInputDim(framework::GradVarName("Out")),
"Input(Out) and its gradients should have a same shape.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
......
......@@ -26,13 +26,13 @@ class SoftmaxKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* X = context.Input<Tensor>("X");
auto* Y = context.Output<Tensor>("Y");
auto* Out = context.Output<Tensor>("Out");
// allocate memory on device.
Y->mutable_data<T>(context.GetPlace());
Out->mutable_data<T>(context.GetPlace());
math::SoftmaxFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), X, Y);
context.template device_context<DeviceContext>(), X, Out);
}
};
......@@ -40,15 +40,15 @@ template <typename DeviceContext, typename T>
class SoftmaxGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* Y = context.Input<Tensor>("Y");
auto* dY = context.Input<Tensor>(framework::GradVarName("Y"));
auto* Out = context.Input<Tensor>("Out");
auto* dOut = context.Input<Tensor>(framework::GradVarName("Out"));
auto* dX = context.Output<Tensor>(framework::GradVarName("X"));
// allocate memory on device.
dX->mutable_data<T>(context.GetPlace());
math::SoftmaxGradFunctor<DeviceContext, T>()(
context.template device_context<DeviceContext>(), Y, dY, dX);
context.template device_context<DeviceContext>(), Out, dOut, dX);
}
};
......
......@@ -180,10 +180,22 @@ def save_inference_model(dirname,
:return: None
"""
if isinstance(feeded_var_names, basestring):
feeded_var_names = [feeded_var_names]
else:
if not (bool(feeded_var_names) and all(
isinstance(name, basestring) for name in feeded_var_names)):
raise ValueError("'feed_var_names' should be a list of str.")
if isinstance(target_vars, Variable):
feeded_var_names = [feeded_var_names]
else:
if not (bool(target_vars) and all(
isinstance(var, Variable) for var in target_vars)):
raise ValueError("'target_vars' should be a list of Variable.")
if main_program is None:
main_program = default_main_program()
if not isinstance(target_vars, list):
target_vars = [target_vars]
if not os.path.isdir(dirname):
os.makedirs(dirname)
......
......@@ -184,7 +184,7 @@ class LayerHelper(object):
self.append_op(
type=act_type,
inputs={"X": [input_var]},
outputs={"Y": [tmp]},
outputs={"Out": [tmp]},
attrs=act)
return tmp
......
......@@ -386,7 +386,8 @@ def square_error_cost(input, label, **kwargs):
square_out = helper.create_tmp_variable(dtype=input.dtype)
helper.append_op(
type='square', inputs={'X': [minus_out]}, outputs={'Y': [square_out]})
type='square', inputs={'X': [minus_out]},
outputs={'Out': [square_out]})
return square_out
......@@ -604,7 +605,7 @@ def sequence_pool(input, pool_type, **kwargs):
sqrt : out.data = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2),
6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2)
max : out.data = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
pool_type (string): The pooling type of sequence_pool.
......@@ -616,7 +617,7 @@ def sequence_pool(input, pool_type, **kwargs):
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
avg_x = fluid.layers.sequence_pool(input=x, pool_type='average')
......@@ -654,7 +655,7 @@ def sequence_first_step(input, **kwargs):
out.dim = [3, 1]
with condition len(x.lod[-1]) - 1 == out.dims[0]
out.data = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
......@@ -664,7 +665,7 @@ def sequence_first_step(input, **kwargs):
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
x_first_step = fluid.layers.sequence_first_step(input=x)
......@@ -687,7 +688,7 @@ def sequence_last_step(input, **kwargs):
out.dim = [3, 1]
with condition len(x.lod[-1]) - 1 == out.dims[0]
out.data = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1)
Args:
input(variable): The input variable which is a LoDTensor.
......@@ -697,7 +698,7 @@ def sequence_last_step(input, **kwargs):
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[7, 1],
dtype='float32', lod_level=1)
x_last_step = fluid.layers.sequence_last_step(input=x)
......@@ -1132,7 +1133,7 @@ def reduce_sum(input, dim=None, keep_dim=False):
Returns:
Variable: The reduced Tensor variable.
Examples:
.. code-block:: python
......@@ -1176,7 +1177,7 @@ def reduce_mean(input, dim=None, keep_dim=False):
Returns:
Variable: The reduced Tensor variable.
Examples:
.. code-block:: python
......
......@@ -10,13 +10,13 @@ class TestExp(OpTest):
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.exp(self.inputs['X'])}
self.outputs = {'Out': np.exp(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestSigmoid(OpTest):
......@@ -25,13 +25,13 @@ class TestSigmoid(OpTest):
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))}
self.outputs = {'Out': 1 / (1 + np.exp(-self.inputs['X']))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008)
self.check_grad(['X'], 'Out', max_relative_error=0.008)
class TestLogSigmoid(OpTest):
......@@ -40,13 +40,13 @@ class TestLogSigmoid(OpTest):
self.inputs = {
'X': np.random.uniform(-1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.log(1 / (1 + np.exp(-self.inputs['X'])))}
self.outputs = {'Out': np.log(1 / (1 + np.exp(-self.inputs['X'])))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008)
self.check_grad(['X'], 'Out', max_relative_error=0.008)
class TestTanh(OpTest):
......@@ -55,13 +55,13 @@ class TestTanh(OpTest):
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.tanh(self.inputs['X'])}
self.outputs = {'Out': np.tanh(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestTanhShrink(OpTest):
......@@ -70,13 +70,13 @@ class TestTanhShrink(OpTest):
self.inputs = {
'X': np.random.uniform(0.1, 1, [10, 17]).astype("float32")
}
self.outputs = {'Y': self.inputs['X'] - np.tanh(self.inputs['X'])}
self.outputs = {'Out': self.inputs['X'] - np.tanh(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008)
self.check_grad(['X'], 'Out', max_relative_error=0.008)
class TestHardShrink(OpTest):
......@@ -90,13 +90,13 @@ class TestHardShrink(OpTest):
t = np.copy(x)
t[(t >= -threshold) & (t <= threshold)] = 0
self.outputs = {'Y': t}
self.outputs = {'Out': t}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.005)
self.check_grad(['X'], 'Out', max_relative_error=0.005)
class TestSoftShrink(OpTest):
......@@ -110,13 +110,13 @@ class TestSoftShrink(OpTest):
y = np.copy(self.inputs['X'])
y = (y < -lambda_val) * (y + lambda_val) + (y > lambda_val) * (
y - lambda_val)
self.outputs = {'Y': y}
self.outputs = {'Out': y}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestSqrt(OpTest):
......@@ -125,13 +125,13 @@ class TestSqrt(OpTest):
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.sqrt(self.inputs['X'])}
self.outputs = {'Out': np.sqrt(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestAbs(OpTest):
......@@ -144,13 +144,13 @@ class TestAbs(OpTest):
# we should avoid this
x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x}
self.outputs = {'Y': np.abs(self.inputs['X'])}
self.outputs = {'Out': np.abs(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestCeil(OpTest):
......@@ -158,13 +158,13 @@ class TestCeil(OpTest):
self.op_type = "ceil"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x}
self.outputs = {'Y': np.ceil(self.inputs['X'])}
self.outputs = {'Out': np.ceil(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestFloor(OpTest):
......@@ -173,13 +173,13 @@ class TestFloor(OpTest):
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x}
# numpy floor need +1
self.outputs = {'Y': np.floor(self.inputs['X']) + 1.0}
self.outputs = {'Out': np.floor(self.inputs['X']) + 1.0}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestRound(OpTest):
......@@ -187,13 +187,13 @@ class TestRound(OpTest):
self.op_type = "round"
x = np.random.uniform(-1, 1, [4, 4]).astype("float32")
self.inputs = {'X': x}
self.outputs = {'Y': np.round(self.inputs['X'])}
self.outputs = {'Out': np.round(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestRelu(OpTest):
......@@ -203,13 +203,13 @@ class TestRelu(OpTest):
# The same reason with TestAbs
x[np.abs(x) < 0.005] = 0.02
self.inputs = {'X': x}
self.outputs = {'Y': np.maximum(self.inputs['X'], 0)}
self.outputs = {'Out': np.maximum(self.inputs['X'], 0)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestBRelu(OpTest):
......@@ -227,13 +227,13 @@ class TestBRelu(OpTest):
t = np.copy(x)
t[t < t_min] = t_min
t[t > t_max] = t_max
self.outputs = {'Y': t}
self.outputs = {'Out': t}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestRelu6(OpTest):
......@@ -248,14 +248,14 @@ class TestRelu6(OpTest):
self.inputs = {'X': x}
self.attrs = {'threshold': threshold}
self.outputs = {
'Y': np.minimum(np.maximum(self.inputs['X'], 0), threshold)
'Out': np.minimum(np.maximum(self.inputs['X'], 0), threshold)
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestSoftRelu(OpTest):
......@@ -271,13 +271,13 @@ class TestSoftRelu(OpTest):
t = np.copy(x)
t[t < -threshold] = -threshold
t[t > threshold] = threshold
self.outputs = {'Y': np.log((np.exp(t) + 1))}
self.outputs = {'Out': np.log((np.exp(t) + 1))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestELU(OpTest):
......@@ -290,27 +290,27 @@ class TestELU(OpTest):
self.inputs = {'X': x}
self.attrs = {'alpha': alpha}
self.outputs = {
'Y': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
'Out': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1))
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestReciprocal(OpTest):
def setUp(self):
self.op_type = "reciprocal"
self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")}
self.outputs = {'Y': np.reciprocal(self.inputs['X'])}
self.outputs = {'Out': np.reciprocal(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.01)
self.check_grad(['X'], 'Out', max_relative_error=0.01)
class TestLog(OpTest):
......@@ -319,13 +319,13 @@ class TestLog(OpTest):
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.log(self.inputs['X'])}
self.outputs = {'Out': np.log(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestSquare(OpTest):
......@@ -334,13 +334,13 @@ class TestSquare(OpTest):
self.inputs = {
'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32")
}
self.outputs = {'Y': np.square(self.inputs['X'])}
self.outputs = {'Out': np.square(self.inputs['X'])}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestPow(OpTest):
......@@ -348,13 +348,13 @@ class TestPow(OpTest):
self.op_type = "pow"
self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")}
self.attrs = {'factor': 3.0}
self.outputs = {'Y': np.power(self.inputs['X'], 3)}
self.outputs = {'Out': np.power(self.inputs['X'], 3)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.02)
self.check_grad(['X'], 'Out', max_relative_error=0.02)
class TestSTanh(OpTest):
......@@ -366,13 +366,13 @@ class TestSTanh(OpTest):
scale_a = 2.0 / 3.0
scale_b = 1.7159
self.attrs = {'scale_a': scale_a, 'scale_b': scale_b}
self.outputs = {'Y': scale_b * np.tanh(self.inputs['X'] * scale_a)}
self.outputs = {'Out': scale_b * np.tanh(self.inputs['X'] * scale_a)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestSoftplus(OpTest):
......@@ -381,13 +381,13 @@ class TestSoftplus(OpTest):
self.inputs = {
'X': np.random.uniform(-1, 1, [11, 17]).astype("float64")
}
self.outputs = {'Y': np.log(1 + np.exp(self.inputs['X']))}
self.outputs = {'Out': np.log(1 + np.exp(self.inputs['X']))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestSoftsign(OpTest):
......@@ -397,14 +397,14 @@ class TestSoftsign(OpTest):
'X': np.random.uniform(-1, 1, [11, 17]).astype("float32")
}
self.outputs = {
'Y': np.divide(self.inputs['X'], 1 + np.abs(self.inputs['X']))
'Out': np.divide(self.inputs['X'], 1 + np.abs(self.inputs['X']))
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.007)
self.check_grad(['X'], 'Out', max_relative_error=0.007)
class TestThresholdedRelu(OpTest):
......@@ -419,13 +419,13 @@ class TestThresholdedRelu(OpTest):
self.inputs = {'X': X}
self.attrs = {'threshold': threshold}
self.outputs = {'Y': (X > threshold) * X}
self.outputs = {'Out': (X > threshold) * X}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=self.relative_error)
self.check_grad(['X'], 'Out', max_relative_error=self.relative_error)
class TestHardSigmoid(OpTest):
......@@ -447,13 +447,13 @@ class TestHardSigmoid(OpTest):
upper_threshold - 0.2
temp = X * slope + offset
self.outputs = {'Y': np.maximum(0.0, np.minimum(1.0, temp))}
self.outputs = {'Out': np.maximum(0.0, np.minimum(1.0, temp))}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.002)
self.check_grad(['X'], 'Out', max_relative_error=0.002)
class TestSwish(OpTest):
......@@ -462,13 +462,13 @@ class TestSwish(OpTest):
X = np.random.uniform(0.1, 1, [11, 17]).astype("float32")
self.inputs = {'X': X}
self.attrs = {'beta': 2.3}
self.outputs = {'Y': X * expit(self.attrs['beta'] * X)}
self.outputs = {'Out': X * expit(self.attrs['beta'] * X)}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.008)
self.check_grad(['X'], 'Out', max_relative_error=0.008)
if __name__ == "__main__":
......
......@@ -7,7 +7,7 @@ def fc(X, W, Y):
ret_v = core.Net.create()
ret_v.append_op(Operator("mul", X="X", Y="W", Out="pre_activation"))
ret_v.append_op(Operator("sigmoid", X="pre_activation", Y=Y))
ret_v.append_op(Operator("sigmoid", X="pre_activation", Out=Y))
ret_v.complete_add_op(True)
return ret_v
......@@ -30,7 +30,7 @@ Op(plain_net), inputs:{all[W, X, Y]}, outputs:{all[Out, fc.out, pre_activation]}
Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}.
Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}.
Op(mul), inputs:{X[X], Y[W]}, outputs:{Out[pre_activation]}.
Op(sigmoid), inputs:{X[pre_activation]}, outputs:{Y[fc.out]}.
Op(sigmoid), inputs:{X[pre_activation]}, outputs:{Out[fc.out]}.
'''
self.assertEqual(expected, "\n" + str(net))
......
......@@ -17,14 +17,14 @@ class TestSoftmaxOp(OpTest):
'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32")
}
self.outputs = {
'Y': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X'])
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
self.check_grad(['X'], 'Out')
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册