未验证 提交 b18cbfb2 编写于 作者: Z zhangkaihuo 提交者: GitHub

add op: fused_feedforward(forward) (#35843)

这个PR只包含fused_feedforward前向的代码。

相关kernel实现:fused_dropout_act_bias, fused_residual_dropout_bias, fused_layernorm_residual_dropout_bias

fused_feedforward是一个融合算子,该算子对transformer模型的feed forward层的算子进行融合和封装,使得前端只呈现一个接口,通过融合减少部分访存和kernel launch的时间,以此提升性能。
上级 087c3abe
......@@ -217,7 +217,8 @@ function(op_library TARGET)
"fusion_transpose_flatten_concat_op" "fusion_conv_inception_op"
"sync_batch_norm_op" "sparse_attention_op" "dgc_op" "fused_fc_elementwise_layernorm_op"
"skip_layernorm_op" "multihead_matmul_op" "fusion_group_op" "fused_bn_activation_op" "fused_embedding_eltwise_layernorm_op" "fusion_gru_op" "fusion_lstm_op"
"fused_bn_add_activation_op" "fused_attention_op" "resnet_unit_op")
"fused_bn_add_activation_op" "fused_attention_op" "resnet_unit_op" "fused_feedforward_op")
if ("${TARGET}" STREQUAL "${manual_pybind_op}")
set(pybind_flag 1)
endif()
......
......@@ -18,6 +18,7 @@ register_operators(EXCLUDES
fused_bn_add_activation_op
fused_attention_op
fused_transformer_op
fused_feedforward_op
resnet_unit_op)
# fusion_gru_op does not have CUDA kernel
......@@ -79,6 +80,11 @@ if (WITH_GPU OR WITH_ROCM)
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
op_library(fused_feedforward_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_feedforward);\n")
# fused_attention_op
op_library(fused_attention_op)
file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_attention);\n")
......
......@@ -110,27 +110,34 @@ inline __device__ void CalculateDBias(const T *tmp_sum, T *dbias,
}
__syncthreads();
// reduce sum
T sum = static_cast<T>(0);
T sum[2] = {static_cast<T>(0)};
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 5; // warp id
int y = tid & 31; // thread id on warp 0~31
// need BlockSizeX * VecSize warps
if (x < BlockSizeX * VecSize) {
for (int j = x; j < BlockSizeX * VecSize; j += 32) {
// reduce 128 to 32
#pragma unroll
for (int i = 0; i < (BlockSizeY >> 5); i++) {
sum += cache[x][y + i * 32];
sum[(j >> 5)] += cache[j][y + i * 32];
}
}
int reduce_num_pre_thread = (BlockSizeX * VecSize + 31) / 32;
// reduce 32 to 1
sum = WarpReduceSum(sum);
for (int i = 0; i < reduce_num_pre_thread; i++) {
sum[i] = WarpReduceSum(sum[i]);
}
// save sum to dbias
int bias_id = blockIdx.x * blockDim.x * VecSize + x;
if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) {
dbias[bias_id] = sum;
if (y == 0 && x < BlockSizeX * VecSize) {
for (int i = 0; i < reduce_num_pre_thread; i++) {
int bias_id = blockIdx.x * BlockSizeX * VecSize + x + i * 32;
if (bias_id < cols) {
dbias[bias_id] = sum[i];
}
}
}
}
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/matmul_v2_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class FusedFeedForwardOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "fused_feedforward");
OP_INOUT_CHECK(context->HasInput("Linear1Weight"), "Input", "Linear1Weight",
"fused_feedforward");
OP_INOUT_CHECK(context->HasInput("Linear2Weight"), "Input", "Linear2Weight",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Dropout1Mask"), "Output", "Dropout1Mask",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Dropout2Mask"), "Output", "Dropout2Mask",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Mean"), "Output", "Ln1Mean",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Variance"), "Output", "Ln1Variance",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln2Mean"), "Output", "Ln2Mean",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln2Variance"), "Output", "Ln2Variance",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Linear1Out"), "Output", "Linear1Out",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Ln1Out"), "Output", "Ln1Out",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Dropout1Out"), "Output", "Dropout1Out",
"fused_feedforward");
OP_INOUT_CHECK(context->HasOutput("Dropout2Out"), "Output", "Dropout2Out",
"fused_feedforward");
auto dim_x = context->GetInputDim("X");
auto mat_dim_x =
math::CreateMatrixDescriptor(RowMatrixFromVector(dim_x), 0, false);
// verify for the pre layer_norm, the feature size must be larger than 1
PADDLE_ENFORCE_GT(
mat_dim_x.width_, static_cast<size_t>(1),
platform::errors::InvalidArgument("Product from the X shape[1] to "
"shape[n-1] must be larger than 1!"));
auto dim_Linear1Weight = context->GetInputDim("Linear1Weight");
auto tmp_dim_x = dim_x;
tmp_dim_x[dim_x.size() - 1] =
dim_Linear1Weight[dim_Linear1Weight.size() - 1];
context->SetOutputDim("Out", dim_x);
if (context->Attrs().Get<bool>("dropout1_is_test") == false) {
context->SetOutputDim("Dropout1Mask", tmp_dim_x);
}
context->SetOutputDim("Dropout1Out", tmp_dim_x);
context->SetOutputDim("Linear1Out", tmp_dim_x);
context->SetOutputDim("Ln1Out", dim_x);
context->SetOutputDim("Dropout2Out", dim_x);
if (context->Attrs().Get<bool>("dropout2_is_test") == false) {
context->SetOutputDim("Dropout2Mask", dim_x);
}
framework::DDim mean_dim =
framework::make_ddim({mat_dim_x.batch_size_ * mat_dim_x.height_});
context->SetOutputDim("Ln1Mean", mean_dim);
context->SetOutputDim("Ln1Variance", mean_dim);
context->SetOutputDim("Ln2Mean", mean_dim);
context->SetOutputDim("Ln2Variance", mean_dim);
context->ShareLoD("X", "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
class FusedFeedForwardOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "The input of FusedFeedForward op");
AddInput(
"Dropout1Seed",
"The seed of first dropout op, it has higher priority than the attr "
"fix_seed and seed")
.AsDispensable();
AddInput(
"Dropout2Seed",
"The seed of second dropout op, it has higher priority than the attr "
"fix_seed and seed")
.AsDispensable();
AddInput("Linear1Weight", "The linear1 weight of FusedFeedForward op");
AddInput("Linear1Bias", "The linear1 bias of FusedFeedForward op")
.AsDispensable();
AddInput("Linear2Weight", "The linear2 weight of FusedFeedForward op");
AddInput("Linear2Bias", "The linear2 bias input of FusedFeedForward op")
.AsDispensable();
AddInput("Ln1Scale", "The layer_norm1 scale of FusedFeedForward op")
.AsDispensable();
AddInput("Ln1Bias", "The layer_norm1 bias of FusedFeedForward op")
.AsDispensable();
AddInput("Ln2Scale", "The layer_norm2 scale of FusedFeedForward op")
.AsDispensable();
AddInput("Ln2Bias", "The layer_norm2 bias of FusedFeedForward op")
.AsDispensable();
AddOutput("Out", "The output of FusedFeedForward op");
AddOutput("Dropout1Mask", "The mask of dropout1").AsIntermediate();
AddOutput("Dropout2Mask", "The mask of dropout2").AsIntermediate();
AddOutput("Ln1Mean", "The mean of layer_norm1").AsIntermediate();
AddOutput("Ln1Variance", "The variance of layer_norm1").AsIntermediate();
AddOutput("Ln2Mean", "The mean of layer_nomr2").AsIntermediate();
AddOutput("Ln2Variance", "The variance of layer_norm2").AsIntermediate();
AddOutput("Linear1Out", "The output of linear1").AsIntermediate();
AddOutput("Ln1Out", "The output of layer_norm1").AsIntermediate();
AddOutput("Dropout1Out", "The output of dropout1").AsIntermediate();
AddOutput("Dropout2Out", "The output of dropout2").AsIntermediate();
AddAttr<bool>("pre_layer_norm", "true is pre layernorm").SetDefault(false);
AddAttr<float>("ln1_epsilon", "epsilon of pre layer_norm")
.SetDefault(1e-5f);
AddAttr<float>("ln2_epsilon", "epsilon of post layer_norm")
.SetDefault(1e-5f);
AddAttr<std::string>("act_method", "act_method").SetDefault("gelu");
AddAttr<float>("dropout1_rate", "the dropout rate of first dropout")
.SetDefault(.5f)
.AddCustomChecker([](const float &drop_p) {
PADDLE_ENFORCE_EQ(
drop_p >= 0.0f && drop_p <= 1.0f, true,
platform::errors::InvalidArgument(
"'dropout1_rate' must be between 0.0 and 1.0."));
});
AddAttr<float>("dropout2_rate", "the dropout rate of second dropout")
.SetDefault(.5f)
.AddCustomChecker([](const float &drop_p) {
PADDLE_ENFORCE_EQ(
drop_p >= 0.0f && drop_p <= 1.0f, true,
platform::errors::InvalidArgument(
"'dropout2_rate' must be between 0.0 and 1.0."));
});
AddAttr<std::string>("dropout1_implementation",
"the dropout implementation of first dropout")
.SetDefault("downgrade_in_infer")
.AddCustomChecker([](const std::string &type) {
PADDLE_ENFORCE_EQ(
type == "downgrade_in_infer" || type == "upscale_in_train", true,
platform::errors::InvalidArgument(
"dropout1_implementation can only be downgrade_in_infer or "
"upscale_in_train"));
});
AddAttr<std::string>("dropout2_implementation",
"the dropout implementation of second dropout")
.SetDefault("downgrade_in_infer")
.AddCustomChecker([](const std::string &type) {
PADDLE_ENFORCE_EQ(
type == "downgrade_in_infer" || type == "upscale_in_train", true,
platform::errors::InvalidArgument(
"dropout2_implementation can only be downgrade_in_infer or "
"upscale_in_train"));
});
AddAttr<bool>("dropout1_is_test", "the is_test of first dropout")
.SetDefault(false);
AddAttr<bool>("dropout2_is_test", "the is_test of second dropout")
.SetDefault(false);
AddAttr<bool>("dropout1_fix_seed", "the is_test of first dropout")
.SetDefault(false);
AddAttr<bool>("dropout2_fix_seed", "the is_test of second dropout")
.SetDefault(false);
AddAttr<int>("dropout1_seed", "Dropout1 random seed.").SetDefault(0);
AddAttr<int>("dropout2_seed", "Dropout2 random seed.").SetDefault(0);
AddComment(R"DOC(
the function of fused_feedforward operator is the same as the following pseudo code:
residual = src;
ln1_out = src;
if(pre_layer_norm){
ln1_out = layer_norm(src);
}
out = linear(dropout(activation(dropout(linear(ln1_out)))));
if(!pre_layer_norm) {
out = layer_norm(out);
}
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(fused_feedforward, ops::FusedFeedForwardOp,
ops::FusedFeedForwardOpMaker);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/matmul_v2_op.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class FusedFeedForwardKernel : public framework::OpKernel<T> {
public:
void MatMul(const platform::CUDADeviceContext& ctx,
const framework::Tensor& a, const framework::Tensor& b,
framework::Tensor* c) const {
auto blas = math::GetBlas<DeviceContext, T>(ctx);
auto a_2d = FoldInitDims(a);
auto b_2d = FoldInitDims(b);
auto mat_dim_a = math::CreateMatrixDescriptor(a_2d.dims(), 0, false);
auto mat_dim_b = math::CreateMatrixDescriptor(b_2d.dims(), 0, false);
T alpha = static_cast<T>(1.0);
blas.MatMul(a, mat_dim_a, b, mat_dim_b, alpha, c, T(0));
}
void FFN(const framework::Tensor& x, const framework::Tensor& linear1_weight,
const framework::Tensor* linear1_bias,
const framework::Tensor& linear2_weight,
const framework::Tensor* linear2_bias,
const framework::Tensor* ln1_scale,
const framework::Tensor* ln1_bias,
const framework::Tensor* ln2_scale,
const framework::Tensor* ln2_bias, framework::Tensor* out,
framework::Tensor* dropout1_mask, framework::Tensor* dropout2_mask,
framework::Tensor* ln1_mean, framework::Tensor* ln1_variance,
framework::Tensor* ln2_mean, framework::Tensor* ln2_variance,
framework::Tensor* linear1_out, framework::Tensor* ln1_out,
framework::Tensor* dropout1_out, framework::Tensor* dropout2_out,
const int bsz_seq, const int d_model, const int dim_feedforward,
const std::string& act_method, const bool pre_layer_norm,
const float epsilon1, const float epsilon2,
const DropoutParam& dropout_param1,
const DropoutParam& dropout_param2,
const platform::CUDADeviceContext& ctx) const {
FusedDropoutLayerNormHelper<T, uint8_t> pre_layernorm_helper(
bsz_seq, d_model, epsilon1);
FusedDropoutHelper<T, uint8_t> fused_act_dropout_helper(
ctx, bsz_seq, dim_feedforward, dropout_param1);
FusedDropoutLayerNormHelper<T, uint8_t> fused_dropout_layernorm_helper(
ctx, bsz_seq, d_model, dropout_param2, epsilon2);
auto place = ctx.GetPlace();
using U = LayerNormParamType<T>;
const framework::Tensor* in = &x;
const U* ln1_scale_ptr =
ln1_scale == nullptr ? nullptr : ln1_scale->data<U>();
const U* ln1_bias_ptr = ln1_bias == nullptr ? nullptr : ln1_bias->data<U>();
const U* ln2_scale_ptr =
ln2_scale == nullptr ? nullptr : ln2_scale->data<U>();
const U* ln2_bias_ptr = ln2_bias == nullptr ? nullptr : ln2_bias->data<U>();
const T* linear1_bias_ptr =
linear1_bias == nullptr ? nullptr : linear1_bias->data<T>();
const T* linear2_bias_ptr =
linear2_bias == nullptr ? nullptr : linear2_bias->data<T>();
if (pre_layer_norm) {
pre_layernorm_helper.LayerNorm(
ctx, x.data<T>(), ln1_scale_ptr, ln1_bias_ptr, ln1_out->data<T>(),
ln1_mean->data<U>(), ln1_variance->data<U>());
in = ln1_out;
}
MatMul(ctx, *in, linear1_weight, linear1_out);
fused_act_dropout_helper.DropoutActBias(
ctx, linear1_out->data<T>(), linear1_bias_ptr, act_method,
dropout1_out->data<T>(), dropout1_mask->data<uint8_t>());
framework::Tensor linear2_out;
linear2_out.mutable_data<T>({bsz_seq, d_model}, place);
MatMul(ctx, *dropout1_out, linear2_weight, &linear2_out);
if (!pre_layer_norm) {
fused_dropout_layernorm_helper.LayernormResidualDropoutBias(
ctx, linear2_out.data<T>(), x.data<T>(), linear2_bias_ptr,
ln2_scale_ptr, ln2_bias_ptr, dropout2_out->data<T>(),
dropout2_mask->data<uint8_t>(), out->data<T>(), ln2_mean->data<U>(),
ln2_variance->data<U>());
} else {
fused_dropout_layernorm_helper.ResidualDropoutBias(
ctx, linear2_out.data<T>(), x.data<T>(), linear2_bias_ptr,
out->data<T>(), dropout2_mask->data<uint8_t>());
}
}
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* linear1_weight = context.Input<framework::Tensor>("Linear1Weight");
auto* linear1_bias = context.Input<framework::Tensor>("Linear1Bias");
auto* linear2_weight = context.Input<framework::Tensor>("Linear2Weight");
auto* linear2_bias = context.Input<framework::Tensor>("Linear2Bias");
auto* ln1_scale = context.Input<framework::Tensor>("Ln1Scale");
auto* ln1_bias = context.Input<framework::Tensor>("Ln1Bias");
auto* ln2_scale = context.Input<framework::Tensor>("Ln2Scale");
auto* ln2_bias = context.Input<framework::Tensor>("Ln2Bias");
auto* ln1_mean = context.Output<framework::Tensor>("Ln1Mean");
auto* ln1_variance = context.Output<framework::Tensor>("Ln1Variance");
auto* ln2_mean = context.Output<framework::Tensor>("Ln2Mean");
auto* ln2_variance = context.Output<framework::Tensor>("Ln2Variance");
auto* out = context.Output<framework::Tensor>("Out");
auto* dropout1_mask = context.Output<framework::Tensor>("Dropout1Mask");
auto* dropout2_mask = context.Output<framework::Tensor>("Dropout2Mask");
auto* linear1_out = context.Output<framework::Tensor>("Linear1Out");
auto* ln1_out = context.Output<framework::Tensor>("Ln1Out");
auto* dropout1_out = context.Output<framework::Tensor>("Dropout1Out");
auto* dropout2_out = context.Output<framework::Tensor>("Dropout2Out");
const std::string act_method = context.Attr<std::string>("act_method");
const bool pre_layer_norm = context.Attr<bool>("pre_layer_norm");
const float epsilon1 = context.Attr<float>("ln1_epsilon");
const float epsilon2 = context.Attr<float>("ln2_epsilon");
DropoutParam dropout_param1(context, 1);
DropoutParam dropout_param2(context, 2);
using U = LayerNormParamType<T>;
auto place = context.GetPlace();
out->mutable_data<T>(place);
dropout1_mask->mutable_data<uint8_t>(place);
dropout2_mask->mutable_data<uint8_t>(place);
ln1_mean->mutable_data<U>(place);
ln1_variance->mutable_data<U>(place);
ln2_mean->mutable_data<U>(place);
ln2_variance->mutable_data<U>(place);
linear1_out->mutable_data<T>(place);
ln1_out->mutable_data<T>(place);
dropout1_out->mutable_data<T>(place);
dropout2_out->mutable_data<T>(place);
auto x_dim = x->dims();
auto mat_dim_x =
math::CreateMatrixDescriptor(RowMatrixFromVector(x_dim), 0, false);
auto dim = linear1_weight->dims();
int d_model = dim[0];
int dim_feedforward = dim[dim.size() - 1];
int bsz_seq = mat_dim_x.batch_size_ * mat_dim_x.height_;
FFN(*x, *linear1_weight, linear1_bias, *linear2_weight, linear2_bias,
ln1_scale, ln1_bias, ln2_scale, ln2_bias, out, dropout1_mask,
dropout2_mask, ln1_mean, ln1_variance, ln2_mean, ln2_variance,
linear1_out, ln1_out, dropout1_out, dropout2_out, bsz_seq, d_model,
dim_feedforward, act_method, pre_layer_norm, epsilon1, epsilon2,
dropout_param1, dropout_param2, context.cuda_device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
fused_feedforward,
ops::FusedFeedForwardKernel<paddle::platform::CUDADeviceContext, float>,
ops::FusedFeedForwardKernel<paddle::platform::CUDADeviceContext, double>,
ops::FusedFeedForwardKernel<paddle::platform::CUDADeviceContext,
paddle::platform::float16>);
......@@ -71,6 +71,9 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"sparse_momentum", {"Param", "Grad", "Velocity", "Index", "LearningRate"}},
{"rnn", {"Input", "PreState", "WeightList", "SequenceLength"}},
{"run_program", {"X", "Params"}},
{"fused_feedforward",
{"Dropout1Seed", "Dropout2Seed", "Linear1Bias", "Linear2Bias", "Ln1Scale",
"Ln1Bias", "Ln2Scale", "Ln2Bias"}},
{"faster_tokenizer", {"Text", "Vocab", "TextPair"}},
{"matrix_rank", {"X", "TolTensor"}},
{"adam",
......
......@@ -98,6 +98,8 @@ foreach(TEST_OP ${MIXED_DIST_TEST_OPS})
endforeach()
if(NOT WITH_GPU)
LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
endif()
......@@ -377,14 +379,14 @@ function(bash_test_modules TARGET_NAME)
if(WITH_COVERAGE)
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS}
WITH_COVERAGE=ON COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_START_BASH}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
else()
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${bash_test_modules_ENVS}
bash ${CMAKE_CURRENT_BINARY_DIR}/${bash_test_modules_START_BASH}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
......@@ -419,14 +421,14 @@ function(parallel_bash_test_modules TARGET_NAME)
if(WITH_COVERAGE)
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${parallel_bash_test_modules_ENVS} UnitTests=${uts_string}
WITH_COVERAGE=ON COVERAGE_FILE=${PADDLE_BINARY_DIR}/python-coverage.data
bash ${CMAKE_CURRENT_BINARY_DIR}/${parallel_bash_test_modules_START_BASH}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
else()
add_test(NAME ${TARGET_NAME}
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${PADDLE_BINARY_DIR}/python
TEST_TARGET_NAME=${TARGET_NAME} TEST_TIMEOUT=${timeout} ${parallel_bash_test_modules_ENVS} UnitTests=${uts_string}
bash ${CMAKE_CURRENT_BINARY_DIR}/${parallel_bash_test_modules_START_BASH}
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.nn.layer import transformer
import paddle.nn.functional as F
from paddle.nn.layer.norm import LayerNorm
from paddle.nn.layer.common import Linear, Dropout
import unittest
from op_test import OpTest
class TestFusedFFNOp(OpTest):
def getDtype(self):
self.dtype = "float32"
self.layer_norm_dtype = "float32"
def getShape(self):
self.batch_size = np.random.randint(1, 64)
self.query_length = np.random.randint(32, 256)
self.d_model = np.random.randint(32, 1024)
self.dim_feedforward = np.random.randint(32, 1024)
def getDiff(self):
self.rtol = 1e-3
self.atol = 1e-4
def getActivation(self):
self.act_method = "gelu"
def getNormalizeBefore(self):
self.pre_layer_norm = False
def setUp(self):
paddle.disable_static()
self.__class__.op_type = "fused_feedforward"
self.getDtype()
self.getShape()
self.getDiff()
self.getActivation()
self.getNormalizeBefore()
paddle.set_default_dtype(self.dtype)
self.weight_attr = None
self.bias_attr = None
self.weight_attrs = transformer._convert_param_attr_to_list(
self.weight_attr, 2)
self.bias_attrs = transformer._convert_param_attr_to_list(
self.bias_attr, 2)
self.linear1 = Linear(
self.d_model,
self.dim_feedforward,
self.weight_attrs[1],
bias_attr=self.bias_attrs[1])
self.linear2 = Linear(
self.dim_feedforward,
self.d_model,
self.weight_attrs[1],
bias_attr=self.bias_attrs[1])
paddle.set_default_dtype(self.layer_norm_dtype)
self.norm1 = LayerNorm(self.d_model)
self.norm2 = LayerNorm(self.d_model)
self.dropout = Dropout(0.0, mode="upscale_in_train")
self.dropout1 = Dropout(0.0, mode="upscale_in_train")
self.dropout2 = Dropout(0.0, mode="upscale_in_train")
self.activation = getattr(F, self.act_method)
self.src = np.random.random((self.batch_size, self.query_length,
self.d_model)).astype(self.dtype)
def Base(self):
paddle.disable_static()
tensor_src = paddle.to_tensor(self.src, stop_gradient=False)
residual = paddle.to_tensor(self.src)
if self.pre_layer_norm:
ln1_out = self.norm1(tensor_src)
linear2_out = self.linear2(
self.dropout(self.activation(self.linear1(ln1_out))))
dropout2_out = residual + self.dropout2(linear2_out)
else:
linear2_out = self.linear2(
self.dropout(self.activation(self.linear1(tensor_src))))
dropout2_out = residual + self.dropout2(linear2_out)
dropout2_out = self.norm2(dropout2_out)
return dropout2_out
def FusedFFN(self):
paddle.disable_static()
linear1_weight = paddle.to_tensor(
self.linear1.weight, stop_gradient=False)
linear1_bias = paddle.to_tensor(self.linear1.bias, stop_gradient=False)
linear2_weight = paddle.to_tensor(
self.linear2.weight, stop_gradient=False)
linear2_bias = paddle.to_tensor(self.linear2.bias, stop_gradient=False)
ln1_scale = paddle.to_tensor(self.norm1.weight, stop_gradient=False)
ln1_bias = paddle.to_tensor(self.norm1.bias, stop_gradient=False)
ln2_scale = paddle.to_tensor(self.norm2.weight, stop_gradient=False)
ln2_bias = paddle.to_tensor(self.norm2.bias, stop_gradient=False)
x = paddle.to_tensor(self.src, stop_gradient=False)
out = F.fused_feedforward(
x,
linear1_weight,
linear2_weight,
linear1_bias,
linear2_bias,
ln1_scale,
ln1_bias,
ln2_scale,
ln2_bias,
0.0,
0.0,
activation=self.act_method,
pre_layer_norm=self.pre_layer_norm)
return out
def test_fused_ffn(self):
base_out = self.Base()
fused_out = self.FusedFFN()
np.testing.assert_allclose(
base_out.numpy(), fused_out.numpy(), rtol=self.rtol, atol=self.atol)
class TestFusedFFNOpFp16(TestFusedFFNOp):
def getDtype(self):
self.dtype = "float16"
self.layer_norm_dtype = "float32"
def getDiff(self):
self.rtol = 1e-1
self.atol = 1e-2
def getShape(self):
self.batch_size = 8
self.query_length = 128
self.d_model = 512
self.dim_feedforward = 512
class TestFusedFFNOpFp64(TestFusedFFNOp):
def getDtype(self):
self.dtype = "float64"
self.layer_norm_dtype = "float64"
class TestFusedFFNOpActivation(TestFusedFFNOp):
def getActivation(self):
self.act_method = "relu"
class TestFusedFFNOpNormalizeBefore(TestFusedFFNOp):
def getNormalizeBefore(self):
self.pre_layer_norm = True
def getShape(self):
self.batch_size = 1
self.query_length = 1
self.d_model = 8
self.dim_feedforward = 8
class APITestStaticFusedFFN(unittest.TestCase):
def test_static(self):
paddle.enable_static()
dtype = "float32"
layer_norm_dtype = "float32"
batch_size = 1
d_model = 8
dim_feedforward = 8
x = paddle.static.data(
name='x', shape=[batch_size, d_model, dim_feedforward], dtype=dtype)
linear1_weight = paddle.static.data(
name='linear1_weight',
shape=[d_model, dim_feedforward],
dtype=dtype)
linear1_bias = paddle.static.data(
name='linear1_bias', shape=[dim_feedforward])
linear2_weight = paddle.static.data(
name='linear2_weight',
shape=[dim_feedforward, d_model],
dtype=dtype)
linear2_bias = paddle.static.data(name='linear2_bias', shape=[d_model])
ln1_scale = paddle.static.data(name='ln1_scale', shape=[d_model])
ln1_bias = paddle.static.data(name='ln1_scale', shape=[d_model])
ln2_scale = paddle.static.data(name='ln2_scale', shape=[d_model])
ln2_bias = paddle.static.data(name='ln2_scale', shape=[d_model])
fused_out = F.fused_feedforward(
x,
linear1_weight,
linear2_weight,
linear1_bias,
linear2_bias,
ln1_scale,
ln1_bias,
ln2_scale,
ln2_bias,
0.0,
0.0,
activation="relu",
pre_layer_norm=False)
######base ffn######
linear1_out = F.linear(x, linear1_weight, linear1_bias)
act_out = F.relu(linear1_out)
dropout1_out = F.dropout(x=act_out, p=0.0, training=False)
linear2_out = F.linear(dropout1_out, linear2_weight, linear2_bias)
dropout2_out = x + F.dropout(x=linear2_out, p=0.0, training=False)
ln_out = F.layer_norm(
dropout2_out,
normalized_shape=list([d_model]),
weight=ln2_scale,
bias=ln2_bias)
######base ffn######
exe = paddle.static.Executor(paddle.CUDAPlace(0))
x_data = np.random.random(
(batch_size, d_model, dim_feedforward)).astype(dtype)
linear1_weight_data = np.random.random(
(d_model, dim_feedforward)).astype(dtype)
linear1_bias_data = np.zeros((dim_feedforward)).astype(dtype)
linear2_weight_data = np.random.random(
(dim_feedforward, d_model)).astype(dtype)
linear2_bias_data = np.zeros((d_model)).astype(dtype)
ln1_scale_data = np.ones((d_model)).astype(layer_norm_dtype)
ln1_bias_data = np.zeros((d_model)).astype(layer_norm_dtype)
ln2_scale_data = np.ones((d_model)).astype(layer_norm_dtype)
ln2_bias_data = np.zeros((d_model)).astype(layer_norm_dtype)
res_list = [fused_out, ln_out]
real_res = []
for res in res_list:
fetch = exe.run(feed={
'x': x_data,
'linear1_weight': linear1_weight_data,
'linear1_bias': linear1_bias_data,
'linear2_weight': linear2_weight_data,
'linear2_bias': linear2_bias_data,
'ln1_scale': ln1_scale_data,
'ln1_bias': ln1_bias_data,
'ln2_scale': ln2_scale_data,
'ln2_bias': ln2_bias_data
},
fetch_list=[res])
real_res.append(fetch)
self.assertTrue(
np.allclose(
real_res[0], real_res[1], atol=1e-5),
"two value is check diff")
class TestFusedFFNOpError(unittest.TestCase):
def test_errors(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
def test_dtype():
x = paddle.static.data(
name='x', shape=[1, 10, 10], dtype="int32")
linear1_weight = paddle.static.data(
name='linear1_weight', shape=[1, 10, 10], dtype="float32")
linear2_weight = paddle.static.data(
name='linear2_weight', shape=[1, 10, 10], dtype="float32")
paddle.nn.functional.fused_feedforward(x, linear1_weight,
linear2_weight)
self.assertRaises(TypeError, test_dtype)
def test_dropout_rate_type():
x = paddle.static.data(
name='x1', shape=[1, 10, 10], dtype="float32")
linear1_weight = paddle.static.data(
name='linear1_weight1', shape=[10, 10], dtype="float32")
linear2_weight = paddle.static.data(
name='linear2_weight1', shape=[10, 10], dtype="float32")
paddle.nn.functional.fused_feedforward(
x, linear1_weight, linear2_weight, dropout1_rate="a")
self.assertRaises(TypeError, test_dropout_rate_type)
def test_dropout_rate_value():
x = paddle.static.data(
name='x2', shape=[1, 10, 10], dtype="float32")
linear1_weight = paddle.static.data(
name='linear1_weight2', shape=[10, 10], dtype="float32")
linear2_weight = paddle.static.data(
name='linear2_weight2', shape=[10, 10], dtype="float32")
paddle.nn.functional.fused_feedforward(
x, linear1_weight, linear2_weight, dropout2_rate=-1)
self.assertRaises(ValueError, test_dropout_rate_value)
if __name__ == "__main__":
unittest.main()
......@@ -111,6 +111,7 @@ from .vision import grid_sample # noqa: F401
from .vision import pixel_shuffle # noqa: F401
from .input import one_hot # noqa: F401
from .input import embedding # noqa: F401
from .fused_transformer import fused_feedforward # noqa: F401
from ...fluid.layers import gather_tree # noqa: F401
from ...fluid.layers import temporal_shift # noqa: F401
......@@ -212,6 +213,7 @@ __all__ = [ #noqa
'layer_norm',
'instance_norm',
'class_center_sample',
'fused_feedforward',
'fused_multi_head_attention',
'sparse_attention',
]
......@@ -12,13 +12,166 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import in_dygraph_mode
from ...fluid.data_feeder import check_variable_and_dtype, check_dtype
from paddle import _C_ops
__all__ = []
def _verify_dropout_rate(dropout_rate):
if not isinstance(dropout_rate, (float, int)):
raise TypeError("dropout_rate argument should be a number")
if dropout_rate < 0 or dropout_rate > 1:
raise ValueError("dropout_rate argument should between 0 and 1")
def fused_feedforward(x,
linear1_weight,
linear2_weight,
linear1_bias=None,
linear2_bias=None,
ln1_scale=None,
ln1_bias=None,
ln2_scale=None,
ln2_bias=None,
dropout1_rate=0.5,
dropout2_rate=0.5,
activation="relu",
ln1_epsilon=1e-5,
ln2_epsilon=1e-5,
pre_layer_norm=False,
name=None):
"""
This is a fusion operator to compute feed forward layer in transformer model architecture.
This operator only supports running on GPU. The function of the operator is consistent with
the following pseudo code:
.. code-block:: python
residual = src;
if pre_layer_norm:
src = layer_norm(src)
src = linear(dropout(activation(dropout(linear(src)))))
if not pre_layer_norm:
src = layer_norm(out)
Args:
x (Tensor): the input tensor could be 3-D tensor, the input data type could be float16, float32 or float64, the shape is`[batch\_size, sequence\_length, d_model]`.
linear1_weight (Tensor): The weight of first linear, the data type is same as `x`, the shape is `[d\_model, dim\_feedforward]`.
linear2_weight (Tensor): The weight of second linear, the data type is same as `x`, the shape is `[dim\_feedforward, d\_model]`.
linear1_bias (Tensor, optional): The bias of first linear, the data type is same as `x`, the shape is `[dim_feedforward]`. Default None.
linear2_bias (Tensor, optional): The bias of second linear, the data type is same as `x`, the shape is `[d_model]`. Default None.
ln1_scale (Tensor, optional): the weight of first layer_norm, the data type is float32 or float64, the shape is same as `x`. Default None.
ln1_bias (Tensor, optional): The bias of first layer_norm, the data type is float32 or float64, the shape is `[d\_model]`. Default None.
ln2_scale (Tensor, optional): The weight of second layer_norm, the data type is float32 or float64, the shape is same as `x`. Default None.
ln2_bias (Tensor, optional): The bias of second layer_norm, the data type is float32 or float64, the shape is `[d\_model]`. Default None.
dropout1_rate (float, optional): The first dropout probability of setting units to zero. Default 0.5.
dropout2_rate (float, optional): The second dropout probability of setting units to zero. Default 0.5.
activation (str, optional): The activation. Default "relu".
ln1_epsilon (float, optional): Small float of first layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
ln2_epsilon (float, optional): Small float of second layer_norm added to denominator to avoid dividing by zero. Default is 1e-5.
pre_layer_norm (bool, optional): add layer_norm in the pre-processing stage or post-processing state.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
Tensor: The output Tensor, the data type and shape is same as `x`.
Examples:
.. code-block:: python
# required: gpu
import paddle
import numpy as np
x_data = np.random.random((1, 8, 8)).astype("float32")
linear1_weight_data = np.random.random((8, 8)).astype("float32")
linear2_weight_data = np.random.random((8, 8)).astype("float32")
x = paddle.to_tensor(x_data)
linear1_weight = paddle.to_tensor(linear1_weight_data)
linear2_weight = paddle.to_tensor(linear2_weight_data)
out = paddle.nn.functional.fused_feedforward(x, linear1_weight, linear2_weight)
print(out.numpy().shape)
# (1, 8, 8)
"""
_verify_dropout_rate(dropout1_rate)
_verify_dropout_rate(dropout2_rate)
if in_dygraph_mode():
out, _, _, _, _, _, _, _, _, _, _ = _C_ops.fused_feedforward(
x, None, None, linear1_weight, linear1_bias, linear2_weight,
linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias,
'pre_layer_norm', pre_layer_norm, 'ln1_epsilon', ln1_epsilon,
'ln2_epsilon', ln2_epsilon, 'act_method', activation,
'dropout1_rate', dropout1_rate, 'dropout2_rate', dropout2_rate)
return out
helper = LayerHelper("fused_feedforward")
dtype = x.dtype
check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'],
'fused_feedforward')
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
'fused_feedforward')
out = helper.create_variable_for_type_inference(x.dtype)
dropout1_mask = helper.create_variable_for_type_inference(
'uint8', stop_gradient=True)
dropout2_mask = helper.create_variable_for_type_inference(
'uint8', stop_gradient=True)
ln1_mean = helper.create_variable_for_type_inference(
x.dtype, stop_gradient=True)
ln1_variance = helper.create_variable_for_type_inference(
x.dtype, stop_gradient=True)
ln2_mean = helper.create_variable_for_type_inference(
x.dtype, stop_gradient=True)
ln2_variance = helper.create_variable_for_type_inference(
x.dtype, stop_gradient=True)
linear1_out = helper.create_variable_for_type_inference(
x.dtype, stop_gradient=True)
ln1_out = helper.create_variable_for_type_inference(
x.dtype, stop_gradient=True)
dropout1_out = helper.create_variable_for_type_inference(
x.dtype, stop_gradient=True)
dropout2_out = helper.create_variable_for_type_inference(
x.dtype, stop_gradient=True)
helper.append_op(
type='fused_feedforward',
inputs={
'X': x,
'Linear1Weight': linear1_weight,
'Linear1Bias': linear1_bias,
'Linear2Weight': linear2_weight,
'Linear2Bias': linear2_bias,
'Ln1Scale': ln1_scale,
'Ln1Bias': ln1_bias,
'Ln2Scale': ln2_scale,
'Ln2Bias': ln2_bias,
},
outputs={
'Out': out,
'Dropout1Mask': dropout1_mask,
'Dropout2Mask': dropout2_mask,
'Ln1Mean': ln1_mean,
'Ln1Variance': ln1_variance,
'Ln2Mean': ln2_mean,
'Ln2Variance': ln2_variance,
'Linear1Out': linear1_out,
'Ln1Out': ln1_out,
'Dropout1Out': dropout1_out,
'Dropout2Out': dropout2_out,
},
attrs={
'dropout1_rate': dropout1_rate,
'dropout2_rate': dropout2_rate,
'act_method': activation,
'pre_layer_norm': pre_layer_norm,
'ln1_epsilon': ln1_epsilon,
'ln2_epsilon': ln2_epsilon,
})
return out
def fused_multi_head_attention(x,
qkv_weight,
linear_weight,
......@@ -38,7 +191,7 @@ def fused_multi_head_attention(x,
"""
Attention mapps queries and a set of key-value pairs to outputs, and
Multi-Head Attention performs multiple parallel attention to jointly attending
to information from different representation subspaces. This API only
to information from different representation subspaces. This API only
support self_attention. The pseudo code is as follows:
if pre_layer_norm:
out = layer_norm(x);
......@@ -55,41 +208,41 @@ def fused_multi_head_attention(x,
out = softmax(out);
out = dropout(out);
out = out * v;
out = transpose(out, perm=[0, 2, 1, 3]);
out = transpose(out, perm=[0, 2, 1, 3]);
out = out_linear(out);
out = layer_norm(x + dropout(linear_bias + out));
Parameters:
x (Tensor): The input tensor of fused_multi_head_attention. The shape is
x (Tensor): The input tensor of fused_multi_head_attention. The shape is
`[batch\_size, sequence\_len, embed\_dim]`.
qkv_weight (Tensor): The qkv weight tensor. The shape is `[3, num_head, dim_head, dim_embed]`.
linear_weight (Tensor): The linear weight tensor. The shape is `[embed_dim, embed_dim]`.
pre_layer_norm (bool, optional): whether it is pre_layer_norm or post_layer_norm architecture.
pre_layer_norm (bool, optional): whether it is pre_layer_norm or post_layer_norm architecture.
Default False.
pre_ln_scale (Tensor, optional): The weight tensor of pre layernorm. Default None.
pre_ln_bias (Tensor, optional): The bias tensor of pre layernorm. Default None.
ln_scale (Tensor, optional): The weight tensor of layernorm. Default None.
ln_bias (Tensor, optional): The bias tensor of layernorm. Default None.
pre_ln_epsilon (float, optional): Small float value added to denominator of the pre layer_norm
pre_ln_epsilon (float, optional): Small float value added to denominator of the pre layer_norm
to avoid dividing by zero. Default is 1e-5.
qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
qkv_bias (Tensor, optional): The bias of qkv computation. The shape is `[3, num_head, dim_head]`.
Default None.
linear_bias (Tensor, optional): The bias of linear. The shape is `[embed_dim]`. Default None.
attn_mask (Tensor, optional):
dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout after attention.
weights to drop some attention targets for the dropout after attention.
0 for no dropout. Default 0.
attn_dropout_rate (float, optional): The dropout probability used on attention
weights to drop some attention targets for the dropout in attention.
weights to drop some attention targets for the dropout in attention.
0 for no dropout. Default 0.
ln_epsilon (float, optional): Small float value added to denominator of layer_norm
ln_epsilon (float, optional): Small float value added to denominator of layer_norm
to avoid dividing by zero. Default is 1e-5.
Examples:
.. code-block:: python
# required: gpu
# required: gpu
import paddle
import paddle.nn.functional as F
......@@ -115,8 +268,8 @@ def fused_multi_head_attention(x,
print(output.shape)
"""
if in_dygraph_mode():
# pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out,
# qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out,
# pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out,
# qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out,
# linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, final_out = _C_ops.fused_attention(
x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, attn_mask,
......
......@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO: define normalization api
# TODO: define normalization api
import paddle
import paddle.fluid as fluid
from ...fluid.data_feeder import check_variable_and_dtype, check_type
......@@ -35,7 +35,7 @@ def normalize(x, p=2, axis=1, epsilon=1e-12, name=None):
.. math::
y = \frac{x}{ \max\left( \lvert \lvert x \rvert \rvert_p, epsilon\right) }
.. math::
\lvert \lvert x \rvert \rvert_p = \left( \sum_i {\lvert x_i \rvert^p} \right)^{1/p}
......@@ -45,7 +45,7 @@ def normalize(x, p=2, axis=1, epsilon=1e-12, name=None):
Parameters:
x (Tensor): The input tensor could be N-D tensor, and the input data type could be float32 or float64.
p (float|int, optional): The exponent value in the norm formulation. Default: 2
axis (int, optional): The axis on which to apply normalization. If `axis < 0`, the dimension to normalization is `x.ndim + axis`. -1 is the last dimension.
axis (int, optional): The axis on which to apply normalization. If `axis < 0`, the dimension to normalization is `x.ndim + axis`. -1 is the last dimension.
epsilon (float, optional): Small float added to denominator to avoid dividing by zero. Default is 1e-12.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
......@@ -123,13 +123,13 @@ def batch_norm(x,
Applies Batch Normalization as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .
nn.functional.batch_norm is uesd for nn.BatchNorm1D, nn.BatchNorm2D, nn.BatchNorm3D. Please use above API for BatchNorm.
Parameters:
x(Tesnor): input value. It's data type should be float32, float64.
running_mean(Tensor): running mean.
running_var(Tensor): running variance.
weight(Tensor): The weight tensor of batch_norm, can not be None.
bias(Tensor): The bias tensor of batch_norm can not be None.
bias(Tensor): The bias tensor of batch_norm can not be None.
epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-5.
momentum(float, optional): The value used for the moving_mean and moving_var computation. Default: 0.9.
training(bool, optional): True means train mode which compute by batch data and track global mean and var during train period. False means inference mode which compute by global mean and var which calculated by train period. Defalut False.
......@@ -252,7 +252,7 @@ def layer_norm(x,
name=None):
"""
see more detail in paddle.nn.LayerNorm
Parameters:
x(Tensor): Input Tensor. It's data type should be float32, float64.
normalized_shape(int|list|tuple): Input shape from an expected input of
......@@ -277,7 +277,7 @@ def layer_norm(x,
np.random.seed(123)
x_data = np.random.random(size=(2, 2, 2, 3)).astype('float32')
x = paddle.to_tensor(x_data)
x = paddle.to_tensor(x_data)
layer_norm_out = paddle.nn.functional.layer_norm(x, x.shape[1:])
print(layer_norm_out)
"""
......@@ -378,7 +378,7 @@ def instance_norm(x,
np.random.seed(123)
x_data = np.random.random(size=(2, 2, 2, 3)).astype('float32')
x = paddle.to_tensor(x_data)
x = paddle.to_tensor(x_data)
instance_norm_out = paddle.nn.functional.instance_norm(x)
print(instance_norm_out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册