未验证 提交 22255528 编写于 作者: P piotrekobi 提交者: GitHub

[PHI] Migrate reduce sum+grad, mean+grad, min and max oneDNN kernels (#45536)

* gaussian random

* mkldnn to onednn renaming

* fix merge conflicts

* Migrate reduce_op oneDNN kernels to phi

* Remove unnecessary header

* remove fluid code

* onednn renaming

* Change std::vector<int64_t> to IntArray

* Fix code style

* Move classes from mkldnn_reuse.h to onednn_reuse.h

* Move more functions from mkldnn_helper.h to onednn_helpper.h

* Change MKLDNN to OneDNN in VLOG message

* Implement reviewer suggestions
Co-authored-by: NSilv3S <slawomir.siwek@intel.com>
上级 352babaa
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h"
namespace paddle {
namespace operators {
template <typename T>
class ReduceMeanMKLDNNKernel : public ReduceMKLDNNKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx, dnnl::algorithm::reduction_mean);
}
};
template <typename T>
class ReduceMeanGradMKLDNNKernel : public ReduceGradMKLDNNKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* input_x = ctx.Input<Tensor>("X");
auto input_dims = phi::vectorize(input_x->dims());
auto reduce_dims = ctx.Attr<std::vector<int>>("dim");
int number_of_elements = 1;
if (!ctx.Attr<bool>("reduce_all")) {
for (size_t i = 0; i < reduce_dims.size(); ++i) {
reduce_dims[i] = (reduce_dims[i] >= 0)
? reduce_dims[i]
: input_dims.size() + reduce_dims[i];
number_of_elements *= input_dims[reduce_dims[i]];
}
} else {
number_of_elements = input_x->numel();
}
this->RunKernel(ctx,
dnnl::algorithm::binary_add,
dnnl::algorithm::reduction_mean,
0.0f,
1.0L / number_of_elements);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(reduce_mean,
MKLDNN,
paddle::platform::CPUPlace,
ops::ReduceMeanMKLDNNKernel<float>,
ops::ReduceMeanMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(reduce_mean_grad,
MKLDNN,
paddle::platform::CPUPlace,
ops::ReduceMeanGradMKLDNNKernel<float>,
ops::ReduceMeanGradMKLDNNKernel<paddle::platform::bfloat16>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using paddle::framework::LoDTensor;
using paddle::framework::Tensor;
using platform::to_void_cast;
inline std::vector<int64_t> CalculateReducedDims(
const Tensor* input,
const Tensor* output,
std::vector<int>& reduce_dims, // NOLINT
bool reduce_all,
bool keep_dim) {
if (keep_dim) return phi::vectorize(output->dims());
if (reduce_all) return std::vector<int64_t>(input->dims().size(), 1);
std::vector<int64_t> output_dims(phi::vectorize(input->dims()));
for (size_t i = 0; i < reduce_dims.size(); ++i) {
// handle negative dims, f.e. "-1" means rightmost dimension
reduce_dims[i] = (reduce_dims[i] >= 0)
? reduce_dims[i]
: input->dims().size() + reduce_dims[i];
output_dims[reduce_dims[i]] = 1;
}
return output_dims;
}
template <typename T>
class ReduceMKLDNNKernel : public framework::OpKernel<T> {
public:
void RunKernel(const framework::ExecutionContext& ctx,
dnnl::algorithm reduction_type) const {
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
const auto* x = ctx.Input<LoDTensor>("X");
auto* out = ctx.Output<Tensor>("Out");
auto reduce_dims = ctx.Attr<std::vector<int>>("dim");
bool reduce_all = ctx.Attr<bool>("reduce_all");
bool keep_dim = ctx.Attr<bool>("keep_dim");
auto x_tz = phi::vectorize(x->dims());
auto out_tz =
CalculateReducedDims(x, out, reduce_dims, reduce_all, keep_dim);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
// oneDNN reduce op does not support edge case in which memory is being
// copied without actual reduction.
// In that case reorder must be executed to maintain compatibility with
// PaddlePaddle reduce op
if (x_tz == out_tz) {
dnnl::memory::data_type x_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(x->dtype()));
platform::ReorderMKLDNNHandler reorder_handler(
x_tz,
framework::TransToProtoVarType(x->dtype()),
x_type,
onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x->mem_desc(), platform::to_void_cast(x->data<T>()));
// reuse mem desc since it is a simple copy
auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(out, x->mem_desc(), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
out->set_mem_desc(reorder_dst_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(out->dims())));
} else {
platform::ReductionMKLDNNHandler<T> handler(reduction_type,
0.0f,
0.0f,
onednn_engine,
ctx.GetPlace(),
x,
out,
out_tz);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out);
std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto reduction_p = handler.AcquireForwardPrimitive();
reduction_p->execute(astream, reduction_args);
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(out->dims())));
}
}
};
template <typename T>
class ReduceGradMKLDNNKernel : public framework::OpKernel<T> {
public:
void RunKernel(const framework::ExecutionContext& ctx,
dnnl::algorithm binary_type,
dnnl::algorithm reduction_type,
float scale_x,
float scale_y) const {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();
bool keep_dim = ctx.Attr<bool>("keep_dim");
bool reduce_all = ctx.Attr<bool>("reduce_all");
auto dims = ctx.Attr<std::vector<int>>("dim");
const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dout_tz = CalculateReducedDims(dx, dout, dims, reduce_all, keep_dim);
auto dx_tz = phi::vectorize(dx->dims());
platform::BroadcastDataMKLDNNHandler<T> handler(binary_type,
onednn_engine,
ctx.GetPlace(),
dout,
dx,
scale_x,
scale_y,
dout_tz);
const auto src_memory_p = handler.AcquireSrcMemory(dout);
const auto dst_memory_p = handler.AcquireZeroedDstMemory(dx);
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *dst_memory_p},
{DNNL_ARG_SRC_1, *src_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
binary_prim->execute(astream, args);
astream.wait();
dx->set_mem_desc(dst_memory_p->get_desc());
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h"
namespace paddle {
namespace operators {
template <typename T>
class ReduceSumMKLDNNKernel : public ReduceMKLDNNKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx, dnnl::algorithm::reduction_sum);
}
};
template <typename T>
class ReduceSumGradMKLDNNKernel : public ReduceGradMKLDNNKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx,
dnnl::algorithm::binary_add,
dnnl::algorithm::reduction_sum,
0.0f,
1.0f);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(reduce_sum,
MKLDNN,
paddle::platform::CPUPlace,
ops::ReduceSumMKLDNNKernel<float>,
ops::ReduceSumMKLDNNKernel<paddle::platform::bfloat16>);
REGISTER_OP_KERNEL(reduce_sum_grad,
MKLDNN,
paddle::platform::CPUPlace,
ops::ReduceSumGradMKLDNNKernel<float>,
ops::ReduceSumGradMKLDNNKernel<paddle::platform::bfloat16>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/backends/onednn/onednn_reuse.h"
namespace phi {
inline std::vector<int64_t> CalculateReducedDims(
const DenseTensor* input,
const DenseTensor* output,
const std::vector<int64_t>& reduce_dims, // NOLINT
bool reduce_all,
bool keep_dim) {
if (keep_dim) return vectorize(output->dims());
if (reduce_all && reduce_dims.size() > 0)
return std::vector<int64_t>(input->dims().size(), 1);
std::vector<int64_t> output_dims(vectorize(input->dims()));
for (size_t i = 0; i < reduce_dims.size(); ++i) {
// handle negative dims, f.e. "-1" means rightmost dimension
int index = (reduce_dims[i] >= 0) ? reduce_dims[i]
: input->dims().size() + reduce_dims[i];
output_dims[index] = 1;
}
return output_dims;
}
template <typename T, typename Context>
void ReduceKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out,
dnnl::algorithm reduction_type) {
const auto& onednn_engine = dev_ctx.GetEngine();
auto x_tz = vectorize(x.dims());
auto out_tz =
CalculateReducedDims(&x, out, dims.GetData(), reduce_all, keep_dim);
auto& astream = OneDNNContext::tls().get_stream();
// oneDNN reduce op does not support edge case in which memory is being
// copied without actual reduction.
// In that case reorder must be executed to maintain compatibility with
// PaddlePaddle reduce op
if (x_tz == out_tz) {
dnnl::memory::data_type x_type = funcs::ToOneDNNDataType((x.dtype()));
funcs::ReorderOneDNNHandler reorder_handler(
x_tz, x.dtype(), x_type, onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
x.mem_desc(), funcs::to_void_cast(x.data<T>()));
// reuse mem desc since it is a simple copy
auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(out, x.mem_desc(), dev_ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p);
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait();
out->set_mem_desc(reorder_dst_memory_p->get_desc().reshape(
vectorize<int64_t>(out->dims())));
} else {
funcs::ReductionOneDNNHandler<T> handler(reduction_type,
0.0f,
0.0f,
onednn_engine,
dev_ctx.GetPlace(),
&x,
out,
out_tz);
auto src_memory_p = handler.AcquireSrcMemory(&x);
auto dst_memory_p = handler.AcquireDstMemory(out);
std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
auto reduction_p = handler.AcquireForwardPrimitive();
reduction_p->execute(astream, reduction_args);
astream.wait();
out->set_mem_desc(
dst_memory_p->get_desc().reshape(vectorize<int64_t>(out->dims())));
}
}
template <typename T, typename Context>
void ReduceGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad,
dnnl::algorithm binary_type,
dnnl::algorithm reduction_type,
float scale_x,
float scale_y) {
const auto& onednn_engine = dev_ctx.GetEngine();
auto out_grad_tz = CalculateReducedDims(
x_grad, &out_grad, dims.GetData(), reduce_all, keep_dim);
auto x_grad_tz = vectorize(x_grad->dims());
funcs::BroadcastDataOneDNNHandler<T> handler(binary_type,
onednn_engine,
dev_ctx.GetPlace(),
&out_grad,
x_grad,
scale_x,
scale_y,
out_grad_tz);
const auto src_memory_p = handler.AcquireSrcMemory(&out_grad);
const auto dst_memory_p = handler.AcquireZeroedDstMemory(x_grad);
const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *dst_memory_p},
{DNNL_ARG_SRC_1, *src_memory_p},
{DNNL_ARG_DST, *dst_memory_p}};
auto& astream = OneDNNContext::tls().get_stream();
binary_prim->execute(astream, args);
astream.wait();
x_grad->set_mem_desc(dst_memory_p->get_desc());
}
} // namespace phi
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,25 +12,31 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h"
namespace paddle {
namespace operators {
template <typename T>
class ReduceMinMKLDNNKernel : public ReduceMKLDNNKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx, dnnl::algorithm::reduction_min);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(reduce_min,
MKLDNN,
paddle::platform::CPUPlace,
ops::ReduceMinMKLDNNKernel<float>,
ops::ReduceMinMKLDNNKernel<paddle::platform::bfloat16>);
#include "paddle/phi/kernels/reduce_max_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/onednn/reduce_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void MaxRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
ReduceKernel<T, Context>(dev_ctx,
x,
dims,
keep_dim,
reduce_all,
out,
dnnl::algorithm::reduction_max);
}
} // namespace phi
PD_REGISTER_KERNEL(max_raw,
OneDNN,
ALL_LAYOUT,
phi::MaxRawKernel,
float,
phi::dtype::bfloat16) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/reduce_mean_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/onednn/reduce_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void MeanGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
auto input_dims = phi::vectorize(x.dims());
std::vector<int64_t> reduce_dims = dims.GetData();
int number_of_elements = 1;
if (reduce_all == false) {
for (size_t i = 0; i < dims.size(); ++i) {
reduce_dims[i] = (reduce_dims[i] >= 0)
? reduce_dims[i]
: input_dims.size() + reduce_dims[i];
number_of_elements *= input_dims[reduce_dims[i]];
}
} else {
number_of_elements = x.numel();
}
const IntArray new_dims = IntArray(reduce_dims);
ReduceGradKernel<T, Context>(dev_ctx,
x,
out_grad,
new_dims,
keep_dim,
reduce_all,
x_grad,
dnnl::algorithm::binary_add,
dnnl::algorithm::reduction_mean,
0.0f,
1.0L / number_of_elements);
}
} // namespace phi
PD_REGISTER_KERNEL(mean_grad,
OneDNN,
ALL_LAYOUT,
phi::MeanGradKernel,
float,
phi::dtype::bfloat16) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/reduce_mean_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/onednn/reduce_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void MeanRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
ReduceKernel<T, Context>(dev_ctx,
x,
dims,
keep_dim,
reduce_all,
out,
dnnl::algorithm::reduction_mean);
}
} // namespace phi
PD_REGISTER_KERNEL(mean_raw,
OneDNN,
ALL_LAYOUT,
phi::MeanRawKernel,
float,
phi::dtype::bfloat16) {}
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,25 +12,32 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reduce_ops/mkldnn/reduce_mkldnn_op.h"
namespace paddle {
namespace operators {
template <typename T>
class ReduceMaxMKLDNNKernel : public ReduceMKLDNNKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
this->RunKernel(ctx, dnnl::algorithm::reduction_max);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(reduce_max,
MKLDNN,
paddle::platform::CPUPlace,
ops::ReduceMaxMKLDNNKernel<float>,
ops::ReduceMaxMKLDNNKernel<paddle::platform::bfloat16>);
#include "paddle/phi/kernels/reduce_min_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/onednn/reduce_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void MinRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* out) {
ReduceKernel<T, Context>(dev_ctx,
x,
dims,
keep_dim,
reduce_all,
out,
dnnl::algorithm::reduction_min);
}
} // namespace phi
PD_REGISTER_KERNEL(min_raw,
OneDNN,
ALL_LAYOUT,
phi::MinRawKernel,
float,
phi::dtype::bfloat16) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/reduce_sum_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/onednn/reduce_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void SumGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& out_grad,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DenseTensor* x_grad) {
ReduceGradKernel<T, Context>(dev_ctx,
x,
out_grad,
dims,
keep_dim,
reduce_all,
x_grad,
dnnl::algorithm::binary_add,
dnnl::algorithm::reduction_sum,
0.0f,
1.0f);
}
} // namespace phi
PD_REGISTER_KERNEL(sum_grad,
OneDNN,
ALL_LAYOUT,
phi::SumGradKernel,
float,
phi::dtype::bfloat16) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/reduce_sum_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/onednn/reduce_kernel_impl.h"
namespace phi {
template <typename T, typename Context>
void SumRawKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& dims,
bool keep_dim,
bool reduce_all,
DataType out_dtype,
DenseTensor* out) {
ReduceKernel<T, Context>(dev_ctx,
x,
dims,
keep_dim,
reduce_all,
out,
dnnl::algorithm::reduction_sum);
}
} // namespace phi
PD_REGISTER_KERNEL(sum_raw,
OneDNN,
ALL_LAYOUT,
phi::SumRawKernel,
float,
phi::dtype::bfloat16) {}
......@@ -45,3 +45,8 @@ PD_REGISTER_KERNEL(
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(max, KPS, ALL_LAYOUT, phi::MaxKernel, float) {}
#endif
#if defined(PADDLE_WITH_MKLDNN)
PD_REGISTER_KERNEL(
max, OneDNN, ALL_LAYOUT, phi::MaxKernel, float, phi::dtype::bfloat16) {}
#endif
......@@ -50,3 +50,8 @@ PD_REGISTER_KERNEL(mean,
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(mean, KPS, ALL_LAYOUT, phi::MeanKernel, float) {}
#endif
#if defined(PADDLE_WITH_MKLDNN)
PD_REGISTER_KERNEL(
mean, OneDNN, ALL_LAYOUT, phi::MeanKernel, float, phi::dtype::bfloat16) {}
#endif
......@@ -45,3 +45,8 @@ PD_REGISTER_KERNEL(
#if defined(PADDLE_WITH_XPU_KP)
PD_REGISTER_KERNEL(min, KPS, ALL_LAYOUT, phi::MinKernel, float) {}
#endif
#if defined(PADDLE_WITH_MKLDNN)
PD_REGISTER_KERNEL(
min, OneDNN, ALL_LAYOUT, phi::MinKernel, float, phi::dtype::bfloat16) {}
#endif
......@@ -78,3 +78,8 @@ PD_REGISTER_KERNEL(sum, KPS, ALL_LAYOUT, phi::SumKernel, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
#endif
#if defined(PADDLE_WITH_MKLDNN)
PD_REGISTER_KERNEL(
sum, OneDNN, ALL_LAYOUT, phi::SumKernel, float, phi::dtype::bfloat16) {}
#endif
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册