未验证 提交 8fd39f3e 编写于 作者: Y Yiqun Liu 提交者: GitHub

Enhance fused_elementwise_activation op and add python api in contrib.layers (#17236)

* Enhance fused_elementwise_activation op.
test=develop

* Move the api fused_elementwise_activation to contrib.
test=develop

* Add including files.
test=develop

* Add the support of sigmoid in fused_elementwise_activetion op.

* Update API.spec.
test=develop
上级 ac92e4c0
......@@ -423,6 +423,7 @@ paddle.fluid.contrib.multi_download (ArgSpec(args=['client', 'hdfs_path', 'local
paddle.fluid.contrib.multi_upload (ArgSpec(args=['client', 'hdfs_path', 'local_path', 'multi_processes', 'overwrite', 'sync'], varargs=None, keywords=None, defaults=(5, False, True)), ('document', '183f34c83d30dbe16e09e8716c41958a'))
paddle.fluid.contrib.extend_with_decoupled_weight_decay (ArgSpec(args=['base_optimizer'], varargs=None, keywords=None, defaults=None), ('document', 'a1095dfd4ec725747f662d69cd7659d4'))
paddle.fluid.contrib.mixed_precision.decorate (ArgSpec(args=['optimizer', 'init_loss_scaling', 'incr_every_n_steps', 'decr_every_n_nan_or_inf', 'incr_ratio', 'decr_ratio', 'use_dynamic_loss_scaling'], varargs=None, keywords=None, defaults=(1.0, 1000, 2, 2.0, 0.8, False)), ('document', 'bdb8f9dbb0d94b3957272c53eeee9818'))
paddle.fluid.contrib.fused_elemwise_activation (ArgSpec(args=['x', 'y', 'functor_list', 'axis', 'scale', 'save_intermediate_out'], varargs=None, keywords=None, defaults=(-1, 0.0, True)), ('document', '1c4b247a2858cea8d9d8750693688270'))
paddle.fluid.transpiler.DistributeTranspiler.__init__ (ArgSpec(args=['self', 'config'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_program (ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None), ('document', 'b1951949c6d21698290aa8ac69afee32'))
paddle.fluid.transpiler.DistributeTranspiler.get_pserver_programs (ArgSpec(args=['self', 'endpoint'], varargs=None, keywords=None, defaults=None), ('document', 'c89fc350f975ef827f5448d68af388cf'))
......
......@@ -26,7 +26,7 @@ namespace framework {
namespace ir {
void FuseElewiseAddActPass::ApplyImpl(ir::Graph *graph) const {
std::unordered_set<std::string> act_types = {"relu", "scale"};
std::unordered_set<std::string> act_types = {"relu", "scale", "tanh"};
graph = FuseActElewiseAdd(graph, act_types);
graph = FuseElewiseAddAct(graph, act_types);
// backward
......
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused/fused_elemwise_activation_op.h"
#include <memory>
#include <unordered_set>
namespace paddle {
namespace operators {
......@@ -48,7 +50,10 @@ bool InputXCanBeAbsent(const std::vector<std::string> &functor_list) {
* out.
*/
static bool IsSupportedCompound(const std::vector<std::string> &functors) {
static std::unordered_set<std::string> unary_fun = {"scale", "relu"};
PADDLE_ENFORCE_EQ(functors.size(), 2UL);
static std::unordered_set<std::string> unary_fun = {"scale", "relu", "tanh",
"sigmoid"};
static std::unordered_set<std::string> binary_fun = {"elementwise_add",
"elementwise_mul"};
......
......@@ -255,6 +255,27 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
paddle::operators::math::ScaleFunctor<T>>(
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
} else if (funcs_str == "tanh,elementwise_add") {
// Z = Unary(Binary(X, Y))
RunUnaryCompoundFunctors<DeviceContext, T,
paddle::operators::math::TanhFunctor<T>,
paddle::operators::math::AddFunctor<T>>(
ctx, paddle::operators::math::TanhFunctor<T>(),
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "elementwise_mul,tanh") {
// Z = Binary(X, Unary(Y))
RunBinaryCompoundFunctor<DeviceContext, T,
paddle::operators::math::MulFunctor<T>,
paddle::operators::math::TanhFunctor<T>>(
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::TanhFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "elementwise_mul,sigmoid") {
// Z = Binary(X, Unary(Y))
RunBinaryCompoundFunctor<DeviceContext, T,
paddle::operators::math::MulFunctor<T>,
paddle::operators::math::SigmoidFunctor<T>>(
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
} else {
PADDLE_THROW("%s has not been implemented.", funcs_str);
}
......@@ -293,6 +314,7 @@ static void RunGradFunctors(
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_add_grad,relu_grad") {
// The backward of Z = Binary(X, Unary(Y))
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
paddle::operators::math::ReluFunctor<T>,
......@@ -302,6 +324,7 @@ static void RunGradFunctors(
paddle::operators::math::ReluGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "relu_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y))
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::ReluGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
......@@ -321,6 +344,36 @@ static void RunGradFunctors(
paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "tanh_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y))
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::TanhGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::TanhGradFunctor<T>(),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_mul_grad,tanh_grad") {
// The backward of Z = Binary(X, Unary(Y))
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
paddle::operators::math::TanhFunctor<T>,
paddle::operators::math::TanhGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::MulGradFunctor<T>(),
paddle::operators::math::TanhFunctor<T>(),
paddle::operators::math::TanhGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_mul_grad,sigmoid_grad") {
// The backward of Z = Binary(X, Unary(Y))
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
paddle::operators::math::SigmoidFunctor<T>,
paddle::operators::math::SigmoidGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::MulGradFunctor<T>(),
paddle::operators::math::SigmoidFunctor<T>(),
paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else {
PADDLE_THROW("%s has not been implemented.", funcs_str);
}
......
......@@ -78,6 +78,48 @@ struct ReluGradFunctor {
inline HOSTDEVICE T UseXAndOut(T x, T out) { return out > 0 ? 1 : 0; }
};
template <typename T>
struct TanhFunctor {
const T kMin = static_cast<T>(-40);
const T kMax = static_cast<T>(13);
inline HOSTDEVICE T operator()(T x) {
// y = 2 / (1 + e^-2x) - 1
T t0 = 2 * x;
T t1 = (t0 < kMin) ? kMin : ((t0 > kMax) ? kMax : t0);
return static_cast<T>(2) / (static_cast<T>(1) + std::exp(-t1)) -
static_cast<T>(1);
}
};
template <typename T>
struct TanhGradFunctor {
inline HOSTDEVICE T UseX(T x) { return static_cast<T>(1) - x * x; }
inline HOSTDEVICE T UseOut(T out) { return static_cast<T>(1) - out * out; }
inline HOSTDEVICE T UseXAndOut(T x, T out) {
return static_cast<T>(1) - out * out;
}
};
template <typename T>
struct SigmoidFunctor {
const T kMin = static_cast<T>(-40);
const T kMax = static_cast<T>(13);
inline HOSTDEVICE T operator()(T x) {
// y = 1 / (1 + e^-x)
T tmp = (x < kMin) ? kMin : ((x > kMax) ? kMax : x);
return static_cast<T>(1) / (static_cast<T>(1) + std::exp(-tmp));
}
};
template <typename T>
struct SigmoidGradFunctor {
inline HOSTDEVICE T UseX(T x) { return x * (static_cast<T>(1) - x); }
inline HOSTDEVICE T UseOut(T out) { return out * (static_cast<T>(1) - out); }
inline HOSTDEVICE T UseXAndOut(T x, T out) {
return out * (static_cast<T>(1) - out);
}
};
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -36,6 +36,8 @@ from . import model_stat
from .model_stat import *
from . import mixed_precision
from .mixed_precision import *
from . import layers
from .layers import *
__all__ = []
__all__ += decoder.__all__
......@@ -48,3 +50,4 @@ __all__ += slim.__all__
__all__ += utils.__all__
__all__ += extend_optimizer.__all__
__all__ += ['mixed_precision']
__all__ += layers.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import nn
from .nn import *
__all__ = []
__all__ += nn.__all__
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Contrib layers just related to the neural network.
"""
from __future__ import print_function
import numpy as np
import six
import os
import inspect
from paddle.fluid.layer_helper import LayerHelper
__all__ = ['fused_elemwise_activation', ]
def fused_elemwise_activation(x,
y,
functor_list,
axis=-1,
scale=0.0,
save_intermediate_out=True):
"""
**Fused elementwise_add/mul and activation layers**
This function computes an elementwise_add/mul cooperated with an activation.
.. math::
out = Unary(Binary(x, y))
or
.. math::
out = Binary(x, Unary(y))
Unary operators can be: `scale`, `relu`, `tanh`. Binary operators can be:
`elementwise_add`, `elementwise_mul`.
Args:
x (Variable): left operation of the binary operator.
y (Variable): right operator of the binary operator.
functor_list (list of str): types of operator which will be executed
by this layer. For example, ['elementwise_add', 'relu']
(out = elementwise_add(x, relu(y))),
or ['relu', 'elemmentwise_add'] (out = relu(elementwise_add(x, y))).
axis (int32, default -1): axis of elementwise op.
scale (float32, default 0): parameter of scale op.
save_intermediate_out (bool, default True): whether to save the
intermediate result, Unary(y) or Binary(x, y).
Returns:
Variable: The computation result.
"""
if isinstance(functor_list, str):
functor_list = functor_list.split(',')
if not isinstance(functor_list, list) or len(functor_list) != 2:
raise ValueError(
'functor_list should be a list of str, and the length should be 2.')
helper = LayerHelper('fused_elemwise_activation', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
intermediate_out = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='fused_elemwise_activation',
inputs={'X': x,
'Y': y},
outputs={'Out': out,
'IntermediateOut': intermediate_out},
attrs={
'axis': axis,
'scale': scale,
'save_intermediate_out': save_intermediate_out,
'functor_list': functor_list
})
return out
......@@ -121,6 +121,7 @@ packages=['paddle',
'paddle.fluid.contrib.utils',
'paddle.fluid.contrib.extend_optimizer',
'paddle.fluid.contrib.mixed_precision',
'paddle.fluid.contrib.layers',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details',
'paddle.fluid.incubate',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册