未验证 提交 83f81eb5 编写于 作者: L Leo Chen 提交者: GitHub

Fix pow, refine code (#31440)

上级 5fe3d596
......@@ -73,24 +73,19 @@ class PowGradNPUKernel : public framework::OpKernel<T> {
runner_pow.Run(stream);
// Step 2: Construct a broadcast factor, which has the same shape with x.
// 2.1 Get the shape of x
Tensor x_shape(framework::proto::VarType::INT32);
x_shape.mutable_data<int32_t>({x_dims.size()}, place);
TensorFromVector(framework::vectorize<int32_t>(x_dims),
ctx.device_context(), &x_shape);
// 2.2 Get a factor tensor with shape [1].
// 2.1 Get a factor tensor with shape [1].
Tensor factor_tensor(framework::proto::VarType::FP32);
factor_tensor.mutable_data<float>({1}, place);
TensorFromVector(std::vector<float>{factor}, ctx.device_context(),
&factor_tensor);
// 2.3 Get the factor which has the shape with x and the same value with
// 2.2 Get the factor which has the shape with x and the same value with
// factor.
Tensor factor_bc_tensor(framework::proto::VarType::FP32);
factor_bc_tensor.mutable_data<float>(x_dims, place);
auto runner_bc = NpuOpRunner("FillD", {factor_tensor}, {factor_bc_tensor},
{{"dims", x_dims}});
{{"dims", framework::vectorize(x_dims)}});
runner_bc.Run(stream);
// Step 3: Compute x_power_mul_factor = factor * x.pow(factor-1)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册