未验证 提交 5df1296d 编写于 作者: C Chitsing KUI 提交者: GitHub

fix fused_dropout_add bug (#52644)

上级 61ca8b39
......@@ -9,8 +9,9 @@
args : (Tensor seed_offset, Tensor out_grad, Scalar p, bool is_test, str mode, bool fix_seed)
output : Tensor(x_grad), Tensor(y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [out_grad, out_grad]
func : FusedDropoutAddGradInferMeta
param : [seed_offset, out_grad]
kernel :
func : fused_dropout_add_grad
data_type : out_grad
support_dygraph_mode : true
......@@ -215,6 +215,19 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
}
}
void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
MetaTensor* y_grad) {
if (x_grad != nullptr) {
x_grad->share_meta(out_grad);
}
if (y_grad != nullptr) {
y_grad->share_meta(out_grad);
}
}
void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
const MetaTensor& softmax,
const MetaTensor& loss_grad,
......
......@@ -179,6 +179,11 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
MetaTensor* dk,
MetaTensor* dv);
void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
MetaTensor* y_grad);
void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index,
const MetaTensor& out_grad,
......
......@@ -20,7 +20,7 @@ namespace phi {
namespace fusion {
template <typename Context>
static inline std::vector<size_t> GetRandomCudaProp(int numel,
static inline std::vector<size_t> GetRandomCudaProp(int64_t numel,
const Context& dev_ctx) {
constexpr int kVecSize = funcs::uniform_distribution<float>::kReturnsCount;
auto gpu_config =
......
......@@ -1260,6 +1260,7 @@ def set_grad_var_shape(program, dist_context):
"exp_grad",
"sigmoid_grad",
"unsqueeze2_grad",
"fused_dropout_add_grad",
]
forward_list = [
"reshape2",
......@@ -1281,6 +1282,7 @@ def set_grad_var_shape(program, dist_context):
"exp",
"sigmoid",
"unsqueeze2",
"fused_dropout_add",
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册