未验证 提交 13baef48 编写于 作者: Z ZhangDY-6483 提交者: GitHub

edit formate of mea (#52147)

上级 134c9c0c
...@@ -633,6 +633,91 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x, ...@@ -633,6 +633,91 @@ void MaxPoolWithIndexGradInferMeta(const MetaTensor& x,
dx->share_meta(x); dx->share_meta(x);
} }
void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& output,
const MetaTensor& logsumexp,
const MetaTensor& seed_and_offset,
const MetaTensor& output_grad,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
MetaTensor* query_grad,
MetaTensor* key_grad,
MetaTensor* value_grad,
MetaTensor* bias_grad) {
PADDLE_ENFORCE_EQ(
output_grad.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
output_grad.dims().size()));
PADDLE_ENFORCE_EQ(
output.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
output_grad.dims().size()));
const int64_t query_batch_size = query.dims()[0];
const int64_t query_seq_length = query.dims()[1];
const int64_t query_num_head = query.dims()[2];
const int64_t query_head_size = query.dims()[3];
const int64_t key_batch_size = key.dims()[0];
const int64_t key_seq_length = key.dims()[1];
const int64_t key_num_head = key.dims()[2];
const int64_t key_head_size = key.dims()[3];
const int64_t value_batch_size = value.dims()[0];
const int64_t value_seq_length = value.dims()[1];
const int64_t value_num_head = value.dims()[2];
const int64_t value_head_size = value.dims()[3];
std::vector<int64_t> query_grad_dims(
{query_batch_size, query_seq_length, query_num_head, query_head_size});
std::vector<int64_t> key_grad_dims(
{key_batch_size, key_seq_length, key_num_head, key_head_size});
std::vector<int64_t> value_grad_dims(
{value_batch_size, value_seq_length, value_num_head, value_head_size});
query_grad->set_dims(phi::make_ddim(query_grad_dims));
query_grad->share_lod(query);
query_grad->set_dtype(query.dtype());
query_grad->set_layout(query.layout());
key_grad->set_dims(phi::make_ddim(key_grad_dims));
key_grad->share_lod(key);
key_grad->set_dtype(key.dtype());
key_grad->set_layout(key.layout());
value_grad->set_dims(phi::make_ddim(value_grad_dims));
value_grad->share_lod(value);
value_grad->set_dtype(value.dtype());
value_grad->set_layout(value.layout());
if (bias) {
const int64_t bias_batch_size = bias.dims()[0];
const int64_t bias_seq_length = bias.dims()[1];
const int64_t bias_num_head = bias.dims()[2];
const int64_t bias_head_size = bias.dims()[3];
std::vector<int64_t> bias_grad_dims(
{bias_batch_size, bias_seq_length, bias_num_head, bias_head_size});
bias_grad->set_dims(phi::make_ddim(bias_grad_dims));
bias_grad->share_lod(bias);
bias_grad->set_dtype(bias.dtype());
bias_grad->set_layout(bias.layout());
}
}
void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs, void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::vector<const MetaTensor*>& outputs_grad, const std::vector<const MetaTensor*>& outputs_grad,
std::vector<MetaTensor*> inputs_grad) { std::vector<MetaTensor*> inputs_grad) {
...@@ -1052,89 +1137,4 @@ void IndexAddGradInferMeta(const MetaTensor& index, ...@@ -1052,89 +1137,4 @@ void IndexAddGradInferMeta(const MetaTensor& index,
} }
} }
void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& output,
const MetaTensor& logsumexp,
const MetaTensor& seed_and_offset,
const MetaTensor& output_grad,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
MetaTensor* query_grad,
MetaTensor* key_grad,
MetaTensor* value_grad,
MetaTensor* bias_grad) {
PADDLE_ENFORCE_EQ(
output_grad.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
output_grad.dims().size()));
PADDLE_ENFORCE_EQ(
output.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
output_grad.dims().size()));
const int64_t query_batch_size = query.dims()[0];
const int64_t query_seq_length = query.dims()[1];
const int64_t query_num_head = query.dims()[2];
const int64_t query_head_size = query.dims()[3];
const int64_t key_batch_size = key.dims()[0];
const int64_t key_seq_length = key.dims()[1];
const int64_t key_num_head = key.dims()[2];
const int64_t key_head_size = key.dims()[3];
const int64_t value_batch_size = value.dims()[0];
const int64_t value_seq_length = value.dims()[1];
const int64_t value_num_head = value.dims()[2];
const int64_t value_head_size = value.dims()[3];
std::vector<int64_t> query_grad_dims(
{query_batch_size, query_seq_length, query_num_head, query_head_size});
std::vector<int64_t> key_grad_dims(
{key_batch_size, key_seq_length, key_num_head, key_head_size});
std::vector<int64_t> value_grad_dims(
{value_batch_size, value_seq_length, value_num_head, value_head_size});
query_grad->set_dims(phi::make_ddim(query_grad_dims));
query_grad->share_lod(query);
query_grad->set_dtype(query.dtype());
query_grad->set_layout(query.layout());
key_grad->set_dims(phi::make_ddim(key_grad_dims));
key_grad->share_lod(key);
key_grad->set_dtype(key.dtype());
key_grad->set_layout(key.layout());
value_grad->set_dims(phi::make_ddim(value_grad_dims));
value_grad->share_lod(value);
value_grad->set_dtype(value.dtype());
value_grad->set_layout(value.layout());
if (bias) {
const int64_t bias_batch_size = bias.dims()[0];
const int64_t bias_seq_length = bias.dims()[1];
const int64_t bias_num_head = bias.dims()[2];
const int64_t bias_head_size = bias.dims()[3];
std::vector<int64_t> bias_grad_dims(
{bias_batch_size, bias_seq_length, bias_num_head, bias_head_size});
bias_grad->set_dims(phi::make_ddim(bias_grad_dims));
bias_grad->share_lod(bias);
bias_grad->set_dtype(bias.dtype());
bias_grad->set_layout(bias.layout());
}
}
} // namespace phi } // namespace phi
...@@ -294,6 +294,26 @@ void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs, ...@@ -294,6 +294,26 @@ void MeshgridGradInferMeta(const std::vector<const MetaTensor*>& inputs,
const std::vector<const MetaTensor*>& outputs_grad, const std::vector<const MetaTensor*>& outputs_grad,
std::vector<MetaTensor*> inputs_grad); std::vector<MetaTensor*> inputs_grad);
void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& output,
const MetaTensor& logsumexp,
const MetaTensor& seed_and_offset,
const MetaTensor& output_grad,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
MetaTensor* query_grad,
MetaTensor* key_grad,
MetaTensor* value_grad,
MetaTensor* bias_grad);
void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x, void MultiDotGradInferMeta(const std::vector<const MetaTensor*>& x,
const MetaTensor& out_grad, const MetaTensor& out_grad,
std::vector<MetaTensor*> x_grad); std::vector<MetaTensor*> x_grad);
...@@ -418,24 +438,4 @@ void IndexAddGradInferMeta(const MetaTensor& index, ...@@ -418,24 +438,4 @@ void IndexAddGradInferMeta(const MetaTensor& index,
MetaTensor* x_grad, MetaTensor* x_grad,
MetaTensor* add_tensor_grad); MetaTensor* add_tensor_grad);
void MemoryEfficientAttentionGradInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& output,
const MetaTensor& logsumexp,
const MetaTensor& seed_and_offset,
const MetaTensor& output_grad,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
MetaTensor* query_grad,
MetaTensor* key_grad,
MetaTensor* value_grad,
MetaTensor* bias_grad);
} // namespace phi } // namespace phi
...@@ -2112,6 +2112,95 @@ void MergedMomentumInferMeta( ...@@ -2112,6 +2112,95 @@ void MergedMomentumInferMeta(
std::vector<MetaTensor*> velocity_out, std::vector<MetaTensor*> velocity_out,
std::vector<MetaTensor*> master_param_out) {} std::vector<MetaTensor*> master_param_out) {}
void MemoryEfficientAttentionInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& causal_diagonal,
const MetaTensor& seqlen_k,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
const bool is_test,
MetaTensor* output,
MetaTensor* logsumexp,
MetaTensor* seed_and_offset) {
PADDLE_ENFORCE_EQ(
query.dims().size(),
4,
phi::errors::InvalidArgument("Query should be a 4-D tensor"
"But received Query dimension(%s)",
query.dims().size()));
PADDLE_ENFORCE_EQ(
key.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
key.dims().size()));
PADDLE_ENFORCE_EQ(
value.dims().size(),
4,
phi::errors::InvalidArgument("Value should be a 4-D tensor"
"But received Value dimension(%s)",
value.dims().size()));
const int64_t query_batch_size = query.dims()[0];
const int64_t query_seq_length = query.dims()[1];
const int64_t query_num_head = query.dims()[2];
const int64_t query_head_size = query.dims()[3];
const int64_t key_batch_size = key.dims()[0];
const int64_t key_seq_length = key.dims()[1];
const int64_t key_num_head = key.dims()[2];
const int64_t key_head_size = key.dims()[3];
const int64_t value_batch_size = value.dims()[0];
const int64_t value_seq_length = value.dims()[1];
const int64_t value_num_head = value.dims()[2];
const int64_t value_head_size = value.dims()[3];
PADDLE_ENFORCE_EQ(((query_batch_size == key_batch_size) &&
(key_batch_size == value_batch_size)),
true,
phi::errors::InvalidArgument(
"The batchsize of Query, Key, Value should be equal."));
PADDLE_ENFORCE_EQ(
((query_num_head == key_num_head) && (key_num_head == value_num_head)),
true,
phi::errors::InvalidArgument(
"The head number of Query, Key, Value should be equal."));
PADDLE_ENFORCE_EQ(query_head_size == key_head_size,
true,
phi::errors::InvalidArgument(
"The head size of Query, Key should be equal."));
PADDLE_ENFORCE_EQ(key_seq_length == value_seq_length,
true,
phi::errors::InvalidArgument(
"The seq length of Key, Value should be equal."));
std::vector<int64_t> out_dims(
{query_batch_size, query_seq_length, query_num_head, value_head_size});
std::vector<int64_t> logsumexp_dims({query_num_head, query_batch_size});
std::vector<int64_t> seed_and_offset_dims({2});
output->set_dims(phi::make_ddim(out_dims));
output->share_lod(query);
output->set_dtype(query.dtype());
output->set_layout(query.layout());
logsumexp->set_dims(phi::make_ddim(logsumexp_dims));
logsumexp->set_dtype(phi::DataType::FLOAT32);
seed_and_offset->set_dims(phi::make_ddim(seed_and_offset_dims));
seed_and_offset->set_dtype(phi::DataType::INT64);
}
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs, void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs) { std::vector<MetaTensor*> outputs) {
const size_t inputs_num = inputs.size(); const size_t inputs_num = inputs.size();
...@@ -3129,94 +3218,5 @@ void MoeInferMeta(const MetaTensor& x, ...@@ -3129,94 +3218,5 @@ void MoeInferMeta(const MetaTensor& x,
out->set_layout(x.layout()); out->set_layout(x.layout());
} }
void MemoryEfficientAttentionInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& causal_diagonal,
const MetaTensor& seqlen_k,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
const bool is_test,
MetaTensor* output,
MetaTensor* logsumexp,
MetaTensor* seed_and_offset) {
PADDLE_ENFORCE_EQ(
query.dims().size(),
4,
phi::errors::InvalidArgument("Query should be a 4-D tensor"
"But received Query dimension(%s)",
query.dims().size()));
PADDLE_ENFORCE_EQ(
key.dims().size(),
4,
phi::errors::InvalidArgument("Key should be a 4-D tensor"
"But received Key dimension(%s)",
key.dims().size()));
PADDLE_ENFORCE_EQ(
value.dims().size(),
4,
phi::errors::InvalidArgument("Value should be a 4-D tensor"
"But received Value dimension(%s)",
value.dims().size()));
const int64_t query_batch_size = query.dims()[0];
const int64_t query_seq_length = query.dims()[1];
const int64_t query_num_head = query.dims()[2];
const int64_t query_head_size = query.dims()[3];
const int64_t key_batch_size = key.dims()[0];
const int64_t key_seq_length = key.dims()[1];
const int64_t key_num_head = key.dims()[2];
const int64_t key_head_size = key.dims()[3];
const int64_t value_batch_size = value.dims()[0];
const int64_t value_seq_length = value.dims()[1];
const int64_t value_num_head = value.dims()[2];
const int64_t value_head_size = value.dims()[3];
PADDLE_ENFORCE_EQ(((query_batch_size == key_batch_size) &&
(key_batch_size == value_batch_size)),
true,
phi::errors::InvalidArgument(
"The batchsize of Query, Key, Value should be equal."));
PADDLE_ENFORCE_EQ(
((query_num_head == key_num_head) && (key_num_head == value_num_head)),
true,
phi::errors::InvalidArgument(
"The head number of Query, Key, Value should be equal."));
PADDLE_ENFORCE_EQ(query_head_size == key_head_size,
true,
phi::errors::InvalidArgument(
"The head size of Query, Key should be equal."));
PADDLE_ENFORCE_EQ(key_seq_length == value_seq_length,
true,
phi::errors::InvalidArgument(
"The seq length of Key, Value should be equal."));
std::vector<int64_t> out_dims(
{query_batch_size, query_seq_length, query_num_head, value_head_size});
std::vector<int64_t> logsumexp_dims({query_num_head, query_batch_size});
std::vector<int64_t> seed_and_offset_dims({2});
output->set_dims(phi::make_ddim(out_dims));
output->share_lod(query);
output->set_dtype(query.dtype());
output->set_layout(query.layout());
logsumexp->set_dims(phi::make_ddim(logsumexp_dims));
logsumexp->set_dtype(phi::DataType::FLOAT32);
seed_and_offset->set_dims(phi::make_ddim(seed_and_offset_dims));
seed_and_offset->set_dtype(phi::DataType::INT64);
}
} // namespace phi } // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta); PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
...@@ -398,6 +398,24 @@ void MergedMomentumInferMeta( ...@@ -398,6 +398,24 @@ void MergedMomentumInferMeta(
std::vector<MetaTensor*> velocity_out, std::vector<MetaTensor*> velocity_out,
std::vector<MetaTensor*> master_param_out); std::vector<MetaTensor*> master_param_out);
void MemoryEfficientAttentionInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& causal_diagonal,
const MetaTensor& seqlen_k,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
const bool is_test,
MetaTensor* output,
MetaTensor* logsumexp,
MetaTensor* seed_and_offset);
void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs, void MeshgridInferMeta(const std::vector<const MetaTensor*>& inputs,
std::vector<MetaTensor*> outputs); std::vector<MetaTensor*> outputs);
...@@ -587,22 +605,4 @@ void MoeInferMeta(const MetaTensor& x, ...@@ -587,22 +605,4 @@ void MoeInferMeta(const MetaTensor& x,
const std::string& act_type, const std::string& act_type,
MetaTensor* out); MetaTensor* out);
void MemoryEfficientAttentionInferMeta(const MetaTensor& query,
const MetaTensor& key,
const MetaTensor& value,
const MetaTensor& bias,
const MetaTensor& cu_seqlens_q,
const MetaTensor& cu_seqlens_k,
const MetaTensor& causal_diagonal,
const MetaTensor& seqlen_k,
const Scalar& max_seqlen_q,
const Scalar& max_seqlen_k,
const bool causal,
const double dropout_p,
const float scale,
const bool is_test,
MetaTensor* output,
MetaTensor* logsumexp,
MetaTensor* seed_and_offset);
} // namespace phi } // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册