未验证 提交 5df1296d 编写于 作者: C Chitsing KUI 提交者: GitHub

fix fused_dropout_add bug (#52644)

上级 61ca8b39
...@@ -9,8 +9,9 @@ ...@@ -9,8 +9,9 @@
args : (Tensor seed_offset, Tensor out_grad, Scalar p, bool is_test, str mode, bool fix_seed) args : (Tensor seed_offset, Tensor out_grad, Scalar p, bool is_test, str mode, bool fix_seed)
output : Tensor(x_grad), Tensor(y_grad) output : Tensor(x_grad), Tensor(y_grad)
infer_meta : infer_meta :
func : GeneralBinaryGradInferMeta func : FusedDropoutAddGradInferMeta
param : [out_grad, out_grad] param : [seed_offset, out_grad]
kernel : kernel :
func : fused_dropout_add_grad func : fused_dropout_add_grad
data_type : out_grad
support_dygraph_mode : true support_dygraph_mode : true
...@@ -215,6 +215,19 @@ void FlashAttnGradInferMeta(const MetaTensor& q, ...@@ -215,6 +215,19 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
} }
} }
void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
MetaTensor* y_grad) {
if (x_grad != nullptr) {
x_grad->share_meta(out_grad);
}
if (y_grad != nullptr) {
y_grad->share_meta(out_grad);
}
}
void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label,
const MetaTensor& softmax, const MetaTensor& softmax,
const MetaTensor& loss_grad, const MetaTensor& loss_grad,
......
...@@ -179,6 +179,11 @@ void FlashAttnGradInferMeta(const MetaTensor& q, ...@@ -179,6 +179,11 @@ void FlashAttnGradInferMeta(const MetaTensor& q,
MetaTensor* dk, MetaTensor* dk,
MetaTensor* dv); MetaTensor* dv);
void FusedDropoutAddGradInferMeta(const MetaTensor& seed_offset,
const MetaTensor& out_grad,
MetaTensor* x_grad,
MetaTensor* y_grad);
void GatherNdGradInferMeta(const MetaTensor& x, void GatherNdGradInferMeta(const MetaTensor& x,
const MetaTensor& index, const MetaTensor& index,
const MetaTensor& out_grad, const MetaTensor& out_grad,
......
...@@ -20,7 +20,7 @@ namespace phi { ...@@ -20,7 +20,7 @@ namespace phi {
namespace fusion { namespace fusion {
template <typename Context> template <typename Context>
static inline std::vector<size_t> GetRandomCudaProp(int numel, static inline std::vector<size_t> GetRandomCudaProp(int64_t numel,
const Context& dev_ctx) { const Context& dev_ctx) {
constexpr int kVecSize = funcs::uniform_distribution<float>::kReturnsCount; constexpr int kVecSize = funcs::uniform_distribution<float>::kReturnsCount;
auto gpu_config = auto gpu_config =
......
...@@ -1260,6 +1260,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1260,6 +1260,7 @@ def set_grad_var_shape(program, dist_context):
"exp_grad", "exp_grad",
"sigmoid_grad", "sigmoid_grad",
"unsqueeze2_grad", "unsqueeze2_grad",
"fused_dropout_add_grad",
] ]
forward_list = [ forward_list = [
"reshape2", "reshape2",
...@@ -1281,6 +1282,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1281,6 +1282,7 @@ def set_grad_var_shape(program, dist_context):
"exp", "exp",
"sigmoid", "sigmoid",
"unsqueeze2", "unsqueeze2",
"fused_dropout_add",
] ]
if op.type in need_set_shape_list: if op.type in need_set_shape_list:
for forward_op in block.ops: for forward_op in block.ops:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册