diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 968fefcfafb53f697839b3c2f9b81b700ae1726c..7436e8c228db2caeb1421f8d78ddcf55f00deee4 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -21,6 +21,8 @@ cc_test(variable_test SRCS variable_test.cc) cc_library(scope SRCS scope.cc DEPS glog) cc_test(scope_test SRCS scope_test.cc DEPS scope) +cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto) +cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context) cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc @@ -29,7 +31,8 @@ cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(shape_inference SRCS shape_inference.cc DEPS ddim attribute) -cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog shape_inference) +cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog + shape_inference data_transform) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry init) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS shape_inference op_info operator glog) @@ -65,6 +68,3 @@ cc_library(init SRCS init.cc DEPS gflags device_context place stringpiece) cc_test(init_test SRCS init_test.cc DEPS init) cc_test(op_kernel_type_test SRCS op_kernel_type_test.cc DEPS place device_context framework_proto) - -cc_library(data_transform SRCS data_transform.cc DEPS tensor framework_proto) -cc_test(data_transform_test SRCS data_transform_test.cc DEPS data_transform device_context) diff --git a/paddle/framework/data_transform.h b/paddle/framework/data_transform.h index c83c08ba5cee8d26bd9b951b53763801f99124cc..73f894a3e20ab779f8607e63a67139b0e8cce79a 100644 --- a/paddle/framework/data_transform.h +++ b/paddle/framework/data_transform.h @@ -32,17 +32,16 @@ using DataTransformFN = const Variable& in, Variable* out)>; using KernelTypePair = std::pair; -static void hash_combine(std::size_t& seed, const OpKernelType& t) { - OpKernelType::Hash kernel_type_hasher; - seed ^= kernel_type_hasher(t) + 0x9e3779b9 + (seed << 6) + (seed >> 2); -} - struct KernelTypePairHash { + static void HashCombine(const OpKernelType& t, std::size_t* seed) { + OpKernelType::Hash kernel_type_hasher; + (*seed) ^= kernel_type_hasher(t) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); + } + size_t operator()(const KernelTypePair& kernel_pair) const { std::size_t seed = 0; - hash_combine(seed, kernel_pair.first); - hash_combine(seed, kernel_pair.second); - + HashCombine(kernel_pair.first, &seed); + HashCombine(kernel_pair.second, &seed); return seed; } }; diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 66840a2e037e7ca0fd1eacc64421865b170b47f8..886f73e7b81c35cac573bd041e6462eb2111bf85 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include #include +#include "paddle/framework/data_transform.h" #include "paddle/framework/executor.h" #include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/operator.h" @@ -411,7 +412,38 @@ void OperatorWithKernel::Run(const Scope& scope, expected_kernel_key); } - kernel_iter->second->Compute(ctx); + if (actual_kernel_key == expected_kernel_key) { + kernel_iter->second->Compute(ctx); + } else { + Scope& op_scope = scope.NewScope(); + auto input_vars = this->InputVars(); + for (auto var_name : input_vars) { + op_scope.Var(var_name); + } + + // TODO(qijun) get appropriate DeviceContext from DeviceContext pool + platform::DeviceContext* trans_dev_ctx = nullptr; + std::vector trans_dev_ctx_vec{trans_dev_ctx}; + + // TODO(qijun) get appropriate DataTransformFN from global map + framework::DataTransformFN trans_fun = nullptr; + + // Wait for transform starting + dev_ctx->Wait(); + + for (auto var_name : input_vars) { + trans_fun(trans_dev_ctx_vec, *(scope.FindVar(var_name)), + op_scope.FindVar(var_name)); + } + // Wait for data transform finishing + for (auto ctx : trans_dev_ctx_vec) { + ctx->Wait(); + } + + // Create a new ExecutionContext + ExecutionContext op_ctx(*this, op_scope, *dev_ctx); + kernel_iter->second->Compute(op_ctx); + } } OpKernelType OperatorWithKernel::GetActualKernelType( diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index 55f673c22df53027ccf1a39f8b70b766ea45d8f0..4188858a90daf8b2c10eb6960393de977d467371 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -22,8 +22,8 @@ class ActivationOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - ctx->SetOutputDim("Y", ctx->GetInputDim("X")); - ctx->ShareLoD("X", /*->*/ "Y"); + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + ctx->ShareLoD("X", /*->*/ "Out"); } }; @@ -32,7 +32,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext *ctx) const override { - ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y")); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); } }; @@ -41,11 +41,11 @@ class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { SigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Sigmoid operator"); - AddOutput("Y", "Output of Sigmoid operator"); + AddOutput("Out", "Output of Sigmoid operator"); AddComment(R"DOC( Sigmoid Activation Operator -$$y = \frac{1}{1 + e^{-x}}$$ +$$out = \frac{1}{1 + e^{-x}}$$ )DOC"); } @@ -56,11 +56,11 @@ class LogSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { LogSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of LogSigmoid operator"); - AddOutput("Y", "Output of LogSigmoid operator"); + AddOutput("Out", "Output of LogSigmoid operator"); AddComment(R"DOC( Logsigmoid Activation Operator -$$y = \log \frac{1}{1 + e^{-x}}$$ +$$out = \log \frac{1}{1 + e^{-x}}$$ )DOC"); } @@ -71,11 +71,11 @@ class ExpOpMaker : public framework::OpProtoAndCheckerMaker { ExpOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Exp operator"); - AddOutput("Y", "Output of Exp operator"); + AddOutput("Out", "Output of Exp operator"); AddComment(R"DOC( Exp Activation Operator. -$y = e^x$ +$out = e^x$ )DOC"); } @@ -86,11 +86,11 @@ class ReluOpMaker : public framework::OpProtoAndCheckerMaker { ReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Relu operator"); - AddOutput("Y", "Output of Relu operator"); + AddOutput("Out", "Output of Relu operator"); AddComment(R"DOC( Relu Activation Operator. -$y = \max(x, 0)$ +$out = \max(x, 0)$ )DOC"); } @@ -101,12 +101,12 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { LeakyReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of LeakyRelu operator"); - AddOutput("Y", "Output of LeakyRelu operator"); + AddOutput("Out", "Output of LeakyRelu operator"); AddAttr("alpha", "The small negative slope").SetDefault(0.02f); AddComment(R"DOC( LeakyRelu Activation Operator. -$y = \max(x, \alpha * x)$ +$out = \max(x, \alpha * x)$ )DOC"); } @@ -117,13 +117,13 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { SoftShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Softshrink operator"); - AddOutput("Y", "Output of Softshrink operator"); + AddOutput("Out", "Output of Softshrink operator"); AddAttr("lambda", "non-negative offset").SetDefault(0.5f); AddComment(R"DOC( Softshrink Activation Operator. $$ -y = \begin{cases} +out = \begin{cases} x - \lambda, \text{if } x > \lambda \\ x + \lambda, \text{if } x < -\lambda \\ 0, \text{otherwise} @@ -139,11 +139,11 @@ class TanhOpMaker : public framework::OpProtoAndCheckerMaker { TanhOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Tanh operator"); - AddOutput("Y", "Output of Tanh operator"); + AddOutput("Out", "Output of Tanh operator"); AddComment(R"DOC( Tanh Activation Operator. -$$y = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ +$$out = \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ )DOC"); } @@ -154,11 +154,11 @@ class TanhShrinkOpMaker : public framework::OpProtoAndCheckerMaker { TanhShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of TanhShrink operator"); - AddOutput("Y", "Output of TanhShrink operator"); + AddOutput("Out", "Output of TanhShrink operator"); AddComment(R"DOC( TanhShrink Activation Operator. -$$y = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ +$$out = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ )DOC"); } @@ -169,14 +169,14 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { HardShrinkOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of HardShrink operator"); - AddOutput("Y", "Output of HardShrink operator"); + AddOutput("Out", "Output of HardShrink operator"); AddAttr("threshold", "The value of threshold for HardShrink") .SetDefault(0.5f); AddComment(R"DOC( HardShrink Activation Operator. $$ -y = \begin{cases} +out = \begin{cases} x, \text{if } x > \lambda \\ x, \text{if } x < -\lambda \\ 0, \text{otherwise} @@ -192,11 +192,11 @@ class SqrtOpMaker : public framework::OpProtoAndCheckerMaker { SqrtOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Sqrt operator"); - AddOutput("Y", "Output of Sqrt operator"); + AddOutput("Out", "Output of Sqrt operator"); AddComment(R"DOC( Sqrt Activation Operator. -$y = \sqrt{x}$ +$out = \sqrt{x}$ )DOC"); } @@ -207,11 +207,11 @@ class AbsOpMaker : public framework::OpProtoAndCheckerMaker { AbsOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Abs operator"); - AddOutput("Y", "Output of Abs operator"); + AddOutput("Out", "Output of Abs operator"); AddComment(R"DOC( Abs Activation Operator. -$y = |x|$ +$out = |x|$ )DOC"); } @@ -222,11 +222,11 @@ class CeilOpMaker : public framework::OpProtoAndCheckerMaker { CeilOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Ceil operator"); - AddOutput("Y", "Output of Ceil operator"); + AddOutput("Out", "Output of Ceil operator"); AddComment(R"DOC( Ceil Activation Operator. -$y = ceil(x)$ +$out = ceil(x)$ )DOC"); } @@ -237,11 +237,11 @@ class FloorOpMaker : public framework::OpProtoAndCheckerMaker { FloorOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Floor operator"); - AddOutput("Y", "Output of Floor operator"); + AddOutput("Out", "Output of Floor operator"); AddComment(R"DOC( Floor Activation Operator. -$y = floor(x)$ +$out = floor(x)$ )DOC"); } @@ -252,11 +252,11 @@ class RoundOpMaker : public framework::OpProtoAndCheckerMaker { RoundOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Round operator"); - AddOutput("Y", "Output of Round operator"); + AddOutput("Out", "Output of Round operator"); AddComment(R"DOC( Round Activation Operator. -$y = [x]$ +$out = [x]$ )DOC"); } @@ -267,11 +267,11 @@ class ReciprocalOpMaker : public framework::OpProtoAndCheckerMaker { ReciprocalOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Reciprocal operator"); - AddOutput("Y", "Output of Reciprocal operator"); + AddOutput("Out", "Output of Reciprocal operator"); AddComment(R"DOC( Reciprocal Activation Operator. -$$y = \frac{1}{x}$$ +$$out = \frac{1}{x}$$ )DOC"); } @@ -282,11 +282,11 @@ class LogOpMaker : public framework::OpProtoAndCheckerMaker { LogOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Log operator"); - AddOutput("Y", "Output of Log operator"); + AddOutput("Out", "Output of Log operator"); AddComment(R"DOC( Log Activation Operator. -$y = \ln(x)$ +$out = \ln(x)$ Natural logarithm of x. @@ -299,11 +299,11 @@ class SquareOpMaker : public framework::OpProtoAndCheckerMaker { SquareOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Square operator"); - AddOutput("Y", "Output of Square operator"); + AddOutput("Out", "Output of Square operator"); AddComment(R"DOC( Square Activation Operator. -$y = x^2$ +$out = x^2$ )DOC"); } @@ -314,11 +314,11 @@ class SoftplusOpMaker : public framework::OpProtoAndCheckerMaker { SoftplusOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Softplus operator"); - AddOutput("Y", "Output of Softplus operator"); + AddOutput("Out", "Output of Softplus operator"); AddComment(R"DOC( Softplus Activation Operator. -$y = \ln(1 + e^{x})$ +$out = \ln(1 + e^{x})$ )DOC"); } @@ -329,11 +329,11 @@ class SoftsignOpMaker : public framework::OpProtoAndCheckerMaker { SoftsignOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Softsign operator"); - AddOutput("Y", "Output of Softsign operator"); + AddOutput("Out", "Output of Softsign operator"); AddComment(R"DOC( Softsign Activation Operator. -$$y = \frac{x}{1 + |x|}$$ +$$out = \frac{x}{1 + |x|}$$ )DOC"); } @@ -344,7 +344,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker { BReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of BRelu operator"); - AddOutput("Y", "Output of BRelu operator"); + AddOutput("Out", "Output of BRelu operator"); AddAttr("t_min", "The min marginal value of BRelu") .SetDefault(static_cast(0)); AddAttr("t_max", "The max marginal value of BRelu") @@ -352,7 +352,7 @@ class BReluOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( BRelu Activation Operator. -$y = \max(\min(x, t_{min}), t_{max})$ +$out = \max(\min(x, t_{min}), t_{max})$ )DOC"); } @@ -363,13 +363,13 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { SoftReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of SoftRelu operator"); - AddOutput("Y", "Output of SoftRelu operator"); + AddOutput("Out", "Output of SoftRelu operator"); AddAttr("threshold", "The threshold value of SoftRelu") .SetDefault(40.0f); AddComment(R"DOC( SoftRelu Activation Operator. -$y = \ln(1 + \exp(\max(\min(x, threshold), threshold))$ +$out = \ln(1 + \exp(\max(\min(x, threshold), threshold))$ )DOC"); } @@ -380,7 +380,7 @@ class ELUOpMaker : public framework::OpProtoAndCheckerMaker { ELUOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of ELU operator"); - AddOutput("Y", "Output of ELU operator"); + AddOutput("Out", "Output of ELU operator"); AddAttr("alpha", "The alpha value of ELU").SetDefault(1.0f); AddComment(R"DOC( ELU Activation Operator. @@ -388,7 +388,7 @@ ELU Activation Operator. Applies the following element-wise computation on the input according to https://arxiv.org/abs/1511.07289. -$y = \max(0, x) + \min(0, \alpha * (e^x - 1))$ +$out = \max(0, x) + \min(0, \alpha * (e^x - 1))$ )DOC"); } @@ -399,13 +399,13 @@ class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { Relu6OpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Relu6 operator"); - AddOutput("Y", "Output of Relu6 operator"); + AddOutput("Out", "Output of Relu6 operator"); AddAttr("threshold", "The threshold value of Relu6") .SetDefault(6.0f); AddComment(R"DOC( Relu6 Activation Operator. -$y = \min(\max(0, x), 6)$ +$out = \min(\max(0, x), 6)$ )DOC"); } @@ -416,12 +416,12 @@ class PowOpMaker : public framework::OpProtoAndCheckerMaker { PowOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Pow operator"); - AddOutput("Y", "Output of Pow operator"); + AddOutput("Out", "Output of Pow operator"); AddAttr("factor", "The exponential factor of Pow").SetDefault(1.0f); AddComment(R"DOC( Pow Activation Operator. -$y = x^{factor}$ +$out = x^{factor}$ )DOC"); } @@ -432,7 +432,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { STanhOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of STanh operator"); - AddOutput("Y", "Output of STanh operator"); + AddOutput("Out", "Output of STanh operator"); AddAttr("scale_a", "The scale parameter of a for the input") .SetDefault(2.0f / 3.0f); AddAttr("scale_b", "The scale parameter of b for the input") @@ -440,7 +440,7 @@ class STanhOpMaker : public framework::OpProtoAndCheckerMaker { AddComment(R"DOC( STanh Activation Operator. -$$y = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$ +$$out = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$ )DOC"); } @@ -451,14 +451,14 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { ThresholdedReluOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of ThresholdedRelu operator"); - AddOutput("Y", "Output of ThresholdedRelu operator"); + AddOutput("Out", "Output of ThresholdedRelu operator"); AddAttr("threshold", "The threshold location of activation") .SetDefault(1.0f); AddComment(R"DOC( ThresholdedRelu Activation Operator. $$ -y = \begin{cases} +out = \begin{cases} x, \text{if } x > threshold \\ 0, \text{otherwise} \end{cases} @@ -473,7 +473,7 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { HardSigmoidOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of HardSigmoid operator"); - AddOutput("Y", "Output of HardSigmoid operator"); + AddOutput("Out", "Output of HardSigmoid operator"); AddAttr("slope", "Slope for linear approximation of sigmoid") .SetDefault(0.2f); AddAttr("offset", "Offset for linear approximation of sigmoid") @@ -484,7 +484,7 @@ HardSigmoid Activation Operator. Segment-wise linear approximation of sigmoid(https://arxiv.org/abs/1603.00391), which is much faster than sigmoid. -$y = \max(0, \min(1, slope * x + shift))$ +$out = \max(0, \min(1, slope * x + shift))$ The slope should be positive. The offset can be either positive or negative. The default slope and shift are set according to the above reference. @@ -499,12 +499,12 @@ class SwishOpMaker : public framework::OpProtoAndCheckerMaker { SwishOpMaker(OpProto *proto, OpAttrChecker *op_checker) : framework::OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", "Input of Swish operator"); - AddOutput("Y", "Output of Swish operator"); + AddOutput("Out", "Output of Swish operator"); AddAttr("beta", "Constant beta of swish operator").SetDefault(1.0f); AddComment(R"DOC( Swish Activation Operator. -$$y = \frac{x}{1 + e^{- \beta x}}$$ +$$out = \frac{x}{1 + e^{- \beta x}}$$ )DOC"); } diff --git a/paddle/operators/activation_op.h b/paddle/operators/activation_op.h index 7b433df995afae7340ac97f721f962efdd2f07c0..0885f7c570b9b52dc51597347295734fd689da8d 100644 --- a/paddle/operators/activation_op.h +++ b/paddle/operators/activation_op.h @@ -27,11 +27,11 @@ class ActivationKernel void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); - auto* Y = context.Output("Y"); - Y->mutable_data(context.GetPlace()); + auto* Out = context.Output("Out"); + Out->mutable_data(context.GetPlace()); auto x = framework::EigenVector::Flatten(*X); - auto y = framework::EigenVector::Flatten(*Y); + auto out = framework::EigenVector::Flatten(*Out); auto* place = context.template device_context().eigen_device(); Functor functor; @@ -40,7 +40,7 @@ class ActivationKernel for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } - functor(*place, x, y); + functor(*place, x, out); } }; @@ -51,14 +51,15 @@ class ActivationGradKernel using T = typename Functor::ELEMENT_TYPE; void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); - auto* Y = context.Input("Y"); - auto* dY = context.Input(framework::GradVarName("Y")); + auto* Out = context.Input("Out"); + auto* dOut = + context.Input(framework::GradVarName("Out")); auto* dX = context.Output(framework::GradVarName("X")); dX->mutable_data(context.GetPlace()); - auto dy = framework::EigenVector::Flatten(*dY); + auto dout = framework::EigenVector::Flatten(*dOut); auto x = framework::EigenVector::Flatten(*X); - auto y = framework::EigenVector::Flatten(*Y); + auto out = framework::EigenVector::Flatten(*Out); auto dx = framework::EigenVector::Flatten(*dX); auto* place = context.template device_context().eigen_device(); @@ -67,7 +68,7 @@ class ActivationGradKernel for (auto& attr : attrs) { *attr.second = context.Attr(attr.first); } - functor(*place, x, y, dy, dx); + functor(*place, x, out, dout, dx); } }; @@ -83,17 +84,18 @@ struct BaseActivationFunctor { // sigmoid(x) = 1 / (1 + exp(-x)) template struct SigmoidFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = static_cast(1) / (static_cast(1) + (-x).exp()); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = static_cast(1) / (static_cast(1) + (-x).exp()); } }; template struct SigmoidGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * y * (static_cast(1) - y); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * out * (static_cast(1) - out); } }; @@ -101,7 +103,7 @@ struct SigmoidGradFunctor : public BaseActivationFunctor { // For numerical stability, we can use the log-sum-exp trick: // https://hips.seas.harvard.edu/blog/2013/01/09/computing-log-sum-exp/ // We can rewrite the above equation as: -// y = -log( exp(0) + exp(-x)) [since exp(0) = 1] +// out = -log( exp(0) + exp(-x)) [since exp(0) = 1] // = -log( exp(max(-x, 0) - max(-x, 0)) + exp(-x + max(-x, 0) - max(-x, 0))) // = -log( exp(max(-x, 0)) * exp(-max(-x, 0)) - exp(max(-x, 0)) * exp(-x - // max(-x, 0))) @@ -112,10 +114,10 @@ struct SigmoidGradFunctor : public BaseActivationFunctor { // + exp(-x - max(-x, 0)))) template struct LogSigmoidFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { + template + void operator()(Device d, X x, Out out) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) - y.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log()); + out.device(d) = -temp - (((-temp).exp() + (-x - temp).exp()).log()); } }; @@ -124,62 +126,66 @@ struct LogSigmoidFunctor : public BaseActivationFunctor { // exp(-x - max(-x, 0))) template struct LogSigmoidGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp = (-x).cwiseMax(static_cast(0)); // temp = max(-x, 0) dx.device(d) = - dy * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); + dout * ((-x - temp).exp() / ((-temp).exp() + (-x - temp).exp())); } }; // exp(x) = e^x template struct ExpFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.exp(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.exp(); } }; template struct ExpGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * y; + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * out; } }; // relu(x) = max(x, 0) template struct ReluFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.cwiseMax(static_cast(0)); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(0)); } }; template struct ReluGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * (x > static_cast(0)).template cast(); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (x > static_cast(0)).template cast(); } }; // tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.tanh(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.tanh(); } }; template struct TanhGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * (static_cast(1) - y * y); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast(1) - out * out); } }; @@ -187,17 +193,18 @@ struct TanhGradFunctor : public BaseActivationFunctor { // where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) template struct TanhShrinkFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x - x.tanh(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x - x.tanh(); } }; template struct TanhShrinkGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * (x.tanh() * x.tanh()); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (x.tanh() * x.tanh()); } }; @@ -210,11 +217,11 @@ struct HardShrinkFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y) const { + template + void operator()(Device d, X x, Out out) const { auto temp1 = (x < static_cast(threshold * -1)).template cast().eval(); auto temp2 = (x > static_cast(threshold)).template cast().eval(); - y.device(d) = x * (temp1 + temp2); + out.device(d) = x * (temp1 + temp2); } }; @@ -226,11 +233,12 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = (x < static_cast(threshold * -1)).template cast().eval(); auto temp2 = (x > static_cast(threshold)).template cast().eval(); - dx.device(d) = dy * (temp1 + temp2).template cast(); + dx.device(d) = dout * (temp1 + temp2).template cast(); } }; @@ -243,12 +251,12 @@ struct SoftShrinkFunctor : public BaseActivationFunctor { return {{"lambda", &lambda}}; } - template - void operator()(Device d, X x, Y y) const { + template + void operator()(Device d, X x, Out out) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast().eval(); auto temp2 = (x < -lambdaT).template cast().eval(); - y.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); + out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT); } }; @@ -258,46 +266,49 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"lambda", &lambda}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto lambdaT = static_cast(lambda); auto temp1 = (x > lambdaT).template cast().eval(); auto temp2 = (x < -lambdaT).template cast().eval(); - dx.device(d) = dy * (temp1 + temp2).template cast(); + dx.device(d) = dout * (temp1 + temp2).template cast(); } }; // sqrt(x) = x^(1/2) template struct SqrtFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.sqrt(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.sqrt(); } }; template struct SqrtGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - const Y y_conj = Eigen::numext::conj(y); - dx.device(d) = static_cast(0.5) * dy / y_conj; + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + const Out out_conj = Eigen::numext::conj(out); + dx.device(d) = static_cast(0.5) * dout / out_conj; } }; // ceil(x) = ceiling(x) template struct CeilFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.ceil(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.ceil(); } }; template struct ZeroGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { dx.device(d) = static_cast(0) / x; } }; @@ -305,86 +316,90 @@ struct ZeroGradFunctor : public BaseActivationFunctor { // floor(x) = flooring(x) template struct FloorFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.ceil(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.ceil(); } }; // round(x) = [x] template struct RoundFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.round(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.round(); } }; // abs(x) = |x| template struct AbsFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.abs(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.abs(); } }; template struct AbsGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * x.sign(); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * x.sign(); } }; // reciprocal(x) = 1 / x template struct ReciprocalFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = static_cast(1) / x; + template + void operator()(Device d, X x, Out out) const { + out.device(d) = static_cast(1) / x; } }; template struct ReciprocalGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * static_cast(-1) * y * y; + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(-1) * out * out; } }; // log(x) = natural logarithm of x template struct LogFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.log(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.log(); } }; template struct LogGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * (static_cast(1) / x); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (static_cast(1) / x); } }; // square(x) = x^2 template struct SquareFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.square(); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.square(); } }; template struct SquareGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * static_cast(2) * x; + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(2) * x; } }; @@ -399,9 +414,9 @@ struct BReluFunctor : public BaseActivationFunctor { return {{"t_min", &t_min}, {"t_max", &t_max}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(t_min)).cwiseMin(static_cast(t_max)); } }; @@ -413,9 +428,10 @@ struct BReluGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"t_min", &t_min}, {"t_max", &t_max}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * ((x > static_cast(t_min)) * (x < static_cast(t_max))) .template cast(); } @@ -430,9 +446,9 @@ struct Relu6Functor : public BaseActivationFunctor { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(0)).cwiseMin(static_cast(threshold)); } }; @@ -443,9 +459,10 @@ struct Relu6GradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * ((x > static_cast(0)) * (x < static_cast(threshold))) .template cast(); } @@ -458,10 +475,10 @@ struct Relu6GradFunctor : public BaseActivationFunctor { // Then: softplus(x) = max(x, 0) + log(exp(-max(x, 0)) + exp(x - max(x, 0))) template struct SoftplusFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) { + template + void operator()(Device d, X x, Out out) { auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) - y.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log()); + out.device(d) = temp + (((-temp).exp() + (x - temp).exp()).log()); } }; @@ -471,19 +488,21 @@ struct SoftplusFunctor : public BaseActivationFunctor { // exp(x - max(x, 0))) template struct SoftplusGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) { auto temp = x.cwiseMax(static_cast(0)); // temp = max(x, 0) - dx.device(d) = dy * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); + dx.device(d) = + dout * ((x - temp).exp() / ((-temp).exp() + (x - temp).exp())); } }; // softsign(x) = x / (1 + |x|) template struct SoftsignFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y) { - y.device(d) = x / (static_cast(1) + x.abs()); + template + void operator()(Device d, X x, Out out) { + out.device(d) = x / (static_cast(1) + x.abs()); } }; @@ -491,10 +510,11 @@ struct SoftsignFunctor : public BaseActivationFunctor { // Taken from https://en.wikipedia.org/wiki/Activation_function template struct SoftsignGradFunctor : public BaseActivationFunctor { - template - void operator()(Device d, X x, Y y, dY dy, dX dx) { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) { dx.device(d) = - dy * (static_cast(1) / (static_cast(1) + x.abs()).square()); + dout * (static_cast(1) / (static_cast(1) + x.abs()).square()); } }; @@ -505,11 +525,11 @@ struct SoftReluFunctor : public BaseActivationFunctor { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y) const { + template + void operator()(Device d, X x, Out out) const { auto tmp = static_cast(threshold); auto temp = x.cwiseMax(-tmp).cwiseMin(tmp); - y.device(d) = (static_cast(1) + temp.exp()).log(); + out.device(d) = (static_cast(1) + temp.exp()).log(); } }; @@ -519,11 +539,12 @@ struct SoftReluGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto tmp = static_cast(threshold); auto temp = ((x > -tmp) * (x < tmp)).template cast().eval(); - dx.device(d) = dy * (static_cast(1) - (-y).exp()) * temp; + dx.device(d) = dout * (static_cast(1) - (-out).exp()) * temp; } }; @@ -534,9 +555,9 @@ struct LeakyReluFunctor : public BaseActivationFunctor { return {{"alpha", &alpha}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.cwiseMax(static_cast(alpha) * x); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(alpha) * x); } }; @@ -546,12 +567,13 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(alpha) * (x < static_cast(0)).template cast().eval(); auto temp2 = (x >= static_cast(0)).template cast().eval(); - dx.device(d) = dy * (temp1 + temp2).template cast(); + dx.device(d) = dout * (temp1 + temp2).template cast(); } }; @@ -562,11 +584,11 @@ struct ELUFunctor : public BaseActivationFunctor { return {{"alpha", &alpha}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.cwiseMax(static_cast(0)) + - (static_cast(alpha) * (x.exp() - static_cast(1))) - .cwiseMin(static_cast(0)); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.cwiseMax(static_cast(0)) + + (static_cast(alpha) * (x.exp() - static_cast(1))) + .cwiseMin(static_cast(0)); } }; @@ -576,10 +598,11 @@ struct ELUGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"alpha", &alpha}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * (x > static_cast(0)).template cast() + - dy * (y + static_cast(alpha)) * + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * (x > static_cast(0)).template cast() + + dout * (out + static_cast(alpha)) * (x < static_cast(0)).template cast(); } }; @@ -591,9 +614,9 @@ struct PowFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x.pow(static_cast(factor)); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x.pow(static_cast(factor)); } }; @@ -603,9 +626,10 @@ struct PowGradFunctor : public BaseActivationFunctor { typename BaseActivationFunctor::AttrPair GetAttrs() { return {{"factor", &factor}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = dy * static_cast(factor) * + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * static_cast(factor) * x.pow(static_cast(factor - static_cast(1))); } }; @@ -618,9 +642,9 @@ struct STanhFunctor : public BaseActivationFunctor { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = + template + void operator()(Device d, X x, Out out) const { + out.device(d) = static_cast(scale_b) * (static_cast(scale_a) * x).tanh(); } }; @@ -633,12 +657,13 @@ struct STanhGradFunctor : public BaseActivationFunctor { return {{"scale_a", &scale_a}, {"scale_b", &scale_b}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto a = static_cast(scale_a); auto b = static_cast(scale_b); auto temp = (a * x).tanh() * (a * x).tanh(); - dx.device(d) = dy * a * b * (static_cast(1) - temp); + dx.device(d) = dout * a * b * (static_cast(1) - temp); } }; @@ -649,10 +674,10 @@ struct ThresholdedReluFunctor : public BaseActivationFunctor { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y) const { + template + void operator()(Device d, X x, Out out) const { auto th = static_cast(threshold); - y.device(d) = (x > th).template cast() * x; + out.device(d) = (x > th).template cast() * x; } }; @@ -663,10 +688,11 @@ struct ThresholdedReluGradFunctor : public BaseActivationFunctor { return {{"threshold", &threshold}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto th = static_cast(threshold); - dx.device(d) = dy * (x > th).template cast(); + dx.device(d) = dout * (x > th).template cast(); } }; @@ -678,10 +704,11 @@ struct HardSigmoidFunctor : public BaseActivationFunctor { return {{"slope", &slope}, {"offset", &offset}}; } - template - void operator()(Device d, X x, Y y) const { + template + void operator()(Device d, X x, Out out) const { auto temp = x * static_cast(slope) + static_cast(offset); - y.device(d) = temp.cwiseMax(static_cast(0)).cwiseMin(static_cast(1)); + out.device(d) = + temp.cwiseMax(static_cast(0)).cwiseMin(static_cast(1)); } }; @@ -693,12 +720,13 @@ struct HardSigmoidGradFunctor : public BaseActivationFunctor { return {{"slope", &slope}, {"offset", &offset}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { - dx.device(d) = - dy * - ((y > static_cast(0)) * (y < static_cast(1))).template cast() * - static_cast(slope); + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { + dx.device(d) = dout * + ((out > static_cast(0)) * (out < static_cast(1))) + .template cast() * + static_cast(slope); } }; @@ -709,9 +737,9 @@ struct SwishFunctor : public BaseActivationFunctor { return {{"beta", &beta}}; } - template - void operator()(Device d, X x, Y y) const { - y.device(d) = x / (static_cast(1) + (static_cast(-beta) * x).exp()); + template + void operator()(Device d, X x, Out out) const { + out.device(d) = x / (static_cast(1) + (static_cast(-beta) * x).exp()); } }; @@ -722,12 +750,13 @@ struct SwishGradFunctor : public BaseActivationFunctor { return {{"beta", &beta}}; } - template - void operator()(Device d, X x, Y y, dY dy, dX dx) const { + template + void operator()(Device d, X x, Out out, dOut dout, dX dx) const { auto temp1 = static_cast(1) / (static_cast(1) + (static_cast(-beta) * x).exp()); - auto temp2 = temp1 * (static_cast(1) - (beta * y)); - dx.device(d) = dy * ((beta * y) + temp2); + auto temp2 = temp1 * (static_cast(1) - (beta * out)); + dx.device(d) = dout * ((beta * out) + temp2); } }; diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index 6b3f19bb46c45b7dd8ec6c23ee449521b36d759c..e7306bc5f13377813e0bd49846bc834d501602eb 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -24,13 +24,13 @@ class SoftmaxOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SoftmaxOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Y"), - "Output(Y) of SoftmaxOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of SoftmaxOp should not be null."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE(x_dims.size() == 2UL, "The input of softmax op must be a matrix."); - ctx->SetOutputDim("Y", x_dims); + ctx->SetOutputDim("Out", x_dims); } }; @@ -41,7 +41,7 @@ class SoftmaxOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("X", "The input tensor of softmax. " "2-D with shape [batch_size, input_feature_dimensions]."); - AddOutput("Y", "The normalized values with the same shape as X."); + AddOutput("Out", "The normalized values with the same shape as X."); AddComment(R"DOC( Softmax Operator. @@ -59,7 +59,7 @@ exponential values of all the other dimensions is the output of the softmax operator. For each row $i$ and each column $j$ in Input(X), we have: - $$Y[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$ + $$Out[i, j] = \frac{\exp(X[i, j])}{\sum_j(exp(X[i, j])}$$ )DOC"); } @@ -70,12 +70,12 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); - PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), - "Input(Y@GRAD) should be not null."); - PADDLE_ENFORCE_EQ(ctx->GetInputDim("Y"), - ctx->GetInputDim(framework::GradVarName("Y")), - "Input(Y) and its gradients should have a same shape."); + PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) should be not null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should be not null."); + PADDLE_ENFORCE_EQ(ctx->GetInputDim("Out"), + ctx->GetInputDim(framework::GradVarName("Out")), + "Input(Out) and its gradients should have a same shape."); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } diff --git a/paddle/operators/softmax_op.h b/paddle/operators/softmax_op.h index 0f8998b99e93b5ed6c9b43ad7adabc2d515c1ff1..63e379a3b31a6c75aab0c56b4ce1b988fa7f0318 100644 --- a/paddle/operators/softmax_op.h +++ b/paddle/operators/softmax_op.h @@ -26,13 +26,13 @@ class SoftmaxKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* X = context.Input("X"); - auto* Y = context.Output("Y"); + auto* Out = context.Output("Out"); // allocate memory on device. - Y->mutable_data(context.GetPlace()); + Out->mutable_data(context.GetPlace()); math::SoftmaxFunctor()( - context.template device_context(), X, Y); + context.template device_context(), X, Out); } }; @@ -40,15 +40,15 @@ template class SoftmaxGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { - auto* Y = context.Input("Y"); - auto* dY = context.Input(framework::GradVarName("Y")); + auto* Out = context.Input("Out"); + auto* dOut = context.Input(framework::GradVarName("Out")); auto* dX = context.Output(framework::GradVarName("X")); // allocate memory on device. dX->mutable_data(context.GetPlace()); math::SoftmaxGradFunctor()( - context.template device_context(), Y, dY, dX); + context.template device_context(), Out, dOut, dX); } }; diff --git a/python/paddle/v2/fluid/io.py b/python/paddle/v2/fluid/io.py index e147ac22ad289eb00c83def66974d875fcdc31f8..69a732fc45a1946f260cdd9a9c2da150b87c3ddd 100644 --- a/python/paddle/v2/fluid/io.py +++ b/python/paddle/v2/fluid/io.py @@ -180,10 +180,22 @@ def save_inference_model(dirname, :return: None """ + if isinstance(feeded_var_names, basestring): + feeded_var_names = [feeded_var_names] + else: + if not (bool(feeded_var_names) and all( + isinstance(name, basestring) for name in feeded_var_names)): + raise ValueError("'feed_var_names' should be a list of str.") + + if isinstance(target_vars, Variable): + feeded_var_names = [feeded_var_names] + else: + if not (bool(target_vars) and all( + isinstance(var, Variable) for var in target_vars)): + raise ValueError("'target_vars' should be a list of Variable.") + if main_program is None: main_program = default_main_program() - if not isinstance(target_vars, list): - target_vars = [target_vars] if not os.path.isdir(dirname): os.makedirs(dirname) diff --git a/python/paddle/v2/fluid/layer_helper.py b/python/paddle/v2/fluid/layer_helper.py index a076f26f7ff279812f762bd62f72195dc2376378..4469f7285efe1c31d0955c6dd4ba3ecac08070af 100644 --- a/python/paddle/v2/fluid/layer_helper.py +++ b/python/paddle/v2/fluid/layer_helper.py @@ -184,7 +184,7 @@ class LayerHelper(object): self.append_op( type=act_type, inputs={"X": [input_var]}, - outputs={"Y": [tmp]}, + outputs={"Out": [tmp]}, attrs=act) return tmp diff --git a/python/paddle/v2/fluid/layers/nn.py b/python/paddle/v2/fluid/layers/nn.py index ee9d7554fb1f0cc3b9229f046f8bcdd4dbd7c35d..2a462ee6cb47496b17a8f584e56ac2c8934b319a 100644 --- a/python/paddle/v2/fluid/layers/nn.py +++ b/python/paddle/v2/fluid/layers/nn.py @@ -386,7 +386,8 @@ def square_error_cost(input, label, **kwargs): square_out = helper.create_tmp_variable(dtype=input.dtype) helper.append_op( - type='square', inputs={'X': [minus_out]}, outputs={'Y': [square_out]}) + type='square', inputs={'X': [minus_out]}, + outputs={'Out': [square_out]}) return square_out @@ -604,7 +605,7 @@ def sequence_pool(input, pool_type, **kwargs): sqrt : out.data = [2.82, 6.93, 4.24], where 2.82=(1+3)/sqrt(2), 6.93=(2+4+6)/sqrt(3), 4.24=(5+1)/sqrt(2) max : out.data = [3, 6, 5], where 3=max(1,3), 6=max(2,4,6), 5=max(5,1) - + Args: input(variable): The input variable which is a LoDTensor. pool_type (string): The pooling type of sequence_pool. @@ -616,7 +617,7 @@ def sequence_pool(input, pool_type, **kwargs): Examples: .. code-block:: python - + x = fluid.layers.data(name='x', shape=[7, 1], dtype='float32', lod_level=1) avg_x = fluid.layers.sequence_pool(input=x, pool_type='average') @@ -654,7 +655,7 @@ def sequence_first_step(input, **kwargs): out.dim = [3, 1] with condition len(x.lod[-1]) - 1 == out.dims[0] out.data = [1, 2, 5], where 1=first(1,3), 2=first(2,4,6), 5=first(5,1) - + Args: input(variable): The input variable which is a LoDTensor. @@ -664,7 +665,7 @@ def sequence_first_step(input, **kwargs): Examples: .. code-block:: python - + x = fluid.layers.data(name='x', shape=[7, 1], dtype='float32', lod_level=1) x_first_step = fluid.layers.sequence_first_step(input=x) @@ -687,7 +688,7 @@ def sequence_last_step(input, **kwargs): out.dim = [3, 1] with condition len(x.lod[-1]) - 1 == out.dims[0] out.data = [3, 6, 1], where 3=last(1,3), 6=last(2,4,6), 1=last(5,1) - + Args: input(variable): The input variable which is a LoDTensor. @@ -697,7 +698,7 @@ def sequence_last_step(input, **kwargs): Examples: .. code-block:: python - + x = fluid.layers.data(name='x', shape=[7, 1], dtype='float32', lod_level=1) x_last_step = fluid.layers.sequence_last_step(input=x) @@ -1132,7 +1133,7 @@ def reduce_sum(input, dim=None, keep_dim=False): Returns: Variable: The reduced Tensor variable. - + Examples: .. code-block:: python @@ -1176,7 +1177,7 @@ def reduce_mean(input, dim=None, keep_dim=False): Returns: Variable: The reduced Tensor variable. - + Examples: .. code-block:: python diff --git a/python/paddle/v2/fluid/tests/test_activation_op.py b/python/paddle/v2/fluid/tests/test_activation_op.py index b052374dc7ec3c5684d6adfda6b9d000c5e19fe0..03eb7deb9a35933e5a1676a262a371c69151e6d1 100644 --- a/python/paddle/v2/fluid/tests/test_activation_op.py +++ b/python/paddle/v2/fluid/tests/test_activation_op.py @@ -10,13 +10,13 @@ class TestExp(OpTest): self.inputs = { 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") } - self.outputs = {'Y': np.exp(self.inputs['X'])} + self.outputs = {'Out': np.exp(self.inputs['X'])} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestSigmoid(OpTest): @@ -25,13 +25,13 @@ class TestSigmoid(OpTest): self.inputs = { 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") } - self.outputs = {'Y': 1 / (1 + np.exp(-self.inputs['X']))} + self.outputs = {'Out': 1 / (1 + np.exp(-self.inputs['X']))} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.008) + self.check_grad(['X'], 'Out', max_relative_error=0.008) class TestLogSigmoid(OpTest): @@ -40,13 +40,13 @@ class TestLogSigmoid(OpTest): self.inputs = { 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") } - self.outputs = {'Y': np.log(1 / (1 + np.exp(-self.inputs['X'])))} + self.outputs = {'Out': np.log(1 / (1 + np.exp(-self.inputs['X'])))} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.008) + self.check_grad(['X'], 'Out', max_relative_error=0.008) class TestTanh(OpTest): @@ -55,13 +55,13 @@ class TestTanh(OpTest): self.inputs = { 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") } - self.outputs = {'Y': np.tanh(self.inputs['X'])} + self.outputs = {'Out': np.tanh(self.inputs['X'])} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestTanhShrink(OpTest): @@ -70,13 +70,13 @@ class TestTanhShrink(OpTest): self.inputs = { 'X': np.random.uniform(0.1, 1, [10, 17]).astype("float32") } - self.outputs = {'Y': self.inputs['X'] - np.tanh(self.inputs['X'])} + self.outputs = {'Out': self.inputs['X'] - np.tanh(self.inputs['X'])} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.008) + self.check_grad(['X'], 'Out', max_relative_error=0.008) class TestHardShrink(OpTest): @@ -90,13 +90,13 @@ class TestHardShrink(OpTest): t = np.copy(x) t[(t >= -threshold) & (t <= threshold)] = 0 - self.outputs = {'Y': t} + self.outputs = {'Out': t} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.005) + self.check_grad(['X'], 'Out', max_relative_error=0.005) class TestSoftShrink(OpTest): @@ -110,13 +110,13 @@ class TestSoftShrink(OpTest): y = np.copy(self.inputs['X']) y = (y < -lambda_val) * (y + lambda_val) + (y > lambda_val) * ( y - lambda_val) - self.outputs = {'Y': y} + self.outputs = {'Out': y} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestSqrt(OpTest): @@ -125,13 +125,13 @@ class TestSqrt(OpTest): self.inputs = { 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") } - self.outputs = {'Y': np.sqrt(self.inputs['X'])} + self.outputs = {'Out': np.sqrt(self.inputs['X'])} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestAbs(OpTest): @@ -144,13 +144,13 @@ class TestAbs(OpTest): # we should avoid this x[np.abs(x) < 0.005] = 0.02 self.inputs = {'X': x} - self.outputs = {'Y': np.abs(self.inputs['X'])} + self.outputs = {'Out': np.abs(self.inputs['X'])} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestCeil(OpTest): @@ -158,13 +158,13 @@ class TestCeil(OpTest): self.op_type = "ceil" x = np.random.uniform(-1, 1, [4, 4]).astype("float32") self.inputs = {'X': x} - self.outputs = {'Y': np.ceil(self.inputs['X'])} + self.outputs = {'Out': np.ceil(self.inputs['X'])} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestFloor(OpTest): @@ -173,13 +173,13 @@ class TestFloor(OpTest): x = np.random.uniform(-1, 1, [4, 4]).astype("float32") self.inputs = {'X': x} # numpy floor need +1 - self.outputs = {'Y': np.floor(self.inputs['X']) + 1.0} + self.outputs = {'Out': np.floor(self.inputs['X']) + 1.0} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestRound(OpTest): @@ -187,13 +187,13 @@ class TestRound(OpTest): self.op_type = "round" x = np.random.uniform(-1, 1, [4, 4]).astype("float32") self.inputs = {'X': x} - self.outputs = {'Y': np.round(self.inputs['X'])} + self.outputs = {'Out': np.round(self.inputs['X'])} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestRelu(OpTest): @@ -203,13 +203,13 @@ class TestRelu(OpTest): # The same reason with TestAbs x[np.abs(x) < 0.005] = 0.02 self.inputs = {'X': x} - self.outputs = {'Y': np.maximum(self.inputs['X'], 0)} + self.outputs = {'Out': np.maximum(self.inputs['X'], 0)} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestBRelu(OpTest): @@ -227,13 +227,13 @@ class TestBRelu(OpTest): t = np.copy(x) t[t < t_min] = t_min t[t > t_max] = t_max - self.outputs = {'Y': t} + self.outputs = {'Out': t} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.02) + self.check_grad(['X'], 'Out', max_relative_error=0.02) class TestRelu6(OpTest): @@ -248,14 +248,14 @@ class TestRelu6(OpTest): self.inputs = {'X': x} self.attrs = {'threshold': threshold} self.outputs = { - 'Y': np.minimum(np.maximum(self.inputs['X'], 0), threshold) + 'Out': np.minimum(np.maximum(self.inputs['X'], 0), threshold) } def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.02) + self.check_grad(['X'], 'Out', max_relative_error=0.02) class TestSoftRelu(OpTest): @@ -271,13 +271,13 @@ class TestSoftRelu(OpTest): t = np.copy(x) t[t < -threshold] = -threshold t[t > threshold] = threshold - self.outputs = {'Y': np.log((np.exp(t) + 1))} + self.outputs = {'Out': np.log((np.exp(t) + 1))} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.02) + self.check_grad(['X'], 'Out', max_relative_error=0.02) class TestELU(OpTest): @@ -290,27 +290,27 @@ class TestELU(OpTest): self.inputs = {'X': x} self.attrs = {'alpha': alpha} self.outputs = { - 'Y': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1)) + 'Out': np.maximum(0, x) + np.minimum(0, alpha * (np.exp(x) - 1)) } def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.02) + self.check_grad(['X'], 'Out', max_relative_error=0.02) class TestReciprocal(OpTest): def setUp(self): self.op_type = "reciprocal" self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} - self.outputs = {'Y': np.reciprocal(self.inputs['X'])} + self.outputs = {'Out': np.reciprocal(self.inputs['X'])} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.01) + self.check_grad(['X'], 'Out', max_relative_error=0.01) class TestLog(OpTest): @@ -319,13 +319,13 @@ class TestLog(OpTest): self.inputs = { 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") } - self.outputs = {'Y': np.log(self.inputs['X'])} + self.outputs = {'Out': np.log(self.inputs['X'])} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestSquare(OpTest): @@ -334,13 +334,13 @@ class TestSquare(OpTest): self.inputs = { 'X': np.random.uniform(0.1, 1, [11, 17]).astype("float32") } - self.outputs = {'Y': np.square(self.inputs['X'])} + self.outputs = {'Out': np.square(self.inputs['X'])} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestPow(OpTest): @@ -348,13 +348,13 @@ class TestPow(OpTest): self.op_type = "pow" self.inputs = {'X': np.random.uniform(1, 2, [11, 17]).astype("float32")} self.attrs = {'factor': 3.0} - self.outputs = {'Y': np.power(self.inputs['X'], 3)} + self.outputs = {'Out': np.power(self.inputs['X'], 3)} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.02) + self.check_grad(['X'], 'Out', max_relative_error=0.02) class TestSTanh(OpTest): @@ -366,13 +366,13 @@ class TestSTanh(OpTest): scale_a = 2.0 / 3.0 scale_b = 1.7159 self.attrs = {'scale_a': scale_a, 'scale_b': scale_b} - self.outputs = {'Y': scale_b * np.tanh(self.inputs['X'] * scale_a)} + self.outputs = {'Out': scale_b * np.tanh(self.inputs['X'] * scale_a)} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestSoftplus(OpTest): @@ -381,13 +381,13 @@ class TestSoftplus(OpTest): self.inputs = { 'X': np.random.uniform(-1, 1, [11, 17]).astype("float64") } - self.outputs = {'Y': np.log(1 + np.exp(self.inputs['X']))} + self.outputs = {'Out': np.log(1 + np.exp(self.inputs['X']))} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestSoftsign(OpTest): @@ -397,14 +397,14 @@ class TestSoftsign(OpTest): 'X': np.random.uniform(-1, 1, [11, 17]).astype("float32") } self.outputs = { - 'Y': np.divide(self.inputs['X'], 1 + np.abs(self.inputs['X'])) + 'Out': np.divide(self.inputs['X'], 1 + np.abs(self.inputs['X'])) } def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.007) + self.check_grad(['X'], 'Out', max_relative_error=0.007) class TestThresholdedRelu(OpTest): @@ -419,13 +419,13 @@ class TestThresholdedRelu(OpTest): self.inputs = {'X': X} self.attrs = {'threshold': threshold} - self.outputs = {'Y': (X > threshold) * X} + self.outputs = {'Out': (X > threshold) * X} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=self.relative_error) + self.check_grad(['X'], 'Out', max_relative_error=self.relative_error) class TestHardSigmoid(OpTest): @@ -447,13 +447,13 @@ class TestHardSigmoid(OpTest): upper_threshold - 0.2 temp = X * slope + offset - self.outputs = {'Y': np.maximum(0.0, np.minimum(1.0, temp))} + self.outputs = {'Out': np.maximum(0.0, np.minimum(1.0, temp))} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.002) + self.check_grad(['X'], 'Out', max_relative_error=0.002) class TestSwish(OpTest): @@ -462,13 +462,13 @@ class TestSwish(OpTest): X = np.random.uniform(0.1, 1, [11, 17]).astype("float32") self.inputs = {'X': X} self.attrs = {'beta': 2.3} - self.outputs = {'Y': X * expit(self.attrs['beta'] * X)} + self.outputs = {'Out': X * expit(self.attrs['beta'] * X)} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.008) + self.check_grad(['X'], 'Out', max_relative_error=0.008) if __name__ == "__main__": diff --git a/python/paddle/v2/fluid/tests/test_net.py b/python/paddle/v2/fluid/tests/test_net.py index 318df08a9e73ac95cab73c34182bc6220ef6c681..d9fe55a8af5c750c5c926e875ddbb645f8abb1a0 100644 --- a/python/paddle/v2/fluid/tests/test_net.py +++ b/python/paddle/v2/fluid/tests/test_net.py @@ -7,7 +7,7 @@ def fc(X, W, Y): ret_v = core.Net.create() ret_v.append_op(Operator("mul", X="X", Y="W", Out="pre_activation")) - ret_v.append_op(Operator("sigmoid", X="pre_activation", Y=Y)) + ret_v.append_op(Operator("sigmoid", X="pre_activation", Out=Y)) ret_v.complete_add_op(True) return ret_v @@ -30,7 +30,7 @@ Op(plain_net), inputs:{all[W, X, Y]}, outputs:{all[Out, fc.out, pre_activation]} Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}. Op(plain_net), inputs:{all[W, X]}, outputs:{all[fc.out, pre_activation]}. Op(mul), inputs:{X[X], Y[W]}, outputs:{Out[pre_activation]}. - Op(sigmoid), inputs:{X[pre_activation]}, outputs:{Y[fc.out]}. + Op(sigmoid), inputs:{X[pre_activation]}, outputs:{Out[fc.out]}. ''' self.assertEqual(expected, "\n" + str(net)) diff --git a/python/paddle/v2/fluid/tests/test_softmax_op.py b/python/paddle/v2/fluid/tests/test_softmax_op.py index b41c810d9a6269c934a434b085748a86deccb475..136fc0283afd6acf1de4baae5e681789662295ce 100644 --- a/python/paddle/v2/fluid/tests/test_softmax_op.py +++ b/python/paddle/v2/fluid/tests/test_softmax_op.py @@ -17,14 +17,14 @@ class TestSoftmaxOp(OpTest): 'X': np.random.uniform(0.1, 1, [10, 10]).astype("float32") } self.outputs = { - 'Y': np.apply_along_axis(stable_softmax, 1, self.inputs['X']) + 'Out': np.apply_along_axis(stable_softmax, 1, self.inputs['X']) } def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(['X'], 'Out') if __name__ == "__main__":