未验证 提交 d73eb38c 编写于 作者: A Aganlengzi 提交者: GitHub

[NPU] support model PPO (#42484)

上级 1588e7e7
......@@ -118,20 +118,45 @@ class ElementwiseDivGradNPUKernel : public framework::OpKernel<T> {
if (dy) {
dy->mutable_data<T>(place);
Tensor neg_out(y->type());
neg_out.mutable_data<T>(y->dims(), place);
Tensor neg_out(out->type());
neg_out.mutable_data<T>(out->dims(), place);
const auto& runner_neg_out = NpuOpRunner("Neg", {*out}, {neg_out}, {});
runner_neg_out.Run(stream);
Tensor y_grad_w(y->type());
y_grad_w.mutable_data<T>(y->dims(), place);
const auto& runner_y_grad_w =
NpuOpRunner("Div", {neg_out, *y}, {y_grad_w}, {});
runner_y_grad_w.Run(stream);
Tensor tmp_mul(out->type());
tmp_mul.mutable_data<T>(out->dims(), place);
const auto& runner_mul =
NpuOpRunner("Mul", {neg_out, *dout}, {tmp_mul}, {});
runner_mul.Run(stream);
if (dy->dims() != dout->dims()) {
Tensor reduced_tmp_mul(y->type());
reduced_tmp_mul.mutable_data<T>(y->dims(), place);
std::vector<int64_t> axes;
int64_t diff = dout->dims().size() - dy->dims().size();
for (int64_t i = 0; i < dout->dims().size(); ++i) {
if (i < diff) {
axes.push_back(i);
continue;
}
if (dout->dims()[i] > dy->dims()[i - diff]) {
axes.push_back(i);
}
}
const auto& runner_reduce =
NpuOpRunner("ReduceSumD", {tmp_mul}, {reduced_tmp_mul},
{{"axes", axes}, {"keep_dims", false}});
runner_reduce.Run(stream);
const auto& runner_y_grad =
NpuOpRunner("Mul", {y_grad_w, *dout}, {*dy}, {});
NpuOpRunner("Div", {reduced_tmp_mul, *y}, {*dy}, {});
runner_y_grad.Run(stream);
} else {
const auto& runner_y_grad =
NpuOpRunner("Div", {tmp_mul, *y}, {*dy}, {});
runner_y_grad.Run(stream);
}
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册