未验证 提交 413d6e1b 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] fix FillD not work on cann512, test=develop (#45586)

上级 f41b8566
......@@ -65,21 +65,11 @@ class FillConstantNPUKernel : public framework::OpKernel<T> {
tensor_value.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&tensor_value, value);
NpuOpRunner runner;
#if (CANN_VERSION_CODE >= 503003 && CANN_VERSION_CODE < 504000)
runner.SetType("FillD")
.AddInput(tensor_value)
.AddOutput(*out_var)
.AddAttrs(
{{ "dims",
phi::vectorize(shape) }})
.Run(stream);
#else
runner.SetType("Fill")
.AddInput(phi::vectorize(shape))
.AddInput(tensor_value)
.AddOutput(*out_var)
.Run(stream);
#endif
} else {
const auto &dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
......
......@@ -30,22 +30,25 @@ class NPUReduceMeanOpKernel : public framework::OpKernel<T> {
auto dims = ctx.Attr<std::vector<int>>("dim");
bool keep_dim = ctx.Attr<bool>("keep_dim");
auto input_dims_vec = phi::vectorize(input->dims());
auto input_dims = input->dims();
if (reduce_all) {
dims.clear();
for (size_t i = 0; i < input_dims_vec.size(); i++) {
for (int i = 0; i < input_dims.size(); i++) {
dims.push_back(static_cast<int>(i));
}
}
const auto& runner = NpuOpRunner("ReduceMeanD",
{*input},
{*output},
{{"axes", dims}, {"keep_dims", keep_dim}});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
NpuOpRunner runner;
runner.SetType("ReduceMean")
.AddInput(*input)
.AddInput(std::move(dims))
.AddOutput(*output)
.AddAttrs({{"keep_dims", keep_dim}})
.Run(stream);
}
};
......@@ -60,41 +63,45 @@ class NPUReduceMeanGradOpKernel : public framework::OpKernel<T> {
bool reduce_all = ctx.Attr<bool>("reduce_all");
auto reduce_dims = ctx.Attr<std::vector<int>>("dim");
auto input_dims_vec = phi::vectorize(input->dims());
auto input_dims = input->dims();
int reduce_numel = 1;
if (reduce_all) {
reduce_dims.clear();
for (size_t d = 0; d < input_dims_vec.size(); ++d) {
for (int d = 0; d < input_dims.size(); ++d) {
reduce_dims.push_back(static_cast<int>(d));
}
}
for (auto& d : reduce_dims) {
if (d < 0) {
d = d + input_dims_vec.size();
d = d + input_dims.size();
}
reduce_numel *= input_dims_vec[d];
reduce_numel *= input_dims[d];
}
const auto& runner =
NpuOpRunner("FillV2D",
{},
{*input_grad},
{{"value", 1.0f / static_cast<float>(reduce_numel)},
{"dims", input_dims_vec}});
Tensor tensor_value(input_grad->dtype());
tensor_value.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(
&tensor_value, static_cast<T>(1.0f / static_cast<T>(reduce_numel)));
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
NpuOpRunner runner;
runner.SetType("Fill")
.AddInput(phi::vectorize(input_dims))
.AddInput(tensor_value)
.AddOutput(*input_grad)
.Run(stream);
Tensor transformed_input_grad, transformed_out_grad;
Tensor tmp_output_grad;
auto tmp_output_dims_vec = input_dims_vec;
auto tmp_output_dims = input_dims;
for (auto d : reduce_dims) {
tmp_output_dims_vec[d] = 1;
tmp_output_dims[d] = 1;
}
tmp_output_grad.ShareDataWith(*output_grad);
tmp_output_grad.Resize(phi::make_ddim(tmp_output_dims_vec));
tmp_output_grad.Resize(tmp_output_dims);
auto& dev_ctx =
ctx.template device_context<paddle::platform::NPUDeviceContext>();
NpuElementWiseOpBroadcast<T>(dev_ctx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册