未验证 提交 0d46a108 编写于 作者: F Feiyu Chan 提交者: GitHub

[Pten] move paddle/operators/math/functors.h and compound_functors.h (#39514)

* move paddle/operators/math/functors.h
* move paddle/operators/math/compound_functors.h
上级 70714d1b
...@@ -20,8 +20,8 @@ limitations under the License. */ ...@@ -20,8 +20,8 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -20,12 +20,11 @@ limitations under the License. */ ...@@ -20,12 +20,11 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h" #include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
#include "paddle/fluid/operators/fused/fused_dropout_test.h" #include "paddle/fluid/operators/fused/fused_dropout_test.h"
#include "paddle/fluid/operators/math/functors.h" #include "paddle/pten/kernels/funcs/functors.h"
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace details = paddle::operators::details; namespace details = paddle::operators::details;
namespace math = paddle::operators::math;
/** /**
* @brief the unittest of fused_dropout_act_bias * @brief the unittest of fused_dropout_act_bias
...@@ -283,12 +282,14 @@ static void BaseTest(const bool is_fp16 = false) { ...@@ -283,12 +282,14 @@ static void BaseTest(const bool is_fp16 = false) {
} }
TEST(FusedDropout, GPUFusedDorpoutActBias) { TEST(FusedDropout, GPUFusedDorpoutActBias) {
BaseTest<float, math::ReluFunctor<float>, math::ReluGradFunctor<float>>(); BaseTest<float, pten::funcs::ReluFunctor<float>,
pten::funcs::ReluGradFunctor<float>>();
BaseTest<float, paddle::operators::GeluFunctor<float>, BaseTest<float, paddle::operators::GeluFunctor<float>,
paddle::operators::GeluGradFunctor<float>>(); paddle::operators::GeluGradFunctor<float>>();
} }
TEST(FusedDropout, GPUFusedDropoutActBiasDouble) { TEST(FusedDropout, GPUFusedDropoutActBiasDouble) {
BaseTest<double, math::ReluFunctor<double>, math::ReluGradFunctor<double>>(); BaseTest<double, pten::funcs::ReluFunctor<double>,
pten::funcs::ReluGradFunctor<double>>();
BaseTest<double, paddle::operators::GeluFunctor<double>, BaseTest<double, paddle::operators::GeluFunctor<double>,
paddle::operators::GeluGradFunctor<double>>(); paddle::operators::GeluGradFunctor<double>>();
} }
...@@ -296,15 +297,16 @@ TEST(FusedDropout, GPUFusedDropoutActBiasDouble) { ...@@ -296,15 +297,16 @@ TEST(FusedDropout, GPUFusedDropoutActBiasDouble) {
// test fp16, For inference, check_grad is not required. ref: test_dropout_op.py // test fp16, For inference, check_grad is not required. ref: test_dropout_op.py
TEST(FusedDropout, GPUFusedDropoutActBiasFp16) { TEST(FusedDropout, GPUFusedDropoutActBiasFp16) {
using fp16 = platform::float16; using fp16 = platform::float16;
BaseTest<fp16, math::ReluFunctor<fp16>, math::ReluGradFunctor<fp16>>(true); BaseTest<fp16, pten::funcs::ReluFunctor<fp16>,
pten::funcs::ReluGradFunctor<fp16>>(true);
} }
TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) { TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) {
const int rows = 16; const int rows = 16;
const int cols = 16; const int cols = 16;
for (auto is_upscale_in_train : {true, false}) { for (auto is_upscale_in_train : {true, false}) {
TestFusedDropoutActBias<float, math::ReluFunctor<float>, TestFusedDropoutActBias<float, pten::funcs::ReluFunctor<float>,
math::ReluGradFunctor<float>> pten::funcs::ReluGradFunctor<float>>
test(rows, cols, 0, 1.0, is_upscale_in_train, false); test(rows, cols, 0, 1.0, is_upscale_in_train, false);
test.Run(); test.Run();
test.CheckOut(static_cast<float>(1e-5)); test.CheckOut(static_cast<float>(1e-5));
...@@ -315,8 +317,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) { ...@@ -315,8 +317,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) {
TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) { TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) {
const int rows = 16; const int rows = 16;
const int cols = 16; const int cols = 16;
TestFusedDropoutActBias<float, math::ReluFunctor<float>, TestFusedDropoutActBias<float, pten::funcs::ReluFunctor<float>,
math::ReluGradFunctor<float>> pten::funcs::ReluGradFunctor<float>>
test(rows, cols, 0, 0.35, true, true); test(rows, cols, 0, 0.35, true, true);
test.Run(); test.Run();
test.CheckOut(static_cast<float>(1e-5)); test.CheckOut(static_cast<float>(1e-5));
...@@ -326,8 +328,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) { ...@@ -326,8 +328,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) {
TEST(FusedDropout, GPUFusedDropoutActBiasSeed) { TEST(FusedDropout, GPUFusedDropoutActBiasSeed) {
const int rows = 16; const int rows = 16;
const int cols = 16; const int cols = 16;
TestFusedDropoutActBias<float, math::ReluFunctor<float>, TestFusedDropoutActBias<float, pten::funcs::ReluFunctor<float>,
math::ReluGradFunctor<float>> pten::funcs::ReluGradFunctor<float>>
test(rows, cols, 125, 0.0, false, false); test(rows, cols, 125, 0.0, false, false);
test.Run(); test.Run();
test.CheckOut(static_cast<float>(1e-5)); test.CheckOut(static_cast<float>(1e-5));
...@@ -337,8 +339,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasSeed) { ...@@ -337,8 +339,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasSeed) {
TEST(FusedDropout, GPUFusedDropoutActBiasLargeShape) { TEST(FusedDropout, GPUFusedDropoutActBiasLargeShape) {
const int rows = 256; const int rows = 256;
const int cols = 4096; const int cols = 4096;
TestFusedDropoutActBias<float, math::ReluFunctor<float>, TestFusedDropoutActBias<float, pten::funcs::ReluFunctor<float>,
math::ReluGradFunctor<float>> pten::funcs::ReluGradFunctor<float>>
test(rows, cols); test(rows, cols);
test.Run(); test.Run();
test.CheckOut(static_cast<float>(1e-5)); test.CheckOut(static_cast<float>(1e-5));
......
...@@ -21,12 +21,12 @@ limitations under the License. */ ...@@ -21,12 +21,12 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/pten/kernels/funcs/functors.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h" #include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h" #include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" #include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
#include "paddle/fluid/operators/math/functors.h" #include "paddle/pten/kernels/funcs/functors.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -167,8 +167,8 @@ class FusedDropoutHelper { ...@@ -167,8 +167,8 @@ class FusedDropoutHelper {
dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train, dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train,
dropout_param_.is_test, src, bias, out, mask, ctx); dropout_param_.is_test, src, bias, out, mask, ctx);
} else if (act_method == "relu") { } else if (act_method == "relu") {
math::ReluFunctor<T> relu; pten::funcs::ReluFunctor<T> relu;
LaunchDropoutActBias<T, MaskType, math::ReluFunctor<T>>( LaunchDropoutActBias<T, MaskType, pten::funcs::ReluFunctor<T>>(
relu, dropout_param_.seed, rows_, cols_, increment, relu, dropout_param_.seed, rows_, cols_, increment,
dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train, dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train,
dropout_param_.is_test, src, bias, out, mask, ctx); dropout_param_.is_test, src, bias, out, mask, ctx);
...@@ -187,8 +187,8 @@ class FusedDropoutHelper { ...@@ -187,8 +187,8 @@ class FusedDropoutHelper {
gelu_grad, dout, mask, src, bias, dropout_param_.dropout_prob, gelu_grad, dout, mask, src, bias, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx); dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
} else if (act_method == "relu") { } else if (act_method == "relu") {
math::ReluGradFunctor<T> relu_grad; pten::funcs::ReluGradFunctor<T> relu_grad;
LaunchDropoutActBiasGrad<T, MaskType, math::ReluGradFunctor<T>>( LaunchDropoutActBiasGrad<T, MaskType, pten::funcs::ReluGradFunctor<T>>(
relu_grad, dout, mask, src, bias, dropout_param_.dropout_prob, relu_grad, dout, mask, src, bias, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx); dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
} else { } else {
......
...@@ -19,8 +19,9 @@ limitations under the License. */ ...@@ -19,8 +19,9 @@ limitations under the License. */
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" #include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/compound_functors.h" #include "paddle/pten/kernels/funcs/compound_functors.h"
#include "paddle/fluid/operators/math/functors.h" #include "paddle/pten/kernels/funcs/elementwise_functor.h"
#include "paddle/pten/kernels/funcs/functors.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -53,22 +54,22 @@ static void RunBinaryCompoundFunctor( ...@@ -53,22 +54,22 @@ static void RunBinaryCompoundFunctor(
// intermediate_out = Unary(Y) // intermediate_out = Unary(Y)
// out = Binary(X, Unary(Y)) // out = Binary(X, Unary(Y))
// In this case, the shape of intermediate_out and out are different. // In this case, the shape of intermediate_out and out are different.
paddle::operators::math::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor> pten::funcs::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
compound_func(binary_functor, unary_functor); compound_func(binary_functor, unary_functor);
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
if (ctx.Attr<bool>("save_intermediate_out")) { if (ctx.Attr<bool>("save_intermediate_out")) {
FusedElemwiseAndActComputeEx<DeviceContext, T, FusedElemwiseAndActComputeEx<
paddle::operators::math::BinaryCompoundFunctor< DeviceContext, T,
T, BinaryFunctor, UnaryFunctor>, pten::funcs::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>,
true /*KeepIntermediateValue*/, true /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>( false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
} else { } else {
FusedElemwiseAndActComputeEx<DeviceContext, T, FusedElemwiseAndActComputeEx<
paddle::operators::math::BinaryCompoundFunctor< DeviceContext, T,
T, BinaryFunctor, UnaryFunctor>, pten::funcs::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>,
false /*KeepIntermediateValue*/, false /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>( false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
} }
} }
...@@ -85,22 +86,22 @@ static void RunUnaryCompoundFunctors( ...@@ -85,22 +86,22 @@ static void RunUnaryCompoundFunctors(
// In this case, the shape of intermediate_out and out are the same. // In this case, the shape of intermediate_out and out are the same.
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
paddle::operators::math::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor> pten::funcs::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
compound_func(unary_functor, binary_functor); compound_func(unary_functor, binary_functor);
if (ctx.Attr<bool>("save_intermediate_out")) { if (ctx.Attr<bool>("save_intermediate_out")) {
FusedElemwiseAndActComputeEx<DeviceContext, T, FusedElemwiseAndActComputeEx<
paddle::operators::math::UnaryCompoundFunctor< DeviceContext, T,
T, UnaryFunctor, BinaryFunctor>, pten::funcs::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>,
true /*KeepIntermediateValue*/, true /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>( true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
} else { } else {
FusedElemwiseAndActComputeEx<DeviceContext, T, FusedElemwiseAndActComputeEx<
paddle::operators::math::UnaryCompoundFunctor< DeviceContext, T,
T, UnaryFunctor, BinaryFunctor>, pten::funcs::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>,
false /*KeepIntermediateValue*/, false /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>( true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]); ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
} }
} }
...@@ -120,13 +121,12 @@ static void RunBinaryCompoundGradFunctors( ...@@ -120,13 +121,12 @@ static void RunBinaryCompoundGradFunctors(
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
using BinaryCompoundDxFunctor = using BinaryCompoundDxFunctor =
paddle::operators::math::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor, pten::funcs::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
UnaryFunctor>; UnaryFunctor>;
using BinaryCompoundDyFunctor = using BinaryCompoundDyFunctor = pten::funcs::BinaryCompoundGradDyFunctor<
paddle::operators::math::BinaryCompoundGradDyFunctor< T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
using BinaryCompoundDIntermedaiteOutFunctor = using BinaryCompoundDIntermedaiteOutFunctor =
paddle::operators::math::BinaryCompoundGradDIntermedaiteOutFunctor< pten::funcs::BinaryCompoundGradDIntermedaiteOutFunctor<
T, BinaryGradFunctor, UnaryFunctor>; T, BinaryGradFunctor, UnaryFunctor>;
if (in_intermediate_out) { if (in_intermediate_out) {
...@@ -170,14 +170,12 @@ static void RunUnaryCompoundGradFunctors( ...@@ -170,14 +170,12 @@ static void RunUnaryCompoundGradFunctors(
// Z = Unary(Binary(X, Y)) // Z = Unary(Binary(X, Y))
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
using UnaryCompoundDxFunctor = using UnaryCompoundDxFunctor = pten::funcs::UnaryCompoundGradDxFunctor<
paddle::operators::math::UnaryCompoundGradDxFunctor< T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>; using UnaryCompoundDyFunctor = pten::funcs::UnaryCompoundGradDyFunctor<
using UnaryCompoundDyFunctor = T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
paddle::operators::math::UnaryCompoundGradDyFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDIntermediateFunctor = using UnaryCompoundDIntermediateFunctor =
paddle::operators::math::UnaryCompoundGradDIntermediateFunctor< pten::funcs::UnaryCompoundGradDIntermediateFunctor<
T, UnaryGradFunctor, BinaryFunctor, InPlace>; T, UnaryGradFunctor, BinaryFunctor, InPlace>;
if (in_intermediate_out) { if (in_intermediate_out) {
...@@ -219,69 +217,60 @@ static void RunFunctors(const framework::ExecutionContext &ctx, ...@@ -219,69 +217,60 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
if (funcs_str == "elementwise_add,scale") { if (funcs_str == "elementwise_add,scale") {
// Z = Binary(X, Unary(Y)) // Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale")); T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundFunctor<DeviceContext, T, RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::AddFunctor<T>,
paddle::operators::math::AddFunctor<T>, pten::funcs::ScaleFunctor<T>>(
paddle::operators::math::ScaleFunctor<T>>( ctx, pten::funcs::AddFunctor<T>(), pten::funcs::ScaleFunctor<T>(scale),
ctx, paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
} else if (funcs_str == "scale,elementwise_add") { } else if (funcs_str == "scale,elementwise_add") {
// Z = Unary(Binary(X, Y)) // Z = Unary(Binary(X, Y))
T scale = static_cast<T>(ctx.Attr<float>("scale")); T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunUnaryCompoundFunctors<DeviceContext, T, RunUnaryCompoundFunctors<DeviceContext, T, pten::funcs::ScaleFunctor<T>,
paddle::operators::math::ScaleFunctor<T>, pten::funcs::AddFunctor<T>>(
paddle::operators::math::AddFunctor<T>>( ctx, pten::funcs::ScaleFunctor<T>(scale), pten::funcs::AddFunctor<T>(),
ctx, paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "elementwise_add,relu") { } else if (funcs_str == "elementwise_add,relu") {
// Z = Binary(X, Unary(Y)) // Z = Binary(X, Unary(Y))
RunBinaryCompoundFunctor<DeviceContext, T, RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::AddFunctor<T>,
paddle::operators::math::AddFunctor<T>, pten::funcs::ReluFunctor<T>>(
paddle::operators::math::ReluFunctor<T>>( ctx, pten::funcs::AddFunctor<T>(), pten::funcs::ReluFunctor<T>(), in_x,
ctx, paddle::operators::math::AddFunctor<T>(), in_y, outputs);
paddle::operators::math::ReluFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "relu,elementwise_add") { } else if (funcs_str == "relu,elementwise_add") {
// Z = Unary(Binary(X, Y)) // Z = Unary(Binary(X, Y))
RunUnaryCompoundFunctors<DeviceContext, T, RunUnaryCompoundFunctors<DeviceContext, T, pten::funcs::ReluFunctor<T>,
paddle::operators::math::ReluFunctor<T>, pten::funcs::AddFunctor<T>>(
paddle::operators::math::AddFunctor<T>>( ctx, pten::funcs::ReluFunctor<T>(), pten::funcs::AddFunctor<T>(), in_x,
ctx, paddle::operators::math::ReluFunctor<T>(), in_y, outputs);
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "elementwise_mul,scale") { } else if (funcs_str == "elementwise_mul,scale") {
// Z = Binary(X, Unary(Y)) // Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale")); T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundFunctor<DeviceContext, T, RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::MultiplyFunctor<T>,
paddle::operators::math::MulFunctor<T>, pten::funcs::ScaleFunctor<T>>(
paddle::operators::math::ScaleFunctor<T>>( ctx, pten::funcs::MultiplyFunctor<T>(),
ctx, paddle::operators::math::MulFunctor<T>(), pten::funcs::ScaleFunctor<T>(scale), in_x, in_y, outputs);
paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
} else if (funcs_str == "tanh,elementwise_add") { } else if (funcs_str == "tanh,elementwise_add") {
// Z = Unary(Binary(X, Y)) // Z = Unary(Binary(X, Y))
RunUnaryCompoundFunctors<DeviceContext, T, RunUnaryCompoundFunctors<DeviceContext, T, pten::funcs::TanhFunctor<T>,
paddle::operators::math::TanhFunctor<T>, pten::funcs::AddFunctor<T>>(
paddle::operators::math::AddFunctor<T>>( ctx, pten::funcs::TanhFunctor<T>(), pten::funcs::AddFunctor<T>(), in_x,
ctx, paddle::operators::math::TanhFunctor<T>(), in_y, outputs);
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "elementwise_mul,tanh") { } else if (funcs_str == "elementwise_mul,tanh") {
// Z = Binary(X, Unary(Y)) // Z = Binary(X, Unary(Y))
RunBinaryCompoundFunctor<DeviceContext, T, RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::MultiplyFunctor<T>,
paddle::operators::math::MulFunctor<T>, pten::funcs::TanhFunctor<T>>(
paddle::operators::math::TanhFunctor<T>>( ctx, pten::funcs::MultiplyFunctor<T>(), pten::funcs::TanhFunctor<T>(),
ctx, paddle::operators::math::MulFunctor<T>(), in_x, in_y, outputs);
paddle::operators::math::TanhFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "elementwise_mul,sigmoid") { } else if (funcs_str == "elementwise_mul,sigmoid") {
// Z = Binary(X, Unary(Y)) // Z = Binary(X, Unary(Y))
RunBinaryCompoundFunctor<DeviceContext, T, RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::MultiplyFunctor<T>,
paddle::operators::math::MulFunctor<T>, pten::funcs::SigmoidFunctor<T>>(
paddle::operators::math::SigmoidFunctor<T>>( ctx, pten::funcs::MultiplyFunctor<T>(),
ctx, paddle::operators::math::MulFunctor<T>(), pten::funcs::SigmoidFunctor<T>(), in_x, in_y, outputs);
paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "gelu,elementwise_add") { } else if (funcs_str == "gelu,elementwise_add") {
// Z = Unary(Binary(X, Y)) // Z = Unary(Binary(X, Y))
RunUnaryCompoundFunctors<DeviceContext, T, RunUnaryCompoundFunctors<DeviceContext, T, pten::funcs::GeluFunctor<T>,
paddle::operators::math::GeluFunctor<T>, pten::funcs::AddFunctor<T>>(
paddle::operators::math::AddFunctor<T>>( ctx, pten::funcs::GeluFunctor<T>(), pten::funcs::AddFunctor<T>(), in_x,
ctx, paddle::operators::math::GeluFunctor<T>(), in_y, outputs);
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str)); "%s has not been implemented.", funcs_str));
...@@ -301,95 +290,83 @@ static void RunGradFunctors( ...@@ -301,95 +290,83 @@ static void RunGradFunctors(
if (funcs_str == "elementwise_add_grad,scale_grad") { if (funcs_str == "elementwise_add_grad,scale_grad") {
// The backward of Z = Binary(X, Unary(Y)) // The backward of Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale")); T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundGradFunctors< RunBinaryCompoundGradFunctors<DeviceContext, T,
DeviceContext, T, paddle::operators::math::AddGradFunctor<T>, pten::funcs::AddGradFunctor<T>,
paddle::operators::math::ScaleFunctor<T>, pten::funcs::ScaleFunctor<T>,
paddle::operators::math::ScaleGradFunctor<T>, InPlace>( pten::funcs::ScaleGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::AddGradFunctor<T>(), ctx, pten::funcs::AddGradFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale), pten::funcs::ScaleFunctor<T>(scale),
paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out, pten::funcs::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "scale_grad,elementwise_add_grad") { } else if (funcs_str == "scale_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y)) // The backward of Z = Unary(Binary(X, Y))
T scale = static_cast<T>(ctx.Attr<float>("scale")); T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunUnaryCompoundGradFunctors< RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::ScaleGradFunctor<T>, DeviceContext, T, pten::funcs::ScaleGradFunctor<T>,
paddle::operators::math::AddFunctor<T>, pten::funcs::AddFunctor<T>, pten::funcs::AddGradFunctor<T>, InPlace>(
paddle::operators::math::AddGradFunctor<T>, InPlace>( ctx, pten::funcs::ScaleGradFunctor<T>(scale),
ctx, paddle::operators::math::ScaleGradFunctor<T>(scale), pten::funcs::AddFunctor<T>(), pten::funcs::AddGradFunctor<T>(), in_x,
paddle::operators::math::AddFunctor<T>(), in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad,
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out, d_intermediate_out);
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_add_grad,relu_grad") { } else if (funcs_str == "elementwise_add_grad,relu_grad") {
// The backward of Z = Binary(X, Unary(Y)) // The backward of Z = Binary(X, Unary(Y))
RunBinaryCompoundGradFunctors< RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::AddGradFunctor<T>, DeviceContext, T, pten::funcs::AddGradFunctor<T>,
paddle::operators::math::ReluFunctor<T>, pten::funcs::ReluFunctor<T>, pten::funcs::ReluGradFunctor<T>, InPlace>(
paddle::operators::math::ReluGradFunctor<T>, InPlace>( ctx, pten::funcs::AddGradFunctor<T>(), pten::funcs::ReluFunctor<T>(),
ctx, paddle::operators::math::AddGradFunctor<T>(), pten::funcs::ReluGradFunctor<T>(), in_x, in_y, in_out,
paddle::operators::math::ReluFunctor<T>(),
paddle::operators::math::ReluGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "relu_grad,elementwise_add_grad") { } else if (funcs_str == "relu_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y)) // The backward of Z = Unary(Binary(X, Y))
RunUnaryCompoundGradFunctors< RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::ReluGradFunctor<T>, DeviceContext, T, pten::funcs::ReluGradFunctor<T>,
paddle::operators::math::AddFunctor<T>, pten::funcs::AddFunctor<T>, pten::funcs::AddGradFunctor<T>, InPlace>(
paddle::operators::math::AddGradFunctor<T>, InPlace>( ctx, pten::funcs::ReluGradFunctor<T>(), pten::funcs::AddFunctor<T>(),
ctx, paddle::operators::math::ReluGradFunctor<T>(), pten::funcs::AddGradFunctor<T>(), in_x, in_y, in_out,
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_mul_grad,scale_grad") { } else if (funcs_str == "elementwise_mul_grad,scale_grad") {
// The backward of Z = Binary(X, Unary(Y)) // The backward of Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale")); T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundGradFunctors< RunBinaryCompoundGradFunctors<DeviceContext, T,
DeviceContext, T, paddle::operators::math::MulGradFunctor<T>, pten::funcs::MulGradFunctor<T>,
paddle::operators::math::ScaleFunctor<T>, pten::funcs::ScaleFunctor<T>,
paddle::operators::math::ScaleGradFunctor<T>, InPlace>( pten::funcs::ScaleGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::MulGradFunctor<T>(), ctx, pten::funcs::MulGradFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale), pten::funcs::ScaleFunctor<T>(scale),
paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out, pten::funcs::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "tanh_grad,elementwise_add_grad") { } else if (funcs_str == "tanh_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y)) // The backward of Z = Unary(Binary(X, Y))
RunUnaryCompoundGradFunctors< RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::TanhGradFunctor<T>, DeviceContext, T, pten::funcs::TanhGradFunctor<T>,
paddle::operators::math::AddFunctor<T>, pten::funcs::AddFunctor<T>, pten::funcs::AddGradFunctor<T>, InPlace>(
paddle::operators::math::AddGradFunctor<T>, InPlace>( ctx, pten::funcs::TanhGradFunctor<T>(), pten::funcs::AddFunctor<T>(),
ctx, paddle::operators::math::TanhGradFunctor<T>(), pten::funcs::AddGradFunctor<T>(), in_x, in_y, in_out,
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_mul_grad,tanh_grad") { } else if (funcs_str == "elementwise_mul_grad,tanh_grad") {
// The backward of Z = Binary(X, Unary(Y)) // The backward of Z = Binary(X, Unary(Y))
RunBinaryCompoundGradFunctors< RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::MulGradFunctor<T>, DeviceContext, T, pten::funcs::MulGradFunctor<T>,
paddle::operators::math::TanhFunctor<T>, pten::funcs::TanhFunctor<T>, pten::funcs::TanhGradFunctor<T>, InPlace>(
paddle::operators::math::TanhGradFunctor<T>, InPlace>( ctx, pten::funcs::MulGradFunctor<T>(), pten::funcs::TanhFunctor<T>(),
ctx, paddle::operators::math::MulGradFunctor<T>(), pten::funcs::TanhGradFunctor<T>(), in_x, in_y, in_out,
paddle::operators::math::TanhFunctor<T>(),
paddle::operators::math::TanhGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_mul_grad,sigmoid_grad") { } else if (funcs_str == "elementwise_mul_grad,sigmoid_grad") {
// The backward of Z = Binary(X, Unary(Y)) // The backward of Z = Binary(X, Unary(Y))
RunBinaryCompoundGradFunctors< RunBinaryCompoundGradFunctors<DeviceContext, T,
DeviceContext, T, paddle::operators::math::MulGradFunctor<T>, pten::funcs::MulGradFunctor<T>,
paddle::operators::math::SigmoidFunctor<T>, pten::funcs::SigmoidFunctor<T>,
paddle::operators::math::SigmoidGradFunctor<T>, InPlace>( pten::funcs::SigmoidGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::MulGradFunctor<T>(), ctx, pten::funcs::MulGradFunctor<T>(), pten::funcs::SigmoidFunctor<T>(),
paddle::operators::math::SigmoidFunctor<T>(), pten::funcs::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "gelu_grad,elementwise_add_grad") { } else if (funcs_str == "gelu_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y)) // The backward of Z = Unary(Binary(X, Y))
RunUnaryCompoundGradFunctors< RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::GeluGradFunctor<T>, DeviceContext, T, pten::funcs::GeluGradFunctor<T>,
paddle::operators::math::AddFunctor<T>, pten::funcs::AddFunctor<T>, pten::funcs::AddGradFunctor<T>, InPlace>(
paddle::operators::math::AddGradFunctor<T>, InPlace>( ctx, pten::funcs::GeluGradFunctor<T>(), pten::funcs::AddFunctor<T>(),
ctx, paddle::operators::math::GeluGradFunctor<T>(), pten::funcs::AddGradFunctor<T>(), in_x, in_y, in_out,
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out); in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
......
...@@ -122,12 +122,12 @@ __global__ void FusedLayernormResidualDropoutBias( ...@@ -122,12 +122,12 @@ __global__ void FusedLayernormResidualDropoutBias(
__shared__ U shared_mean[32]; __shared__ U shared_mean[32];
__shared__ U shared_var[32]; __shared__ U shared_var[32];
math::ReluFunctor<T> relu; pten::funcs::ReluFunctor<T> relu;
U mean_val = 0; U mean_val = 0;
U var_val = 0; U var_val = 0;
for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) { for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) {
FusedResidualDropoutBiasOneThread<T, MaskType, VecSize, true, false, FusedResidualDropoutBiasOneThread<T, MaskType, VecSize, true, false,
math::ReluFunctor<T>>( pten::funcs::ReluFunctor<T>>(
row_id, i, cols, &state, dropout_prob, factor, src, residual, bias, dst, row_id, i, cols, &state, dropout_prob, factor, src, residual, bias, dst,
mask, is_test, &mean_val, &var_val, relu); mask, is_test, &mean_val, &var_val, relu);
} }
......
...@@ -115,12 +115,12 @@ __global__ void FusedResidualDropoutBias( ...@@ -115,12 +115,12 @@ __global__ void FusedResidualDropoutBias(
curandStatePhilox4_32_10_t state; curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state); curand_init(seed, idx, increment, &state);
const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test); const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
math::ReluFunctor<T> relu; pten::funcs::ReluFunctor<T> relu;
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols; for (int i = col_id * VecSize; i < cols;
i += blockDim.x * gridDim.x * VecSize) { i += blockDim.x * gridDim.x * VecSize) {
FusedResidualDropoutBiasOneThread<T, MaskType, VecSize, false, false, FusedResidualDropoutBiasOneThread<T, MaskType, VecSize, false, false,
math::ReluFunctor<T>>( pten::funcs::ReluFunctor<T>>(
r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst, r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst,
mask, is_test, nullptr, nullptr, relu); mask, is_test, nullptr, nullptr, relu);
} }
......
...@@ -15,8 +15,9 @@ ...@@ -15,8 +15,9 @@
#include <limits> #include <limits>
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/log_softmax_op.h" #include "paddle/fluid/operators/log_softmax_op.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
#include "paddle/pten/kernels/funcs/functors.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -213,15 +214,15 @@ __global__ void LogSoftmaxForwardCUDAKernelNotLastAxis( ...@@ -213,15 +214,15 @@ __global__ void LogSoftmaxForwardCUDAKernelNotLastAxis(
for (int d = threadIdx.x; d < dim_size; d += blockDim.x) { for (int d = threadIdx.x; d < dim_size; d += blockDim.x) {
const AccT value = const AccT value =
static_cast<AccT>(input[data_offset + d * dim_stride]); static_cast<AccT>(input[data_offset + d * dim_stride]);
max_value = math::MaxFunctor<AccT>()(max_value, value); max_value = pten::funcs::MaxFunctor<AccT>()(max_value, value);
} }
// If there are more than 1 threads along block x, reduce all max_values // If there are more than 1 threads along block x, reduce all max_values
// and get the global max_value, which is the max value along "axis". // and get the global max_value, which is the max value along "axis".
// If there is only one thread along block x, no need to reduce, as the // If there is only one thread along block x, no need to reduce, as the
// 'max_value' is the global max_value. // 'max_value' is the global max_value.
if (blockDim.x > 1) { if (blockDim.x > 1) {
max_value = max_value = BlockReduceAlongDimX<AccT, pten::funcs::MaxFunctor>(
BlockReduceAlongDimX<AccT, math::MaxFunctor>(sdata, max_value); sdata, max_value);
} }
// 2. reduce sum // 2. reduce sum
...@@ -232,7 +233,7 @@ __global__ void LogSoftmaxForwardCUDAKernelNotLastAxis( ...@@ -232,7 +233,7 @@ __global__ void LogSoftmaxForwardCUDAKernelNotLastAxis(
max_value); max_value);
} }
if (blockDim.x > 1) { if (blockDim.x > 1) {
sum = BlockReduceAlongDimX<AccT, math::AddFunctor>(sdata, sum); sum = BlockReduceAlongDimX<AccT, pten::funcs::AddFunctor>(sdata, sum);
} }
// 3. input-max-log_sum and write to output // 3. input-max-log_sum and write to output
......
...@@ -18,9 +18,8 @@ limitations under the License. */ ...@@ -18,9 +18,8 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
namespace paddle { namespace pten {
namespace operators { namespace funcs {
namespace math {
// Z = BinaryFunctor(X, UnaryFunctor(Y)) // Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename BinaryFunctor, typename UnaryFunctor> template <typename T, typename BinaryFunctor, typename UnaryFunctor>
...@@ -69,8 +68,8 @@ struct BinaryCompoundGradDxFunctor { ...@@ -69,8 +68,8 @@ struct BinaryCompoundGradDxFunctor {
return dout * d_binary_fun_.Dx(x, unary_fun_(y)); return dout * d_binary_fun_.Dx(x, unary_fun_(y));
} }
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out, inline HOSTDEVICE T
T dout) { UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
return dout * d_binary_fun_.Dx(x, intermediate_out); return dout * d_binary_fun_.Dx(x, intermediate_out);
} }
...@@ -82,8 +81,11 @@ struct BinaryCompoundGradDxFunctor { ...@@ -82,8 +81,11 @@ struct BinaryCompoundGradDxFunctor {
}; };
// Z = BinaryFunctor(X, UnaryFunctor(Y)) // Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename DBinaryFun, typename UnaryFun, template <typename T,
typename DUnaryFun, bool InPlace> typename DBinaryFun,
typename UnaryFun,
typename DUnaryFun,
bool InPlace>
struct BinaryCompoundGradDyFunctor { struct BinaryCompoundGradDyFunctor {
BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun, BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun, const UnaryFun &unary_fun,
...@@ -96,8 +98,8 @@ struct BinaryCompoundGradDyFunctor { ...@@ -96,8 +98,8 @@ struct BinaryCompoundGradDyFunctor {
return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_.UseX(y); return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_.UseX(y);
} }
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out, inline HOSTDEVICE T
T dout) { UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
if (InPlace) { if (InPlace) {
return dout * d_binary_fun_.Dy(x, intermediate_out) * return dout * d_binary_fun_.Dy(x, intermediate_out) *
d_unary_fun_.UseOut(intermediate_out); d_unary_fun_.UseOut(intermediate_out);
...@@ -116,8 +118,11 @@ struct BinaryCompoundGradDyFunctor { ...@@ -116,8 +118,11 @@ struct BinaryCompoundGradDyFunctor {
}; };
// Z = UnaryFunctor(BinaryFunctor(X, Y)) // Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun, template <typename T,
typename DBinaryFun, bool InPlace> typename DUnaryFun,
typename BinaryFun,
typename DBinaryFun,
bool InPlace>
struct UnaryCompoundGradDxFunctor { struct UnaryCompoundGradDxFunctor {
UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun, UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun, const BinaryFun &binary_fun,
...@@ -136,8 +141,8 @@ struct UnaryCompoundGradDxFunctor { ...@@ -136,8 +141,8 @@ struct UnaryCompoundGradDxFunctor {
return base * d_binary_fun_.Dx(x, y); return base * d_binary_fun_.Dx(x, y);
} }
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out, inline HOSTDEVICE T
T dout) { UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
T base; T base;
if (InPlace) { if (InPlace) {
base = dout * d_unary_fun_.UseOut(out); base = dout * d_unary_fun_.UseOut(out);
...@@ -156,8 +161,11 @@ struct UnaryCompoundGradDxFunctor { ...@@ -156,8 +161,11 @@ struct UnaryCompoundGradDxFunctor {
}; };
// Z = UnaryFunctor(BinaryFunctor(X, Y)) // Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun, template <typename T,
typename DBinaryFun, bool InPlace> typename DUnaryFun,
typename BinaryFun,
typename DBinaryFun,
bool InPlace>
struct UnaryCompoundGradDyFunctor { struct UnaryCompoundGradDyFunctor {
UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun, UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun, const BinaryFun &binary_fun,
...@@ -176,8 +184,8 @@ struct UnaryCompoundGradDyFunctor { ...@@ -176,8 +184,8 @@ struct UnaryCompoundGradDyFunctor {
return base * d_binary_fun_.Dy(x, y); return base * d_binary_fun_.Dy(x, y);
} }
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out, inline HOSTDEVICE T
T dout) { UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
T base; T base;
if (InPlace) { if (InPlace) {
base = dout * d_unary_fun_.UseOut(out); base = dout * d_unary_fun_.UseOut(out);
...@@ -206,7 +214,9 @@ struct BinaryCompoundGradDIntermedaiteOutFunctor { ...@@ -206,7 +214,9 @@ struct BinaryCompoundGradDIntermedaiteOutFunctor {
return dout * d_binary_fun_.Dy(x, unary_fun_(y)); return dout * d_binary_fun_.Dy(x, unary_fun_(y));
} }
inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out, inline HOSTDEVICE T UseIntermediateOut(T x,
T intermediate_out,
T out,
T dout) { T dout) {
return dout * d_binary_fun_.Dy(x, intermediate_out); return dout * d_binary_fun_.Dy(x, intermediate_out);
} }
...@@ -233,7 +243,9 @@ struct UnaryCompoundGradDIntermediateFunctor { ...@@ -233,7 +243,9 @@ struct UnaryCompoundGradDIntermediateFunctor {
} }
} }
inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out, inline HOSTDEVICE T UseIntermediateOut(T x,
T intermediate_out,
T out,
T dout) { T dout) {
if (InPlace) { if (InPlace) {
return dout * d_unary_fun_.UseOut(out); return dout * d_unary_fun_.UseOut(out);
...@@ -249,6 +261,5 @@ struct UnaryCompoundGradDIntermediateFunctor { ...@@ -249,6 +261,5 @@ struct UnaryCompoundGradDIntermediateFunctor {
BinaryFun binary_fun_; BinaryFun binary_fun_;
}; };
} // namespace math } // namespace funcs
} // namespace operators } // namespace pten
} // namespace paddle
...@@ -17,16 +17,17 @@ limitations under the License. */ ...@@ -17,16 +17,17 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math.h" #include "paddle/fluid/operators/math.h"
namespace paddle { namespace pten {
namespace operators { namespace funcs {
namespace math {
// // MulFunctor
// MulFunctor // // NOTE(chenfeiyu): IT IS NOLONGER USED, use pten::funcs::MultiplyFunctor
template <typename T> // instead
struct MulFunctor { // template <typename T>
// out = x * y; // struct MulFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return x * y; } // // out = x * y;
}; // inline HOSTDEVICE T operator()(T x, T y) { return x * y; }
// };
template <typename T> template <typename T>
struct MulGradFunctor { struct MulGradFunctor {
...@@ -34,12 +35,13 @@ struct MulGradFunctor { ...@@ -34,12 +35,13 @@ struct MulGradFunctor {
inline HOSTDEVICE T Dy(T x, T y) { return x; } inline HOSTDEVICE T Dy(T x, T y) { return x; }
}; };
// AddFunctor // // AddFunctor
template <typename T> // // NOTE(chenfeiyu): IT IS NOLONGER USED, use pten::funcs::AddFunctor instead
struct AddFunctor { // template <typename T>
// out = x + y; // struct AddFunctor {
inline HOSTDEVICE T operator()(T x, T y) { return x + y; } // // out = x + y;
}; // inline HOSTDEVICE T operator()(T x, T y) { return x + y; }
// };
template <typename T> template <typename T>
struct MaxFunctor { struct MaxFunctor {
...@@ -102,7 +104,8 @@ struct TanhFunctor { ...@@ -102,7 +104,8 @@ struct TanhFunctor {
// y = 2 / (1 + e^-2x) - 1 // y = 2 / (1 + e^-2x) - 1
T t0 = static_cast<T>(2) * x; T t0 = static_cast<T>(2) * x;
T t1 = (t0 < kMin) ? kMin : ((t0 > kMax) ? kMax : t0); T t1 = (t0 < kMin) ? kMin : ((t0 > kMax) ? kMax : t0);
return static_cast<T>(2) / (static_cast<T>(1) + real_exp(-t1)) - return static_cast<T>(2) /
(static_cast<T>(1) + paddle::operators::real_exp(-t1)) -
static_cast<T>(1); static_cast<T>(1);
} }
}; };
...@@ -123,7 +126,8 @@ struct SigmoidFunctor { ...@@ -123,7 +126,8 @@ struct SigmoidFunctor {
inline HOSTDEVICE T operator()(T x) { inline HOSTDEVICE T operator()(T x) {
// y = 1 / (1 + e^-x) // y = 1 / (1 + e^-x)
T tmp = (x < kMin) ? kMin : ((x > kMax) ? kMax : x); T tmp = (x < kMin) ? kMin : ((x > kMax) ? kMax : x);
return static_cast<T>(1) / (static_cast<T>(1) + real_exp(-tmp)); return static_cast<T>(1) /
(static_cast<T>(1) + paddle::operators::real_exp(-tmp));
} }
}; };
...@@ -138,7 +142,7 @@ struct SigmoidGradFunctor { ...@@ -138,7 +142,7 @@ struct SigmoidGradFunctor {
template <typename T> template <typename T>
struct GeluFunctor { struct GeluFunctor {
using MT = typename details::MPTypeTrait<T>::Type; using MT = typename paddle::operators::details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T x) { inline HOSTDEVICE T operator()(T x) {
// this function is tanh approximation of gelu // this function is tanh approximation of gelu
// actual gelu is: // actual gelu is:
...@@ -154,7 +158,7 @@ struct GeluFunctor { ...@@ -154,7 +158,7 @@ struct GeluFunctor {
template <typename T> template <typename T>
struct GeluGradFunctor { struct GeluGradFunctor {
using MT = typename details::MPTypeTrait<T>::Type; using MT = typename paddle::operators::details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T UseX(T x) { inline HOSTDEVICE T UseX(T x) {
MT mx = static_cast<MT>(x); MT mx = static_cast<MT>(x);
MT tanh_out = MT tanh_out =
...@@ -193,6 +197,5 @@ struct GeluGradFunctor { ...@@ -193,6 +197,5 @@ struct GeluGradFunctor {
} }
}; };
} // namespace math } // namespace funcs
} // namespace operators } // namespace pten
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册