diff --git a/doc/survey/op_fusion_design.md b/doc/survey/op_fusion_design.md new file mode 100644 index 0000000000000000000000000000000000000000..d6e48f4f58269b67450cb012f6dcc59e1083abba --- /dev/null +++ b/doc/survey/op_fusion_design.md @@ -0,0 +1,20 @@ +# Operator fusion +Fusing multiple operators together is an important method to optimize the program execution, particularly for GPU or other specialized accelerators. An obvious benefit is to avoid the overhead of saving the intermediate result back into global memory. + +There are generally two ways to fuse operators, fusing directly connected operators and fusing non directly connected operators. The first method is mainly used by [NNVM Compiler](https://github.com/dmlc/tvm/) and [XLA](https://www.tensorflow.org/performance/xla/). The second method is mainly used by Dynet and TensorFlow Fold to do auto-batching. The principle of fusing operator is according to some rules to combine multiple operations into one, for example, `Y = X * W` and `Z = Y + B` can be fused to `Z = X * W + B`, and `Y1 = X1 * W` and `Y2 = X2 * W` can be fused to `[Y1;Y2] = [X1;X2] * W`. In order to get a short-term profit, we decided to try to manually specify these rules. + +## Challenge +The challenge of fusing operators is: + - how to make the rules. + - how to implement these rules efficiently. + +### How to make the rules? + +The problem of determining the best single location for a fusion operator is an NP-hard combinatorial problem. After analysis the operators of the DL model, we found there are two group of operators can be fused explicitly, one is the simple and adjacent operations, for example, `tmp = x + y` and `z = Relu(tmp)`, and the other is the operators that have the same function, for example, a serials of `SGD` or `Momentum`. They usually appear in the model in a large number. So we should think about how to fuse them separately first. + +### How to implement these rules efficiently? +#### How to fuse the adjacent operations efficiently? +Here we use a template function to represent the fused operations. The pros of using a template function are that it is simple and efficient, and the cons are that it is not easy to expand, and it can only be used to express some simple operations. So taking into account our current needs, the template function is more appropriate. + +#### How to fuse the operators that have the same function efficiently? +We take SGD operator as an example, the training model may have hundreds of parameters and correspondingly have the same number of SGD operators. The expression(`w = w - lr*w_g`) of those operators is the same, so during of training, the executor will execute this expression hundreds time in CPU or other specialized accelerators. If we can fuse them and make the address of all `w` and all `w_g` continuous respectively, we only need execute one time. For some accelerators, the time of launching kernel is not neglected, so the time of hundreds of times of launching and executing kernel may be larger than launching and executing only once. There usually are many operators that similar to `SGD` in the DL model, such as `AllReduce` and `FC`. diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.cc b/paddle/fluid/operators/fused_elemwise_activation_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a6fd0aeb021dce40339c32251af130d5984dccd2 --- /dev/null +++ b/paddle/fluid/operators/fused_elemwise_activation_op.cc @@ -0,0 +1,221 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/operators/fused_elemwise_activation_op.h" + +namespace paddle { +namespace operators { + +class FusedElemwiseActivationOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE( + ctx->HasInput("X"), + "Input(X) of FusedElemwiseActivationOp op should not be null."); + PADDLE_ENFORCE( + ctx->HasInput("Y"), + "Input(Y) of FusedElemwiseActivationOp op should not be null."); + PADDLE_ENFORCE( + ctx->HasOutput("Out"), + "Output(Out) of FusedElemwiseActivationOp op should not be null."); + + auto x_dim = ctx->GetInputDim("X"); + auto y_dim = ctx->GetInputDim("Y"); + PADDLE_ENFORCE_GE(x_dim.size(), y_dim.size(), + "Rank of first input must >= rank of second input."); + + ctx->SetOutputDim("Out", x_dim); + ctx->ShareLoD("X", /*->*/ "Out"); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + PADDLE_ENFORCE_EQ(ctx.Input("X")->type(), + ctx.Input("Y")->type(), + "The element's type of input should be the same."); + auto input_data_type = + framework::ToDataType(ctx.Input("X")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; + +class FusedElemwiseActivationMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(vector)"); + AddInput("Y", "(vector)"); + AddOutput("Out", "vector"); + AddAttr("axis", + "axis is used by elementwise_op, the default value is -1.") + .SetDefault(-1); + AddAttr("scale", + "scale is used by scale_op, the default value is 0.0.") + .SetDefault(0.0); + AddAttr("recomputation", + "Whether to recompute the Out." + "fused_elemwise_activation_grad has two methods to get the " + "dx and dy, one " + "is to use the 'Out', and the other is not to use it. " + "The former method will save the time of recomputing the " + "'Out', but it must occupy the memory to store the 'out'. " + "While, the later method can avoid occupying the memory, " + "but it must recompute the 'Out'. The default value is true.") + .SetDefault(true); + AddAttr>("functor_list", + "The functors that should be fused.") + .AddCustomChecker([&](const std::vector &functor_list) { + PADDLE_ENFORCE(ValidCheck(functor_list)); + }); + + AddComment(R"DOC( +FusedElemwiseActivation Operator. + +At present, FusedElemwiseActivation only supports Two kinds of compound +operators (elementwise_op and activation_op): + + Z = Binary(X, Unary(Y)) + Z = Unary(Binary(X, Y)) + +The attributions of activation_op can be get from fused_elemwise_activation_op's +attributions. functor_list records the functors to be fused, for example +"scale,elementwise_add". + +)DOC"); + } + + private: + bool ValidCheck(const std::vector &functors) { + std::unordered_set unary_fun = {"scale", "relu"}; + std::unordered_set binary_fun = {"elementwise_add"}; + + std::string unary_fun_str; + if (binary_fun.count(functors[0])) { + unary_fun_str = functors[1]; + } else if (binary_fun.count(functors[1])) { + unary_fun_str = functors[0]; + } else { + PADDLE_THROW("%s and %s are not included in fused_list.", functors[0], + functors[1]); + } + PADDLE_ENFORCE_EQ(unary_fun.count(unary_fun_str), 1, + "%s is not included in fused_list.", unary_fun_str); + return true; + } +}; + +class FusedElemwiseActivationGradMaker + : public framework::SingleGradOpDescMaker { + public: + using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; + + protected: + std::unique_ptr Apply() const override { + auto *op_desc_ptr = new framework::OpDesc(); + op_desc_ptr->SetType(this->ForwardOpType() + "_grad"); + + for (auto &input_param : this->InputNames()) { + op_desc_ptr->SetInput(input_param, this->Input(input_param)); + op_desc_ptr->SetOutput(framework::GradVarName(input_param), + this->InputGrad(input_param, true)); + } + + for (auto &output_param : this->OutputNames()) { + op_desc_ptr->SetInput(output_param, this->Output(output_param)); + op_desc_ptr->SetInput(framework::GradVarName(output_param), + this->OutputGrad(output_param)); + } + op_desc_ptr->SetAttrMap(this->Attrs()); + + std::vector functor_names = + boost::get>( + op_desc_ptr->GetAttr("functor_list")); + functor_names[0] += "_grad"; + functor_names[1] += "_grad"; + op_desc_ptr->SetAttr("functor_list", functor_names); + return std::unique_ptr(op_desc_ptr); + } +}; + +class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), + "Rank of first input must >= rank of second input."); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type_index = ctx.Input("X")->type(); + PADDLE_ENFORCE_EQ(input_data_type_index, + ctx.Input("Y")->type(), + "The element's type of input should be the same."); + PADDLE_ENFORCE_EQ( + input_data_type_index, + ctx.Input(framework::GradVarName("Out"))->type(), + "The element's type of input should be the same."); + + auto input_data_type = framework::ToDataType(input_data_type_index); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(fused_elemwise_activation, ops::FusedElemwiseActivationOp, + ops::FusedElemwiseActivationMaker, + ops::FusedElemwiseActivationGradMaker); +REGISTER_OPERATOR(fused_elemwise_activation_grad, + ops::FusedElemwiseActivationOpGrad); + +REGISTER_OP_CPU_KERNEL( + fused_elemwise_activation, + ops::FusedElemwiseActivationKernel, + ops::FusedElemwiseActivationKernel); + +REGISTER_OP_CPU_KERNEL( + fused_elemwise_activation_grad, + ops::FusedElemwiseActivationGradKernel, + ops::FusedElemwiseActivationGradKernel); diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.cu b/paddle/fluid/operators/fused_elemwise_activation_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..e1d2b16b4b5e3a480777f834c2cbeb6d00a755e4 --- /dev/null +++ b/paddle/fluid/operators/fused_elemwise_activation_op.cu @@ -0,0 +1,30 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/fused_elemwise_activation_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL( + fused_elemwise_activation, + ops::FusedElemwiseActivationKernel, + ops::FusedElemwiseActivationKernel); + +REGISTER_OP_CUDA_KERNEL( + fused_elemwise_activation_grad, + ops::FusedElemwiseActivationGradKernel, + ops::FusedElemwiseActivationGradKernel); diff --git a/paddle/fluid/operators/fused_elemwise_activation_op.h b/paddle/fluid/operators/fused_elemwise_activation_op.h new file mode 100644 index 0000000000000000000000000000000000000000..fe0017b824532b1210d0ae3e51983d63d081f12a --- /dev/null +++ b/paddle/fluid/operators/fused_elemwise_activation_op.h @@ -0,0 +1,425 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/fluid/framework/op_desc.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/detail/safe_ref.h" +#include "paddle/fluid/operators/elementwise_op_function.h" +#include "paddle/fluid/operators/math/functors.h" + +namespace math = paddle::operators::math; + +namespace paddle { +namespace operators { + +// CompoundFunctors +// For example: Z = Binary(X, Unary(Y)) +template +struct BinaryCompoundFunctor { + BinaryCompoundFunctor(const BinaryFun &binary_fun, const UnaryFun &unary_fun) + : binary_fun_(binary_fun), unary_fun_(unary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y) { + return binary_fun_(x, unary_fun_(y)); + } + + private: + BinaryFun binary_fun_; + UnaryFun unary_fun_; +}; + +// For example: Z = Unary(Binary(X, Y)) +template +struct UnaryCompoundFunctor { + UnaryCompoundFunctor(const UnaryFun &unary_fun, const BinaryFun &binary_fun) + : unary_fun_(unary_fun), binary_fun_(binary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y) { + return unary_fun_(binary_fun_(x, y)); + } + + private: + UnaryFun unary_fun_; + BinaryFun binary_fun_; +}; + +// FIXME(zcd): DBinaryFun and DUnaryFun have to method to get +// the dx, one is to use the 'out', and the other is not to use it. +// the former method will save the time of recomputing the +// 'out', but it must occupy the memory to store the 'out'. +// While the later method can avoid occupying this memory, +// but it must recompute the 'out'. + +template +struct BinaryCompoundGradDxFunctor { + BinaryCompoundGradDxFunctor(const DBinaryFun &d_binary_fun, + const UnaryFun &unary_fun) + : d_binary_fun_(d_binary_fun), unary_fun_(unary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + if (Recomputation) { + return dout * d_binary_fun_(x, unary_fun_(y)); + } else { + return dout * d_binary_fun_(x, unary_fun_(y), out); + } + } + + private: + DBinaryFun d_binary_fun_; + UnaryFun unary_fun_; +}; + +template +struct BinaryCompoundGradDyFunctor { + BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun, + const UnaryFun &unary_fun, + const DUnaryFun &d_unary_fun) + : d_binary_fun_(d_binary_fun), + unary_fun_(unary_fun), + d_unary_fun_(d_unary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + if (Recomputation) { + return dout * d_binary_fun_(unary_fun_(y), x) * d_unary_fun_(y); + } else { + return dout * d_binary_fun_(unary_fun_(y), x, out) * d_unary_fun_(y); + } + } + + private: + DBinaryFun d_binary_fun_; + UnaryFun unary_fun_; + DUnaryFun d_unary_fun_; +}; + +template +struct UnaryCompoundGradDxFunctor { + UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun, + const BinaryFun &binary_fun, + const DBinaryFun &d_binary_fun) + : d_unary_fun_(d_unary_fun), + binary_fun_(binary_fun), + d_binary_fun_(d_binary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + T base; + if (Recomputation) { + base = dout * d_unary_fun_(binary_fun_(x, y)); + } else { + base = dout * d_unary_fun_(binary_fun_(x, y), out); + } + return base * d_binary_fun_(x, y); + } + + private: + DUnaryFun d_unary_fun_; + BinaryFun binary_fun_; + DBinaryFun d_binary_fun_; +}; + +template +struct UnaryCompoundGradDyFunctor { + UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun, + const BinaryFun &binary_fun, + const DBinaryFun &d_binary_fun) + : d_unary_fun_(d_unary_fun), + binary_fun_(binary_fun), + d_binary_fun_(d_binary_fun) {} + + inline HOSTDEVICE T operator()(T x, T y, T out, T dout) { + T base; + if (Recomputation) { + base = dout * d_unary_fun_(binary_fun_(x, y)); + } else { + base = dout * d_unary_fun_(binary_fun_(x, y), out); + } + return base * d_binary_fun_(y, x); + } + + private: + DUnaryFun d_unary_fun_; + BinaryFun binary_fun_; + DBinaryFun d_binary_fun_; +}; + +template +static void RunBinaryCompoundFunctor(const framework::ExecutionContext &ctx, + const BinaryFunctor &binary_functor, + const UnaryFunctor &unary_functor, + const framework::Tensor *in_x, + const framework::Tensor *in_y, + framework::Tensor *output) { + int axis = ctx.Attr("axis"); + using BinaryCompoundFunctor = + BinaryCompoundFunctor; + + ElementwiseComputeEx( + ctx, in_x, in_y, axis, + BinaryCompoundFunctor(binary_functor, unary_functor), output); +} + +template +static void RunUnaryCompoundFunctors(const framework::ExecutionContext &ctx, + const UnaryFunctor &unary_functor, + const BinaryFunctor &binary_functor, + const framework::Tensor *in_x, + const framework::Tensor *in_y, + framework::Tensor *output) { + int axis = ctx.Attr("axis"); + + using UnaryCompoundFunctor = + UnaryCompoundFunctor; + + ElementwiseComputeEx( + ctx, in_x, in_y, axis, + UnaryCompoundFunctor(unary_functor, binary_functor), output); +} + +template +static void RunBinaryCompoundGradFunctors( + const framework::ExecutionContext &ctx, + const BinaryGradFunctor &binary_grad_functor, + const UnaryFunctor &unary_functor, + const UnaryGradFunctor &unary_grad_functor, const framework::Tensor *in_x, + const framework::Tensor *in_y, const framework::Tensor *in_out, + const framework::Tensor *in_out_grad, framework::Tensor *x_grad, + framework::Tensor *y_grad) { + int axis = ctx.Attr("axis"); + + using BinaryCompoundDxFunctor = + BinaryCompoundGradDxFunctor; + using BinaryCompoundDyFunctor = + BinaryCompoundGradDyFunctor; + + ElemwiseGradCompute( + ctx, *in_x, *in_y, *in_out, *in_out_grad, axis, x_grad, y_grad, + BinaryCompoundDxFunctor(binary_grad_functor, unary_functor), + BinaryCompoundDyFunctor(binary_grad_functor, unary_functor, + unary_grad_functor)); +} + +template +static void RunUnaryCompoundGradFunctors( + const framework::ExecutionContext &ctx, + const UnaryGradFunctor &unary_grad_functor, + const BinaryFunctor &binary_functor, + const BinaryGradFunctor &binary_grad_functor, const framework::Tensor *in_x, + const framework::Tensor *in_y, const framework::Tensor *in_out, + const framework::Tensor *in_out_grad, framework::Tensor *x_grad, + framework::Tensor *y_grad) { + int axis = ctx.Attr("axis"); + + using UnaryCompoundDxFunctor = + UnaryCompoundGradDxFunctor; + using UnaryCompoundDyFunctor = + UnaryCompoundGradDyFunctor; + + ElemwiseGradCompute( + ctx, *in_x, *in_y, *in_out, *in_out_grad, axis, x_grad, y_grad, + UnaryCompoundDxFunctor(unary_grad_functor, binary_functor, + binary_grad_functor), + UnaryCompoundDyFunctor(unary_grad_functor, binary_functor, + binary_grad_functor)); +} + +template +static void RunFunctors(const framework::ExecutionContext &ctx, + const framework::Tensor *in_x, + const framework::Tensor *in_y, + framework::Tensor *output) { + auto &functors = ctx.Attr>("functor_list"); + auto funcs_str = functors[0] + "," + functors[1]; + // TODO(zcd): The following code can be refined. + if (funcs_str == "elementwise_add,scale") { + // Z = Binary(X, Unary(Y)) + T scale = static_cast(ctx.Attr("scale")); + RunBinaryCompoundFunctor, + math::ScaleFunctor>( + ctx, math::AddFunctor(), math::ScaleFunctor(scale), in_x, in_y, + output); + } else if (funcs_str == "scale,elementwise_add") { + // Z = Unary(Binary(X, Y)) + T scale = static_cast(ctx.Attr("scale")); + RunUnaryCompoundFunctors, + math::AddFunctor>( + ctx, math::ScaleFunctor(scale), math::AddFunctor(), in_x, in_y, + output); + } else if (funcs_str == "elementwise_add,relu") { + RunBinaryCompoundFunctor, + math::ReluFunctor>( + ctx, math::AddFunctor(), math::ReluFunctor(), in_x, in_y, output); + } else if (funcs_str == "relu,elementwise_add") { + RunUnaryCompoundFunctors, + math::AddFunctor>( + ctx, math::ReluFunctor(), math::AddFunctor(), in_x, in_y, output); + } else { + PADDLE_THROW("%s has not been implemented.", funcs_str); + } +} + +template +static void RunGradFunctors(const framework::ExecutionContext &ctx, + const framework::Tensor *in_x, + const framework::Tensor *in_y, + const framework::Tensor *in_out, + const framework::Tensor *in_out_grad, + framework::Tensor *x_grad, + framework::Tensor *y_grad) { + auto &functors = ctx.Attr>("functor_list"); + auto funcs_str = functors[0] + "," + functors[1]; + + bool recomputation = ctx.Attr("recomputation"); + + // TODO(zcd): The following code can be refined. for example, use registion + if (funcs_str == "elementwise_add_grad,scale_grad") { + // The backward of Z = Binary(X, Unary(Y)) + T scale = static_cast(ctx.Attr("scale")); + if (recomputation) { + RunBinaryCompoundGradFunctors, + math::ScaleFunctor, + math::ScaleGradFunctor, true>( + ctx, math::AddGradFunctor(), math::ScaleFunctor(scale), + math::ScaleGradFunctor(scale), in_x, in_y, in_out, in_out_grad, + x_grad, y_grad); + } else { + RunBinaryCompoundGradFunctors, + math::ScaleFunctor, + math::ScaleGradFunctor, false>( + ctx, math::AddGradFunctor(), math::ScaleFunctor(scale), + math::ScaleGradFunctor(scale), in_x, in_y, in_out, in_out_grad, + x_grad, y_grad); + } + } else if (funcs_str == "scale_grad,elementwise_add_grad") { + // The backward of Z = Unary(Binary(X, Y)) + T scale = static_cast(ctx.Attr("scale")); + if (recomputation) { + RunUnaryCompoundGradFunctors, + math::AddFunctor, math::AddGradFunctor, + true>(ctx, math::ScaleGradFunctor(scale), + math::AddFunctor(), + math::AddGradFunctor(), in_x, in_y, + in_out, in_out_grad, x_grad, y_grad); + } else { + RunUnaryCompoundGradFunctors, + math::AddFunctor, math::AddGradFunctor, + false>(ctx, math::ScaleGradFunctor(scale), + math::AddFunctor(), + math::AddGradFunctor(), in_x, in_y, + in_out, in_out_grad, x_grad, y_grad); + } + } else if (funcs_str == "elementwise_add_grad,relu_grad") { + if (recomputation) { + RunBinaryCompoundGradFunctors, + math::ReluFunctor, + math::ReluGradFunctor, true>( + ctx, math::AddGradFunctor(), math::ReluFunctor(), + math::ReluGradFunctor(), in_x, in_y, in_out, in_out_grad, x_grad, + y_grad); + } else { + RunBinaryCompoundGradFunctors, + math::ReluFunctor, + math::ReluGradFunctor, false>( + ctx, math::AddGradFunctor(), math::ReluFunctor(), + math::ReluGradFunctor(), in_x, in_y, in_out, in_out_grad, x_grad, + y_grad); + } + } else if (funcs_str == "relu_grad,elementwise_add_grad") { + if (recomputation) { + RunUnaryCompoundGradFunctors, + math::AddFunctor, math::AddGradFunctor, + true>(ctx, math::ReluGradFunctor(), + math::AddFunctor(), + math::AddGradFunctor(), in_x, in_y, + in_out, in_out_grad, x_grad, y_grad); + } else { + RunUnaryCompoundGradFunctors, + math::AddFunctor, math::AddGradFunctor, + false>(ctx, math::ReluGradFunctor(), + math::AddFunctor(), + math::AddGradFunctor(), in_x, in_y, + in_out, in_out_grad, x_grad, y_grad); + } + } else { + PADDLE_THROW("%s has not been implemented.", funcs_str); + } +} + +template +class FusedElemwiseActivationKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto &in_x = detail::Ref(ctx.Input("X"), + "Cannot get input tensor %s, variable name = %s", + "X", ctx.op().Input("X")); + auto &in_y = detail::Ref(ctx.Input("Y"), + "Cannot get input tensor %s, variable name = %s", + "Y", ctx.op().Input("Y")); + auto &output = detail::Ref(ctx.Output("Out"), + "Cannot get input tensor %s, variable name = %s", + "Out", ctx.op().Output("Out")); + + RunFunctors(ctx, &in_x, &in_y, &output); + } +}; + +template +class FusedElemwiseActivationGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + auto &in_x = detail::Ref(ctx.Input("X"), + "Cannot get input tensor %s, variable name = %s", + "X", ctx.op().Input("X")); + auto &in_y = detail::Ref(ctx.Input("Y"), + "Cannot get input tensor %s, variable name = %s", + "Y", ctx.op().Input("Y")); + auto &in_out = detail::Ref(ctx.Input("Out"), + "Cannot get input tensor %s, variable name = %s", + "Out", ctx.op().Input("Out")); + auto &in_out_grad = + detail::Ref(ctx.Input(framework::GradVarName("Out")), + "Cannot get input tensor %s, variable name = %s", + framework::GradVarName("Out"), + ctx.op().Input(framework::GradVarName("Out"))); + + framework::Tensor *x_grad = + ctx.Output(framework::GradVarName("X")); + framework::Tensor *y_grad = + ctx.Output(framework::GradVarName("Y")); + + RunGradFunctors(ctx, &in_x, &in_y, &in_out, &in_out_grad, + x_grad, y_grad); + } +}; +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/functors.h b/paddle/fluid/operators/math/functors.h new file mode 100644 index 0000000000000000000000000000000000000000..ad2f49ccbf5ff37d33cc9e71c1a683571f4f8137 --- /dev/null +++ b/paddle/fluid/operators/math/functors.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace operators { +namespace math { + +// AddFunctor +template +struct AddFunctor { + // out = x + y; + inline HOSTDEVICE T operator()(T x, T y) { return x + y; } +}; + +template +struct AddGradFunctor { + inline HOSTDEVICE T operator()(T x, T y) { return 1; } + + inline HOSTDEVICE T operator()(T x, T y, T out) const { return 1; } +}; + +template +struct ScaleFunctor { + explicit ScaleFunctor(const T coeff) : coeff_(coeff) {} + + inline HOSTDEVICE T operator()(T ele) { return ele * coeff_; } + + private: + T coeff_; +}; + +template +struct ScaleGradFunctor { + explicit ScaleGradFunctor(T coeff) : coeff_(coeff) {} + + inline HOSTDEVICE T operator()(T x) { return coeff_; } + + inline HOSTDEVICE T operator()(T x, T out) { return coeff_; } + + private: + T coeff_; +}; + +template +struct ReluFunctor { + inline HOSTDEVICE T operator()(T x) { return x * (x > 0); } +}; + +template +struct ReluGradFunctor { + inline HOSTDEVICE T operator()(T x) { return x > 0 ? 1 : 0; } + + inline HOSTDEVICE T operator()(T x, T out) { return x > 0 ? 1 : 0; } +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/test_activation_op.py b/python/paddle/fluid/tests/unittests/test_activation_op.py index 5ed387fb1247f1a91147cb6981f1adc7c2eeb8a2..34f9cf0620fd1351111e93e16ed5f7e765d7078b 100644 --- a/python/paddle/fluid/tests/unittests/test_activation_op.py +++ b/python/paddle/fluid/tests/unittests/test_activation_op.py @@ -313,9 +313,9 @@ class TestAbs(OpTest): self.init_dtype() x = np.random.uniform(-1, 1, [4, 4]).astype(self.dtype) - # Because we set delta = 0.005 in caculating numeric gradient, + # Because we set delta = 0.005 in calculating numeric gradient, # if x is too small, such as 0.002, x_neg will be -0.003 - # x_pos will be 0.007, so the numeric gradient is unaccurate. + # x_pos will be 0.007, so the numeric gradient is inaccurate. # we should avoid this x[np.abs(x) < 0.005] = 0.02 out = np.abs(x) diff --git a/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ec0a939e9ec21952a6657ea849bb9844bb69cc8d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_elemwise_activation_op.py @@ -0,0 +1,818 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest + +# scale + add +# TestElementwiseAddOp +# TestFusedOperatorsOp_scalar +# TestFusedOperatorsOp_scalar2 +# TestFusedOperatorsOp_Vector +# TestFusedOperatorsOp_broadcast_0 +# TestFusedOperatorsOp_broadcast_1 +# TestFusedOperatorsOp_broadcast_2 +# TestFusedOperatorsOp_broadcast_3 +# TestFusedOperatorsOp_broadcast_4 +# TestFusedOperatorsOp_rowwise_add_0 +# TestFusedOperatorsOp_rowwise_add_1 +# TestFusedOperatorsOp_channelwise_add + + +class TestElementwiseAddOp(OpTest): + def setUp(self): + self.op_type = "fused_elemwise_activation" + self.dtype = np.float32 + self.axis = -1 + + self.init_axis() + self.init_dtype() + self.init_input() + self.init_output() + self.init_attr() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.outputs = {'Out': self.out} + + def init_input(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["scale", "elementwise_add"] + } + + def init_dtype(self): + pass + + def init_axis(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.005) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.005, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y')) + + +class TestFusedOperatorsOp_scalar(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y) * self.scale + + +class TestFusedOperatorsOp_scalar2(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1, 1).astype(self.dtype) + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y) * self.scale + + +class TestFusedOperatorsOp_Vector(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.random((32, )).astype(self.dtype) + self.y = np.random.random((32, )).astype(self.dtype) + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y) * self.scale + + +class TestFusedOperatorsOp_broadcast_0(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(2).astype(self.dtype) + + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(2, 1, 1)) * self.scale + + +class TestFusedOperatorsOp_broadcast_1(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(3).astype(self.dtype) + + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(1, 3, 1)) * self.scale + + +class TestFusedOperatorsOp_broadcast_2(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(4).astype(self.dtype) + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(1, 1, 4)) * self.scale + + +class TestFusedOperatorsOp_broadcast_3(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(3, 4).astype(self.dtype) + + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(1, 3, 4, 1)) * self.scale + + +class TestFusedOperatorsOp_broadcast_4(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4, 5).astype(self.dtype) + self.y = np.random.rand(2, 1).astype(self.dtype) + + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(2, 1, 1, 1)) * self.scale + + +class TestFusedOperatorsOp_rowwise_add_0(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(3, 4).astype(self.dtype) + + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(1, 3, 4)) * self.scale + + +class TestFusedOperatorsOp_rowwise_add_1(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(2, 1).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y.reshape(1, 1)) * self.scale + + +class TestFusedOperatorsOp_channelwise_add(TestElementwiseAddOp): + def init_input(self): + self.x = np.random.rand(3, 20, 20).astype(self.dtype) + self.y = np.random.rand(3, 1, 1).astype(self.dtype) + + def init_axis(self): + self.axis = -1 + + def init_output(self): + self.scale = 0.1 + self.out = (self.x + self.y) * self.scale + + +# add + scale +# TestElementwiseAddOp_f_add_scale +# TestFusedOperatorsOp_scalar_f_add_scale +# TestFusedOperatorsOp_scalar2_f_add_scale +# TestFusedOperatorsOp_Vector_f_add_scale +# TestFusedOperatorsOp_broadcast_0_f_add_scale +# TestFusedOperatorsOp_broadcast_1_f_add_scale +# TestFusedOperatorsOp_broadcast_2_f_add_scale +# TestFusedOperatorsOp_broadcast_3_f_add_scale +# TestFusedOperatorsOp_broadcast_4_f_add_scale +# TestFusedOperatorsOp_rowwise_add_0_f_add_scale +# TestFusedOperatorsOp_rowwise_add_1_f_add_scale +# TestFusedOperatorsOp_channelwise_add_f_add_scale + + +class TestFusedOperatorsOp_f_add_scale(TestElementwiseAddOp): + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_scalar_f_add_scale(TestFusedOperatorsOp_scalar): + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_scalar2_f_add_scale(TestFusedOperatorsOp_scalar2): + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_Vector_f_add_scale(TestFusedOperatorsOp_Vector): + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_broadcast_0_f_add_scale( + TestFusedOperatorsOp_broadcast_0): + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y.reshape(2, 1, 1) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_broadcast_1_f_add_scale( + TestFusedOperatorsOp_broadcast_1): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y.reshape(1, 3, 1) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_broadcast_2_f_add_scale( + TestFusedOperatorsOp_broadcast_2): + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y.reshape(1, 1, 4) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_broadcast_3_f_add_scale( + TestFusedOperatorsOp_broadcast_3): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y.reshape(1, 3, 4, 1) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_broadcast_4_f_add_scale( + TestFusedOperatorsOp_broadcast_4): + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.scale = 0.2 + self.out = self.x + self.y.reshape(2, 1, 1, 1) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_rowwise_add_0_f_add_scale( + TestFusedOperatorsOp_rowwise_add_0): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.1 + self.out = self.x + self.y.reshape(1, 3, 4) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_rowwise_add_1_f_add_scale( + TestFusedOperatorsOp_rowwise_add_1): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.scale = 0.2 + self.out = self.x + self.y.reshape(1, 1) * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +class TestFusedOperatorsOp_channelwise_add_f_add_scale( + TestFusedOperatorsOp_channelwise_add): + def init_axis(self): + self.axis = -1 + + def init_output(self): + self.scale = 0.2 + self.out = self.x + self.y * self.scale + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'scale': self.scale, + 'functor_list': ["elementwise_add", "scale"] + } + + +# add + relu +# TestElementwiseAddOp_f_add_relu +# TestFusedOperatorsOp_scalar_f_add_relu +# TestFusedOperatorsOp_scalar2_f_add_relu +# TestFusedOperatorsOp_Vector_f_add_relu +# TestFusedOperatorsOp_broadcast_0_f_add_relu +# TestFusedOperatorsOp_broadcast_1_f_add_relu +# TestFusedOperatorsOp_broadcast_2_f_add_relu +# TestFusedOperatorsOp_broadcast_3_f_add_relu +# TestFusedOperatorsOp_broadcast_4_f_add_relu +# TestFusedOperatorsOp_rowwise_add_0_f_add_relu +# TestFusedOperatorsOp_rowwise_add_1_f_add_relu +# TestFusedOperatorsOp_channelwise_add_f_add_relu + + +class TestFusedOperatorsOp_f_add_relu(TestElementwiseAddOp): + def init_output(self): + # Copy from test_activation_op.py + # Because we set delta = 0.005 in calculating numeric gradient, + # if x is too small, such as 0.002, x_neg will be -0.003 + # x_pos will be 0.007, so the numeric gradient is inaccurate. + # we should avoid this + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y, 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_scalar_f_add_relu(TestFusedOperatorsOp_scalar): + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y, 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_scalar2_f_add_relu(TestFusedOperatorsOp_scalar2): + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y, 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_Vector_f_add_relu(TestFusedOperatorsOp_Vector): + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y, 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_broadcast_0_f_add_relu( + TestFusedOperatorsOp_broadcast_0): + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(2, 1, 1), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_broadcast_1_f_add_relu( + TestFusedOperatorsOp_broadcast_1): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(1, 3, 1), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_broadcast_2_f_add_relu( + TestFusedOperatorsOp_broadcast_2): + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(1, 1, 4), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_broadcast_3_f_add_relu( + TestFusedOperatorsOp_broadcast_3): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(1, 3, 4, 1), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_broadcast_4_f_add_relu( + TestFusedOperatorsOp_broadcast_4): + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(2, 1, 1, 1), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_rowwise_add_0_f_add_relu( + TestFusedOperatorsOp_rowwise_add_0): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(1, 3, 4), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_rowwise_add_1_f_add_relu( + TestFusedOperatorsOp_rowwise_add_1): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y.reshape(1, 1), 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +class TestFusedOperatorsOp_channelwise_add_f_add_relu( + TestFusedOperatorsOp_channelwise_add): + def init_axis(self): + self.axis = -1 + + def init_output(self): + self.y[np.abs(self.y) < 0.005] = 0.02 + self.out = self.x + np.maximum(self.y, 0) + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["elementwise_add", "relu"] + } + + +# relu + add +# TestElementwiseAddOp_f_relu_add +# TestFusedOperatorsOp_scalar_f_relu_add +# TestFusedOperatorsOp_scalar2_f_relu_add +# TestFusedOperatorsOp_Vector_f_relu_add +# TestFusedOperatorsOp_broadcast_0_f_relu_add +# TestFusedOperatorsOp_broadcast_1_f_relu_add +# TestFusedOperatorsOp_broadcast_2_f_relu_add +# TestFusedOperatorsOp_broadcast_3_f_relu_add +# TestFusedOperatorsOp_broadcast_4_f_relu_add +# TestFusedOperatorsOp_rowwise_add_0_f_relu_add +# TestFusedOperatorsOp_rowwise_add_1_f_relu_add +# TestFusedOperatorsOp_channelwise_add_f_relu_add + + +class TestFusedOperatorsOp_f_relu_add(TestElementwiseAddOp): + def init_output(self): + # Copy from test_activation_op.py + # Because we set delta = 0.005 in calculating numeric gradient, + # if x is too small, such as 0.002, x_neg will be -0.003 + # x_pos will be 0.007, so the numeric gradient is inaccurate. + # we should avoid this + self.out = self.x + self.y + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_scalar_f_relu_add(TestFusedOperatorsOp_scalar): + def init_output(self): + self.out = self.x + self.y + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_scalar2_f_relu_add(TestFusedOperatorsOp_scalar2): + def init_output(self): + self.out = self.x + self.y + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_Vector_f_relu_add(TestFusedOperatorsOp_Vector): + def init_output(self): + self.out = self.x + self.y + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_broadcast_0_f_relu_add( + TestFusedOperatorsOp_broadcast_0): + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.out = self.x + self.y.reshape(2, 1, 1) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_broadcast_1_f_relu_add( + TestFusedOperatorsOp_broadcast_1): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.out = self.x + self.y.reshape(1, 3, 1) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_broadcast_2_f_relu_add( + TestFusedOperatorsOp_broadcast_2): + def init_output(self): + self.out = self.x + self.y.reshape(1, 1, 4) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_broadcast_3_f_relu_add( + TestFusedOperatorsOp_broadcast_3): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.out = self.x + self.y.reshape(1, 3, 4, 1) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_broadcast_4_f_relu_add( + TestFusedOperatorsOp_broadcast_4): + def init_axis(self): + self.axis = 0 + + def init_output(self): + self.out = self.x + self.y.reshape(2, 1, 1, 1) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_rowwise_add_0_f_relu_add( + TestFusedOperatorsOp_rowwise_add_0): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.out = self.x + self.y.reshape(1, 3, 4) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_rowwise_add_1_f_relu_add( + TestFusedOperatorsOp_rowwise_add_1): + def init_axis(self): + self.axis = 1 + + def init_output(self): + self.out = self.x + self.y.reshape(1, 1) + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +class TestFusedOperatorsOp_channelwise_add_f_relu_add( + TestFusedOperatorsOp_channelwise_add): + def init_axis(self): + self.axis = -1 + + def init_output(self): + self.out = self.x + self.y + self.out = np.maximum(self.out, 0) + self.out[np.abs(self.out) < 0.005] = 0.02 + + def init_attr(self): + self.attrs = { + 'axis': self.axis, + 'functor_list': ["relu", "elementwise_add"] + } + + +if __name__ == '__main__': + unittest.main()