未验证 提交 0d46a108 编写于 作者: F Feiyu Chan 提交者: GitHub

[Pten] move paddle/operators/math/functors.h and compound_functors.h (#39514)

* move paddle/operators/math/functors.h
* move paddle/operators/math/compound_functors.h
上级 70714d1b
......@@ -20,8 +20,8 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
......
......@@ -20,12 +20,11 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
#include "paddle/fluid/operators/fused/fused_dropout_test.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/pten/kernels/funcs/functors.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace details = paddle::operators::details;
namespace math = paddle::operators::math;
/**
* @brief the unittest of fused_dropout_act_bias
......@@ -283,12 +282,14 @@ static void BaseTest(const bool is_fp16 = false) {
}
TEST(FusedDropout, GPUFusedDorpoutActBias) {
BaseTest<float, math::ReluFunctor<float>, math::ReluGradFunctor<float>>();
BaseTest<float, pten::funcs::ReluFunctor<float>,
pten::funcs::ReluGradFunctor<float>>();
BaseTest<float, paddle::operators::GeluFunctor<float>,
paddle::operators::GeluGradFunctor<float>>();
}
TEST(FusedDropout, GPUFusedDropoutActBiasDouble) {
BaseTest<double, math::ReluFunctor<double>, math::ReluGradFunctor<double>>();
BaseTest<double, pten::funcs::ReluFunctor<double>,
pten::funcs::ReluGradFunctor<double>>();
BaseTest<double, paddle::operators::GeluFunctor<double>,
paddle::operators::GeluGradFunctor<double>>();
}
......@@ -296,15 +297,16 @@ TEST(FusedDropout, GPUFusedDropoutActBiasDouble) {
// test fp16, For inference, check_grad is not required. ref: test_dropout_op.py
TEST(FusedDropout, GPUFusedDropoutActBiasFp16) {
using fp16 = platform::float16;
BaseTest<fp16, math::ReluFunctor<fp16>, math::ReluGradFunctor<fp16>>(true);
BaseTest<fp16, pten::funcs::ReluFunctor<fp16>,
pten::funcs::ReluGradFunctor<fp16>>(true);
}
TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) {
const int rows = 16;
const int cols = 16;
for (auto is_upscale_in_train : {true, false}) {
TestFusedDropoutActBias<float, math::ReluFunctor<float>,
math::ReluGradFunctor<float>>
TestFusedDropoutActBias<float, pten::funcs::ReluFunctor<float>,
pten::funcs::ReluGradFunctor<float>>
test(rows, cols, 0, 1.0, is_upscale_in_train, false);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
......@@ -315,8 +317,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) {
TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) {
const int rows = 16;
const int cols = 16;
TestFusedDropoutActBias<float, math::ReluFunctor<float>,
math::ReluGradFunctor<float>>
TestFusedDropoutActBias<float, pten::funcs::ReluFunctor<float>,
pten::funcs::ReluGradFunctor<float>>
test(rows, cols, 0, 0.35, true, true);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
......@@ -326,8 +328,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) {
TEST(FusedDropout, GPUFusedDropoutActBiasSeed) {
const int rows = 16;
const int cols = 16;
TestFusedDropoutActBias<float, math::ReluFunctor<float>,
math::ReluGradFunctor<float>>
TestFusedDropoutActBias<float, pten::funcs::ReluFunctor<float>,
pten::funcs::ReluGradFunctor<float>>
test(rows, cols, 125, 0.0, false, false);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
......@@ -337,8 +339,8 @@ TEST(FusedDropout, GPUFusedDropoutActBiasSeed) {
TEST(FusedDropout, GPUFusedDropoutActBiasLargeShape) {
const int rows = 256;
const int cols = 4096;
TestFusedDropoutActBias<float, math::ReluFunctor<float>,
math::ReluGradFunctor<float>>
TestFusedDropoutActBias<float, pten::funcs::ReluFunctor<float>,
pten::funcs::ReluGradFunctor<float>>
test(rows, cols);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
......
......@@ -21,12 +21,12 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/pten/kernels/funcs/functors.h"
namespace paddle {
namespace operators {
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/pten/kernels/funcs/functors.h"
namespace paddle {
namespace operators {
......@@ -167,8 +167,8 @@ class FusedDropoutHelper {
dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train,
dropout_param_.is_test, src, bias, out, mask, ctx);
} else if (act_method == "relu") {
math::ReluFunctor<T> relu;
LaunchDropoutActBias<T, MaskType, math::ReluFunctor<T>>(
pten::funcs::ReluFunctor<T> relu;
LaunchDropoutActBias<T, MaskType, pten::funcs::ReluFunctor<T>>(
relu, dropout_param_.seed, rows_, cols_, increment,
dropout_param_.dropout_prob, dropout_param_.is_upscale_in_train,
dropout_param_.is_test, src, bias, out, mask, ctx);
......@@ -187,8 +187,8 @@ class FusedDropoutHelper {
gelu_grad, dout, mask, src, bias, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
} else if (act_method == "relu") {
math::ReluGradFunctor<T> relu_grad;
LaunchDropoutActBiasGrad<T, MaskType, math::ReluGradFunctor<T>>(
pten::funcs::ReluGradFunctor<T> relu_grad;
LaunchDropoutActBiasGrad<T, MaskType, pten::funcs::ReluGradFunctor<T>>(
relu_grad, dout, mask, src, bias, dropout_param_.dropout_prob,
dropout_param_.is_upscale_in_train, rows_, cols_, d_src, d_bias, ctx);
} else {
......
......@@ -19,8 +19,9 @@ limitations under the License. */
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/compound_functors.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/pten/kernels/funcs/compound_functors.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
#include "paddle/pten/kernels/funcs/functors.h"
namespace paddle {
namespace operators {
......@@ -53,22 +54,22 @@ static void RunBinaryCompoundFunctor(
// intermediate_out = Unary(Y)
// out = Binary(X, Unary(Y))
// In this case, the shape of intermediate_out and out are different.
paddle::operators::math::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
pten::funcs::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>
compound_func(binary_functor, unary_functor);
int axis = ctx.Attr<int>("axis");
if (ctx.Attr<bool>("save_intermediate_out")) {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::BinaryCompoundFunctor<
T, BinaryFunctor, UnaryFunctor>,
true /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
FusedElemwiseAndActComputeEx<
DeviceContext, T,
pten::funcs::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>,
true /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
} else {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::BinaryCompoundFunctor<
T, BinaryFunctor, UnaryFunctor>,
false /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
FusedElemwiseAndActComputeEx<
DeviceContext, T,
pten::funcs::BinaryCompoundFunctor<T, BinaryFunctor, UnaryFunctor>,
false /*KeepIntermediateValue*/,
false /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
}
}
......@@ -85,22 +86,22 @@ static void RunUnaryCompoundFunctors(
// In this case, the shape of intermediate_out and out are the same.
int axis = ctx.Attr<int>("axis");
paddle::operators::math::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
pten::funcs::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>
compound_func(unary_functor, binary_functor);
if (ctx.Attr<bool>("save_intermediate_out")) {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::UnaryCompoundFunctor<
T, UnaryFunctor, BinaryFunctor>,
true /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
FusedElemwiseAndActComputeEx<
DeviceContext, T,
pten::funcs::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>,
true /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
} else {
FusedElemwiseAndActComputeEx<DeviceContext, T,
paddle::operators::math::UnaryCompoundFunctor<
T, UnaryFunctor, BinaryFunctor>,
false /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
FusedElemwiseAndActComputeEx<
DeviceContext, T,
pten::funcs::UnaryCompoundFunctor<T, UnaryFunctor, BinaryFunctor>,
false /*KeepIntermediateValue*/,
true /*SameShapeOfIntermediateOutAndOut*/>(
ctx, in_x, in_y, axis, compound_func, (*outputs)[0], (*outputs)[1]);
}
}
......@@ -120,13 +121,12 @@ static void RunBinaryCompoundGradFunctors(
int axis = ctx.Attr<int>("axis");
using BinaryCompoundDxFunctor =
paddle::operators::math::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
UnaryFunctor>;
using BinaryCompoundDyFunctor =
paddle::operators::math::BinaryCompoundGradDyFunctor<
T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
pten::funcs::BinaryCompoundGradDxFunctor<T, BinaryGradFunctor,
UnaryFunctor>;
using BinaryCompoundDyFunctor = pten::funcs::BinaryCompoundGradDyFunctor<
T, BinaryGradFunctor, UnaryFunctor, UnaryGradFunctor, InPlace>;
using BinaryCompoundDIntermedaiteOutFunctor =
paddle::operators::math::BinaryCompoundGradDIntermedaiteOutFunctor<
pten::funcs::BinaryCompoundGradDIntermedaiteOutFunctor<
T, BinaryGradFunctor, UnaryFunctor>;
if (in_intermediate_out) {
......@@ -170,14 +170,12 @@ static void RunUnaryCompoundGradFunctors(
// Z = Unary(Binary(X, Y))
int axis = ctx.Attr<int>("axis");
using UnaryCompoundDxFunctor =
paddle::operators::math::UnaryCompoundGradDxFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDyFunctor =
paddle::operators::math::UnaryCompoundGradDyFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDxFunctor = pten::funcs::UnaryCompoundGradDxFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDyFunctor = pten::funcs::UnaryCompoundGradDyFunctor<
T, UnaryGradFunctor, BinaryFunctor, BinaryGradFunctor, InPlace>;
using UnaryCompoundDIntermediateFunctor =
paddle::operators::math::UnaryCompoundGradDIntermediateFunctor<
pten::funcs::UnaryCompoundGradDIntermediateFunctor<
T, UnaryGradFunctor, BinaryFunctor, InPlace>;
if (in_intermediate_out) {
......@@ -219,69 +217,60 @@ static void RunFunctors(const framework::ExecutionContext &ctx,
if (funcs_str == "elementwise_add,scale") {
// Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundFunctor<DeviceContext, T,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::ScaleFunctor<T>>(
ctx, paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::AddFunctor<T>,
pten::funcs::ScaleFunctor<T>>(
ctx, pten::funcs::AddFunctor<T>(), pten::funcs::ScaleFunctor<T>(scale),
in_x, in_y, outputs);
} else if (funcs_str == "scale,elementwise_add") {
// Z = Unary(Binary(X, Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunUnaryCompoundFunctors<DeviceContext, T,
paddle::operators::math::ScaleFunctor<T>,
paddle::operators::math::AddFunctor<T>>(
ctx, paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
RunUnaryCompoundFunctors<DeviceContext, T, pten::funcs::ScaleFunctor<T>,
pten::funcs::AddFunctor<T>>(
ctx, pten::funcs::ScaleFunctor<T>(scale), pten::funcs::AddFunctor<T>(),
in_x, in_y, outputs);
} else if (funcs_str == "elementwise_add,relu") {
// Z = Binary(X, Unary(Y))
RunBinaryCompoundFunctor<DeviceContext, T,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::ReluFunctor<T>>(
ctx, paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::ReluFunctor<T>(), in_x, in_y, outputs);
RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::AddFunctor<T>,
pten::funcs::ReluFunctor<T>>(
ctx, pten::funcs::AddFunctor<T>(), pten::funcs::ReluFunctor<T>(), in_x,
in_y, outputs);
} else if (funcs_str == "relu,elementwise_add") {
// Z = Unary(Binary(X, Y))
RunUnaryCompoundFunctors<DeviceContext, T,
paddle::operators::math::ReluFunctor<T>,
paddle::operators::math::AddFunctor<T>>(
ctx, paddle::operators::math::ReluFunctor<T>(),
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
RunUnaryCompoundFunctors<DeviceContext, T, pten::funcs::ReluFunctor<T>,
pten::funcs::AddFunctor<T>>(
ctx, pten::funcs::ReluFunctor<T>(), pten::funcs::AddFunctor<T>(), in_x,
in_y, outputs);
} else if (funcs_str == "elementwise_mul,scale") {
// Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundFunctor<DeviceContext, T,
paddle::operators::math::MulFunctor<T>,
paddle::operators::math::ScaleFunctor<T>>(
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale), in_x, in_y, outputs);
RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::MultiplyFunctor<T>,
pten::funcs::ScaleFunctor<T>>(
ctx, pten::funcs::MultiplyFunctor<T>(),
pten::funcs::ScaleFunctor<T>(scale), in_x, in_y, outputs);
} else if (funcs_str == "tanh,elementwise_add") {
// Z = Unary(Binary(X, Y))
RunUnaryCompoundFunctors<DeviceContext, T,
paddle::operators::math::TanhFunctor<T>,
paddle::operators::math::AddFunctor<T>>(
ctx, paddle::operators::math::TanhFunctor<T>(),
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
RunUnaryCompoundFunctors<DeviceContext, T, pten::funcs::TanhFunctor<T>,
pten::funcs::AddFunctor<T>>(
ctx, pten::funcs::TanhFunctor<T>(), pten::funcs::AddFunctor<T>(), in_x,
in_y, outputs);
} else if (funcs_str == "elementwise_mul,tanh") {
// Z = Binary(X, Unary(Y))
RunBinaryCompoundFunctor<DeviceContext, T,
paddle::operators::math::MulFunctor<T>,
paddle::operators::math::TanhFunctor<T>>(
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::TanhFunctor<T>(), in_x, in_y, outputs);
RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::MultiplyFunctor<T>,
pten::funcs::TanhFunctor<T>>(
ctx, pten::funcs::MultiplyFunctor<T>(), pten::funcs::TanhFunctor<T>(),
in_x, in_y, outputs);
} else if (funcs_str == "elementwise_mul,sigmoid") {
// Z = Binary(X, Unary(Y))
RunBinaryCompoundFunctor<DeviceContext, T,
paddle::operators::math::MulFunctor<T>,
paddle::operators::math::SigmoidFunctor<T>>(
ctx, paddle::operators::math::MulFunctor<T>(),
paddle::operators::math::SigmoidFunctor<T>(), in_x, in_y, outputs);
RunBinaryCompoundFunctor<DeviceContext, T, pten::funcs::MultiplyFunctor<T>,
pten::funcs::SigmoidFunctor<T>>(
ctx, pten::funcs::MultiplyFunctor<T>(),
pten::funcs::SigmoidFunctor<T>(), in_x, in_y, outputs);
} else if (funcs_str == "gelu,elementwise_add") {
// Z = Unary(Binary(X, Y))
RunUnaryCompoundFunctors<DeviceContext, T,
paddle::operators::math::GeluFunctor<T>,
paddle::operators::math::AddFunctor<T>>(
ctx, paddle::operators::math::GeluFunctor<T>(),
paddle::operators::math::AddFunctor<T>(), in_x, in_y, outputs);
RunUnaryCompoundFunctors<DeviceContext, T, pten::funcs::GeluFunctor<T>,
pten::funcs::AddFunctor<T>>(
ctx, pten::funcs::GeluFunctor<T>(), pten::funcs::AddFunctor<T>(), in_x,
in_y, outputs);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"%s has not been implemented.", funcs_str));
......@@ -301,95 +290,83 @@ static void RunGradFunctors(
if (funcs_str == "elementwise_add_grad,scale_grad") {
// The backward of Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
paddle::operators::math::ScaleFunctor<T>,
paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::AddGradFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
RunBinaryCompoundGradFunctors<DeviceContext, T,
pten::funcs::AddGradFunctor<T>,
pten::funcs::ScaleFunctor<T>,
pten::funcs::ScaleGradFunctor<T>, InPlace>(
ctx, pten::funcs::AddGradFunctor<T>(),
pten::funcs::ScaleFunctor<T>(scale),
pten::funcs::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "scale_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::ScaleGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::ScaleGradFunctor<T>(scale),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
DeviceContext, T, pten::funcs::ScaleGradFunctor<T>,
pten::funcs::AddFunctor<T>, pten::funcs::AddGradFunctor<T>, InPlace>(
ctx, pten::funcs::ScaleGradFunctor<T>(scale),
pten::funcs::AddFunctor<T>(), pten::funcs::AddGradFunctor<T>(), in_x,
in_y, in_out, in_intermediate_out, in_out_grad, x_grad, y_grad,
d_intermediate_out);
} else if (funcs_str == "elementwise_add_grad,relu_grad") {
// The backward of Z = Binary(X, Unary(Y))
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::AddGradFunctor<T>,
paddle::operators::math::ReluFunctor<T>,
paddle::operators::math::ReluGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::AddGradFunctor<T>(),
paddle::operators::math::ReluFunctor<T>(),
paddle::operators::math::ReluGradFunctor<T>(), in_x, in_y, in_out,
DeviceContext, T, pten::funcs::AddGradFunctor<T>,
pten::funcs::ReluFunctor<T>, pten::funcs::ReluGradFunctor<T>, InPlace>(
ctx, pten::funcs::AddGradFunctor<T>(), pten::funcs::ReluFunctor<T>(),
pten::funcs::ReluGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "relu_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y))
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::ReluGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::ReluGradFunctor<T>(),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
DeviceContext, T, pten::funcs::ReluGradFunctor<T>,
pten::funcs::AddFunctor<T>, pten::funcs::AddGradFunctor<T>, InPlace>(
ctx, pten::funcs::ReluGradFunctor<T>(), pten::funcs::AddFunctor<T>(),
pten::funcs::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_mul_grad,scale_grad") {
// The backward of Z = Binary(X, Unary(Y))
T scale = static_cast<T>(ctx.Attr<float>("scale"));
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
paddle::operators::math::ScaleFunctor<T>,
paddle::operators::math::ScaleGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::MulGradFunctor<T>(),
paddle::operators::math::ScaleFunctor<T>(scale),
paddle::operators::math::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
RunBinaryCompoundGradFunctors<DeviceContext, T,
pten::funcs::MulGradFunctor<T>,
pten::funcs::ScaleFunctor<T>,
pten::funcs::ScaleGradFunctor<T>, InPlace>(
ctx, pten::funcs::MulGradFunctor<T>(),
pten::funcs::ScaleFunctor<T>(scale),
pten::funcs::ScaleGradFunctor<T>(scale), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "tanh_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y))
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::TanhGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::TanhGradFunctor<T>(),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
DeviceContext, T, pten::funcs::TanhGradFunctor<T>,
pten::funcs::AddFunctor<T>, pten::funcs::AddGradFunctor<T>, InPlace>(
ctx, pten::funcs::TanhGradFunctor<T>(), pten::funcs::AddFunctor<T>(),
pten::funcs::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_mul_grad,tanh_grad") {
// The backward of Z = Binary(X, Unary(Y))
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
paddle::operators::math::TanhFunctor<T>,
paddle::operators::math::TanhGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::MulGradFunctor<T>(),
paddle::operators::math::TanhFunctor<T>(),
paddle::operators::math::TanhGradFunctor<T>(), in_x, in_y, in_out,
DeviceContext, T, pten::funcs::MulGradFunctor<T>,
pten::funcs::TanhFunctor<T>, pten::funcs::TanhGradFunctor<T>, InPlace>(
ctx, pten::funcs::MulGradFunctor<T>(), pten::funcs::TanhFunctor<T>(),
pten::funcs::TanhGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "elementwise_mul_grad,sigmoid_grad") {
// The backward of Z = Binary(X, Unary(Y))
RunBinaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::MulGradFunctor<T>,
paddle::operators::math::SigmoidFunctor<T>,
paddle::operators::math::SigmoidGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::MulGradFunctor<T>(),
paddle::operators::math::SigmoidFunctor<T>(),
paddle::operators::math::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
RunBinaryCompoundGradFunctors<DeviceContext, T,
pten::funcs::MulGradFunctor<T>,
pten::funcs::SigmoidFunctor<T>,
pten::funcs::SigmoidGradFunctor<T>, InPlace>(
ctx, pten::funcs::MulGradFunctor<T>(), pten::funcs::SigmoidFunctor<T>(),
pten::funcs::SigmoidGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else if (funcs_str == "gelu_grad,elementwise_add_grad") {
// The backward of Z = Unary(Binary(X, Y))
RunUnaryCompoundGradFunctors<
DeviceContext, T, paddle::operators::math::GeluGradFunctor<T>,
paddle::operators::math::AddFunctor<T>,
paddle::operators::math::AddGradFunctor<T>, InPlace>(
ctx, paddle::operators::math::GeluGradFunctor<T>(),
paddle::operators::math::AddFunctor<T>(),
paddle::operators::math::AddGradFunctor<T>(), in_x, in_y, in_out,
DeviceContext, T, pten::funcs::GeluGradFunctor<T>,
pten::funcs::AddFunctor<T>, pten::funcs::AddGradFunctor<T>, InPlace>(
ctx, pten::funcs::GeluGradFunctor<T>(), pten::funcs::AddFunctor<T>(),
pten::funcs::AddGradFunctor<T>(), in_x, in_y, in_out,
in_intermediate_out, in_out_grad, x_grad, y_grad, d_intermediate_out);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
......
......@@ -122,12 +122,12 @@ __global__ void FusedLayernormResidualDropoutBias(
__shared__ U shared_mean[32];
__shared__ U shared_var[32];
math::ReluFunctor<T> relu;
pten::funcs::ReluFunctor<T> relu;
U mean_val = 0;
U var_val = 0;
for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) {
FusedResidualDropoutBiasOneThread<T, MaskType, VecSize, true, false,
math::ReluFunctor<T>>(
pten::funcs::ReluFunctor<T>>(
row_id, i, cols, &state, dropout_prob, factor, src, residual, bias, dst,
mask, is_test, &mean_val, &var_val, relu);
}
......
......@@ -115,12 +115,12 @@ __global__ void FusedResidualDropoutBias(
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
math::ReluFunctor<T> relu;
pten::funcs::ReluFunctor<T> relu;
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols;
i += blockDim.x * gridDim.x * VecSize) {
FusedResidualDropoutBiasOneThread<T, MaskType, VecSize, false, false,
math::ReluFunctor<T>>(
pten::funcs::ReluFunctor<T>>(
r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst,
mask, is_test, nullptr, nullptr, relu);
}
......
......@@ -15,8 +15,9 @@
#include <limits>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/log_softmax_op.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/pten/kernels/funcs/elementwise_functor.h"
#include "paddle/pten/kernels/funcs/functors.h"
namespace paddle {
namespace operators {
......@@ -213,15 +214,15 @@ __global__ void LogSoftmaxForwardCUDAKernelNotLastAxis(
for (int d = threadIdx.x; d < dim_size; d += blockDim.x) {
const AccT value =
static_cast<AccT>(input[data_offset + d * dim_stride]);
max_value = math::MaxFunctor<AccT>()(max_value, value);
max_value = pten::funcs::MaxFunctor<AccT>()(max_value, value);
}
// If there are more than 1 threads along block x, reduce all max_values
// and get the global max_value, which is the max value along "axis".
// If there is only one thread along block x, no need to reduce, as the
// 'max_value' is the global max_value.
if (blockDim.x > 1) {
max_value =
BlockReduceAlongDimX<AccT, math::MaxFunctor>(sdata, max_value);
max_value = BlockReduceAlongDimX<AccT, pten::funcs::MaxFunctor>(
sdata, max_value);
}
// 2. reduce sum
......@@ -232,7 +233,7 @@ __global__ void LogSoftmaxForwardCUDAKernelNotLastAxis(
max_value);
}
if (blockDim.x > 1) {
sum = BlockReduceAlongDimX<AccT, math::AddFunctor>(sdata, sum);
sum = BlockReduceAlongDimX<AccT, pten::funcs::AddFunctor>(sdata, sum);
}
// 3. input-max-log_sum and write to output
......
......@@ -18,9 +18,8 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
namespace paddle {
namespace operators {
namespace math {
namespace pten {
namespace funcs {
// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename BinaryFunctor, typename UnaryFunctor>
......@@ -69,8 +68,8 @@ struct BinaryCompoundGradDxFunctor {
return dout * d_binary_fun_.Dx(x, unary_fun_(y));
}
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
inline HOSTDEVICE T
UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
return dout * d_binary_fun_.Dx(x, intermediate_out);
}
......@@ -82,8 +81,11 @@ struct BinaryCompoundGradDxFunctor {
};
// Z = BinaryFunctor(X, UnaryFunctor(Y))
template <typename T, typename DBinaryFun, typename UnaryFun,
typename DUnaryFun, bool InPlace>
template <typename T,
typename DBinaryFun,
typename UnaryFun,
typename DUnaryFun,
bool InPlace>
struct BinaryCompoundGradDyFunctor {
BinaryCompoundGradDyFunctor(const DBinaryFun &d_binary_fun,
const UnaryFun &unary_fun,
......@@ -96,8 +98,8 @@ struct BinaryCompoundGradDyFunctor {
return dout * d_binary_fun_.Dy(x, unary_fun_(y)) * d_unary_fun_.UseX(y);
}
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
inline HOSTDEVICE T
UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
if (InPlace) {
return dout * d_binary_fun_.Dy(x, intermediate_out) *
d_unary_fun_.UseOut(intermediate_out);
......@@ -116,8 +118,11 @@ struct BinaryCompoundGradDyFunctor {
};
// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool InPlace>
template <typename T,
typename DUnaryFun,
typename BinaryFun,
typename DBinaryFun,
bool InPlace>
struct UnaryCompoundGradDxFunctor {
UnaryCompoundGradDxFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
......@@ -136,8 +141,8 @@ struct UnaryCompoundGradDxFunctor {
return base * d_binary_fun_.Dx(x, y);
}
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
inline HOSTDEVICE T
UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
T base;
if (InPlace) {
base = dout * d_unary_fun_.UseOut(out);
......@@ -156,8 +161,11 @@ struct UnaryCompoundGradDxFunctor {
};
// Z = UnaryFunctor(BinaryFunctor(X, Y))
template <typename T, typename DUnaryFun, typename BinaryFun,
typename DBinaryFun, bool InPlace>
template <typename T,
typename DUnaryFun,
typename BinaryFun,
typename DBinaryFun,
bool InPlace>
struct UnaryCompoundGradDyFunctor {
UnaryCompoundGradDyFunctor(const DUnaryFun &d_unary_fun,
const BinaryFun &binary_fun,
......@@ -176,8 +184,8 @@ struct UnaryCompoundGradDyFunctor {
return base * d_binary_fun_.Dy(x, y);
}
inline HOSTDEVICE T UseIntermediateOut(T x, T y, T intermediate_out, T out,
T dout) {
inline HOSTDEVICE T
UseIntermediateOut(T x, T y, T intermediate_out, T out, T dout) {
T base;
if (InPlace) {
base = dout * d_unary_fun_.UseOut(out);
......@@ -206,7 +214,9 @@ struct BinaryCompoundGradDIntermedaiteOutFunctor {
return dout * d_binary_fun_.Dy(x, unary_fun_(y));
}
inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out,
inline HOSTDEVICE T UseIntermediateOut(T x,
T intermediate_out,
T out,
T dout) {
return dout * d_binary_fun_.Dy(x, intermediate_out);
}
......@@ -233,7 +243,9 @@ struct UnaryCompoundGradDIntermediateFunctor {
}
}
inline HOSTDEVICE T UseIntermediateOut(T x, T intermediate_out, T out,
inline HOSTDEVICE T UseIntermediateOut(T x,
T intermediate_out,
T out,
T dout) {
if (InPlace) {
return dout * d_unary_fun_.UseOut(out);
......@@ -249,6 +261,5 @@ struct UnaryCompoundGradDIntermediateFunctor {
BinaryFun binary_fun_;
};
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace pten
......@@ -17,16 +17,17 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/math.h"
namespace paddle {
namespace operators {
namespace math {
// MulFunctor
template <typename T>
struct MulFunctor {
// out = x * y;
inline HOSTDEVICE T operator()(T x, T y) { return x * y; }
};
namespace pten {
namespace funcs {
// // MulFunctor
// // NOTE(chenfeiyu): IT IS NOLONGER USED, use pten::funcs::MultiplyFunctor
// instead
// template <typename T>
// struct MulFunctor {
// // out = x * y;
// inline HOSTDEVICE T operator()(T x, T y) { return x * y; }
// };
template <typename T>
struct MulGradFunctor {
......@@ -34,12 +35,13 @@ struct MulGradFunctor {
inline HOSTDEVICE T Dy(T x, T y) { return x; }
};
// AddFunctor
template <typename T>
struct AddFunctor {
// out = x + y;
inline HOSTDEVICE T operator()(T x, T y) { return x + y; }
};
// // AddFunctor
// // NOTE(chenfeiyu): IT IS NOLONGER USED, use pten::funcs::AddFunctor instead
// template <typename T>
// struct AddFunctor {
// // out = x + y;
// inline HOSTDEVICE T operator()(T x, T y) { return x + y; }
// };
template <typename T>
struct MaxFunctor {
......@@ -102,7 +104,8 @@ struct TanhFunctor {
// y = 2 / (1 + e^-2x) - 1
T t0 = static_cast<T>(2) * x;
T t1 = (t0 < kMin) ? kMin : ((t0 > kMax) ? kMax : t0);
return static_cast<T>(2) / (static_cast<T>(1) + real_exp(-t1)) -
return static_cast<T>(2) /
(static_cast<T>(1) + paddle::operators::real_exp(-t1)) -
static_cast<T>(1);
}
};
......@@ -123,7 +126,8 @@ struct SigmoidFunctor {
inline HOSTDEVICE T operator()(T x) {
// y = 1 / (1 + e^-x)
T tmp = (x < kMin) ? kMin : ((x > kMax) ? kMax : x);
return static_cast<T>(1) / (static_cast<T>(1) + real_exp(-tmp));
return static_cast<T>(1) /
(static_cast<T>(1) + paddle::operators::real_exp(-tmp));
}
};
......@@ -138,7 +142,7 @@ struct SigmoidGradFunctor {
template <typename T>
struct GeluFunctor {
using MT = typename details::MPTypeTrait<T>::Type;
using MT = typename paddle::operators::details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T operator()(T x) {
// this function is tanh approximation of gelu
// actual gelu is:
......@@ -154,7 +158,7 @@ struct GeluFunctor {
template <typename T>
struct GeluGradFunctor {
using MT = typename details::MPTypeTrait<T>::Type;
using MT = typename paddle::operators::details::MPTypeTrait<T>::Type;
inline HOSTDEVICE T UseX(T x) {
MT mx = static_cast<MT>(x);
MT tanh_out =
......@@ -193,6 +197,5 @@ struct GeluGradFunctor {
}
};
} // namespace math
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace pten
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册