未验证 提交 12d8a567 编写于 作者: J jakpiase 提交者: GitHub

OneDNN md-in-tensor refactoring part 5: Memory descriptor enabled for...

OneDNN md-in-tensor refactoring part 5: Memory descriptor enabled for elementwises, reductions and expand_v2 ops (#43036)

* enabled md in elementwises, reductions and expand_v2

* CI fix for invalid numpy copy

* fixed formatting

* CI rerun

* changes after review
上级 13a21cf7
...@@ -145,8 +145,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> { ...@@ -145,8 +145,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
binary_prim->execute(astream, args); binary_prim->execute(astream, args);
astream.wait(); astream.wait();
z->set_layout(DataLayout::kMKLDNN); z->set_mem_desc(dst_memory->get_desc());
z->set_format(platform::GetMKLDNNFormat(*dst_memory));
} }
}; };
...@@ -179,7 +178,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -179,7 +178,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
onednn_engine); onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>())); dout->mem_desc(), platform::to_void_cast(dout->data<T>()));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
...@@ -189,7 +188,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -189,7 +188,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
// elementwise_add & elementwise_sub // elementwise_add & elementwise_sub
if (BINARY_OP == dnnl::algorithm::binary_add || if (BINARY_OP == dnnl::algorithm::binary_add ||
BINARY_OP == dnnl::algorithm::binary_sub) { BINARY_OP == dnnl::algorithm::binary_sub) {
dst_memory = reorder_handler.AcquireDstMemory(dx, dout->format(), dst_memory = reorder_handler.AcquireDstMemory(dx, dout->mem_desc(),
ctx.GetPlace()); ctx.GetPlace());
auto reorder_p = auto reorder_p =
reorder_handler.AcquireReorder(dst_memory, reorder_src_memory_p); reorder_handler.AcquireReorder(dst_memory, reorder_src_memory_p);
...@@ -218,8 +217,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -218,8 +217,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
} }
astream.wait(); astream.wait();
dx->set_layout(framework::DataLayout::kMKLDNN); dx->set_mem_desc(dst_memory->get_desc());
dx->set_format(platform::GetMKLDNNFormat(*dst_memory));
} }
if (dy) { if (dy) {
...@@ -232,7 +230,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -232,7 +230,7 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
BINARY_OP == dnnl::algorithm::binary_sub) { BINARY_OP == dnnl::algorithm::binary_sub) {
if (dout->dims() == dy->dims()) { if (dout->dims() == dy->dims()) {
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dy, dout->format(), ctx.GetPlace()); dy, dout->mem_desc(), ctx.GetPlace());
dnnl::primitive_attr reorder_attr; dnnl::primitive_attr reorder_attr;
std::vector<float> scales(1); std::vector<float> scales(1);
...@@ -301,7 +299,6 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -301,7 +299,6 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
dst_memory = dst_dy_memory; dst_memory = dst_dy_memory;
} }
astream.wait(); astream.wait();
dy->set_layout(DataLayout::kMKLDNN);
if (dout->dims() != dy->dims()) { if (dout->dims() != dy->dims()) {
// Broadcasting // Broadcasting
...@@ -324,10 +321,10 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> { ...@@ -324,10 +321,10 @@ class EltwiseMKLDNNGradKernel : public ElemwiseGradKernel<T> {
{DNNL_ARG_DST, *dst_memory}, {DNNL_ARG_DST, *dst_memory},
}); });
astream.wait(); astream.wait();
dy->set_format(platform::GetMKLDNNFormat(dst_memory->get_desc().reshape( dy->set_mem_desc(dst_memory->get_desc().reshape(
phi::vectorize<int64_t>(dy->dims())))); phi::vectorize<int64_t>(dy->dims())));
} else { } else {
dy->set_format(platform::GetMKLDNNFormat(*dst_memory)); dy->set_mem_desc(dst_memory->get_desc());
} }
} }
} }
......
...@@ -45,19 +45,17 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -45,19 +45,17 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> {
out_new_dims[i] = out_new_dims[i] > 0 ? out_new_dims[i] : x_vec_dims[i]; out_new_dims[i] = out_new_dims[i] > 0 ? out_new_dims[i] : x_vec_dims[i];
} }
dnnl::memory::desc x_mem_desc = x->mem_desc();
if (x_vec_dims.size() != out_new_dims.size()) { if (x_vec_dims.size() != out_new_dims.size()) {
x_mem_desc = GetExtendedMemoryDescriptor(x_mem_desc, x_vec_dims, x_vec_dims = GetExtendedXDims(x_vec_dims, out_new_dims.size());
out_new_dims.size());
} }
out->Resize(phi::make_ddim(out_new_dims)); out->Resize(phi::make_ddim(out_new_dims));
paddle::platform::BroadcastDataMKLDNNHandler<T> handler( paddle::platform::BroadcastDataMKLDNNHandler<T> handler(
dnnl::algorithm::binary_add, onednn_engine, ctx.GetPlace(), out, x, dnnl::algorithm::binary_add, onednn_engine, ctx.GetPlace(), x, out,
0.0f, 1.0f, x_mem_desc); 0.0f, 1.0f, x_vec_dims);
auto src_memory_p = handler.AcquireSrcMemory(x); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out); // acquires zeroed mem auto dst_memory_p = handler.AcquireZeroedDstMemory(out);
auto binary_p = handler.AcquireForwardPrimitive(); auto binary_p = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = { const std::unordered_map<int, dnnl::memory> args = {
...@@ -73,14 +71,13 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -73,14 +71,13 @@ class ExpandMKLDNNKernel : public paddle::framework::OpKernel<T> {
} }
private: private:
dnnl::memory::desc GetExtendedMemoryDescriptor( std::vector<int64_t> GetExtendedXDims(const std::vector<int64_t>& x_vec_dims,
const dnnl::memory::desc& x_mem_desc, int new_size) const {
const std::vector<int64_t>& x_vec_dims, int new_size) const { std::vector<int64_t> extended_x_dims(new_size, 1);
std::vector<int64_t> new_dims(new_size, 1);
std::copy(x_vec_dims.begin(), x_vec_dims.end(), std::copy(x_vec_dims.begin(), x_vec_dims.end(),
new_dims.begin() + new_size - x_vec_dims.size()); extended_x_dims.begin() + new_size - x_vec_dims.size());
return x_mem_desc.reshape(new_dims); return extended_x_dims;
} }
}; };
......
...@@ -29,11 +29,11 @@ inline std::vector<int64_t> CalculateReducedDims( ...@@ -29,11 +29,11 @@ inline std::vector<int64_t> CalculateReducedDims(
bool reduce_all, bool keep_dim) { bool reduce_all, bool keep_dim) {
if (keep_dim) return phi::vectorize(output->dims()); if (keep_dim) return phi::vectorize(output->dims());
if (reduce_all) if (reduce_all) return std::vector<int64_t>(input->dims().size(), 1);
return std::vector<int64_t>(phi::vectorize(input->dims()).size(), 1);
std::vector<int64_t> output_dims(phi::vectorize(input->dims())); std::vector<int64_t> output_dims(phi::vectorize(input->dims()));
for (size_t i = 0; i < reduce_dims.size(); ++i) { for (size_t i = 0; i < reduce_dims.size(); ++i) {
// handle negative dims, f.e. "-1" means rightmost dimension
reduce_dims[i] = (reduce_dims[i] >= 0) reduce_dims[i] = (reduce_dims[i] >= 0)
? reduce_dims[i] ? reduce_dims[i]
: input->dims().size() + reduce_dims[i]; : input->dims().size() + reduce_dims[i];
...@@ -52,16 +52,16 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> { ...@@ -52,16 +52,16 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine(); const auto& onednn_engine = dev_ctx.GetEngine();
const auto* input = ctx.Input<LoDTensor>("X"); const auto* x = ctx.Input<LoDTensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
auto reduce_dims = ctx.Attr<std::vector<int>>("dim"); auto reduce_dims = ctx.Attr<std::vector<int>>("dim");
bool reduce_all = ctx.Attr<bool>("reduce_all"); bool reduce_all = ctx.Attr<bool>("reduce_all");
bool keep_dim = ctx.Attr<bool>("keep_dim"); bool keep_dim = ctx.Attr<bool>("keep_dim");
auto output_dims = auto x_tz = phi::vectorize(x->dims());
CalculateReducedDims(input, output, reduce_dims, reduce_all, keep_dim); auto out_tz =
auto input_dims = phi::vectorize(input->dims()); CalculateReducedDims(x, out, reduce_dims, reduce_all, keep_dim);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
...@@ -69,18 +69,19 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> { ...@@ -69,18 +69,19 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
// copied without actual reduction. // copied without actual reduction.
// In that case reorder must be executed to maintain compatibility with // In that case reorder must be executed to maintain compatibility with
// PaddlePaddle reduce op // PaddlePaddle reduce op
if (input_dims == output_dims) { if (x_tz == out_tz) {
dnnl::memory::data_type input_type = framework::ToMKLDNNDataType( dnnl::memory::data_type x_type = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(input->dtype())); framework::TransToProtoVarType(x->dtype()));
platform::ReorderMKLDNNHandler reorder_handler( platform::ReorderMKLDNNHandler reorder_handler(
input_dims, framework::TransToProtoVarType(input->dtype()), x_tz, framework::TransToProtoVarType(x->dtype()), x_type,
input_type, onednn_engine); onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory( auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
input->mem_desc(), platform::to_void_cast(input->data<T>())); x->mem_desc(), platform::to_void_cast(x->data<T>()));
auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory( // reuse mem desc since it is a simple copy
output, input->mem_desc(), ctx.GetPlace()); auto reorder_dst_memory_p =
reorder_handler.AcquireDstMemory(out, x->mem_desc(), ctx.GetPlace());
auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p, auto reorder_p = reorder_handler.AcquireReorder(reorder_src_memory_p,
reorder_dst_memory_p); reorder_dst_memory_p);
...@@ -88,15 +89,15 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> { ...@@ -88,15 +89,15 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p); reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);
astream.wait(); astream.wait();
output->set_mem_desc(reorder_dst_memory_p->get_desc().reshape( out->set_mem_desc(reorder_dst_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(output->dims()))); phi::vectorize<int64_t>(out->dims())));
} else { } else {
platform::ReductionMKLDNNHandler<T> handler(reduction_type, 0.0f, 0.0f, platform::ReductionMKLDNNHandler<T> handler(reduction_type, 0.0f, 0.0f,
onednn_engine, ctx.GetPlace(), onednn_engine, ctx.GetPlace(),
input, output, output_dims); x, out, out_tz);
auto src_memory_p = handler.AcquireSrcMemory(input); auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(output); auto dst_memory_p = handler.AcquireDstMemory(out);
std::unordered_map<int, dnnl::memory> reduction_args = { std::unordered_map<int, dnnl::memory> reduction_args = {
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}}; {DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
...@@ -105,8 +106,9 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> { ...@@ -105,8 +106,9 @@ class ReduceMKLDNNKernel : public framework::OpKernel<T> {
reduction_p->execute(astream, reduction_args); reduction_p->execute(astream, reduction_args);
astream.wait(); astream.wait();
output->set_mem_desc(dst_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(output->dims()))); out->set_mem_desc(dst_memory_p->get_desc().reshape(
phi::vectorize<int64_t>(out->dims())));
} }
} }
}; };
...@@ -127,22 +129,15 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel<T> { ...@@ -127,22 +129,15 @@ class ReduceGradMKLDNNKernel : public framework::OpKernel<T> {
const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out")); const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X")); auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
const auto input_dims = auto dout_tz = CalculateReducedDims(dx, dout, dims, reduce_all, keep_dim);
CalculateReducedDims(dx, dout, dims, reduce_all, keep_dim); auto dx_tz = phi::vectorize(dx->dims());
const auto output_dims = phi::vectorize(dx->dims());
auto dout_mem_desc = dout->mem_desc();
if (input_dims != output_dims) {
dout_mem_desc = dout_mem_desc.reshape(input_dims);
}
platform::BroadcastDataMKLDNNHandler<T> handler( platform::BroadcastDataMKLDNNHandler<T> handler(binary_type, onednn_engine,
binary_type, onednn_engine, ctx.GetPlace(), dx, dout, scale_x, scale_y, ctx.GetPlace(), dout, dx,
dout_mem_desc); scale_x, scale_y, dout_tz);
const auto src_memory_p = handler.AcquireSrcMemory(dout); const auto src_memory_p = handler.AcquireSrcMemory(dout);
const auto dst_memory_p = handler.AcquireDstMemory(dx); const auto dst_memory_p = handler.AcquireZeroedDstMemory(dx);
const auto binary_prim = handler.AcquireForwardPrimitive(); const auto binary_prim = handler.AcquireForwardPrimitive();
const std::unordered_map<int, dnnl::memory> args = { const std::unordered_map<int, dnnl::memory> args = {
......
...@@ -616,29 +616,17 @@ class BinaryMKLDNNHandler ...@@ -616,29 +616,17 @@ class BinaryMKLDNNHandler
public: public:
BinaryMKLDNNHandler(const dnnl::algorithm algo, const int axis, BinaryMKLDNNHandler(const dnnl::algorithm algo, const int axis,
const dnnl::engine engine, platform::Place cpu_place, const dnnl::engine engine, platform::Place cpu_place,
const Tensor* x, const Tensor* y, Tensor* z, const Tensor* x, const Tensor* y, Tensor* out,
float scale_x, float scale_y, float scale_z, float scale_x, float scale_y, float scale_out,
const dnnl::post_ops& post_ops = dnnl::post_ops{}) const dnnl::post_ops& post_ops = dnnl::post_ops{})
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) { : platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for X tensor. Expected: %d (kMKLDNN), Actual: %d",
DataLayout::kMKLDNN, x->layout()));
PADDLE_ENFORCE_EQ(
y->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"Wrong layout set for Y tensor. Expected: %d (kMKLDNN), Actual: %d",
DataLayout::kMKLDNN, y->layout()));
const auto src_x_tz = phi::vectorize(x->dims()); const auto src_x_tz = phi::vectorize(x->dims());
const auto src_y_tz = phi::vectorize(y->dims()); const auto src_y_tz = phi::vectorize(y->dims());
// if output tensor(z) is nullptr then we are computing into oneDNN // if output tensor(z) is nullptr then we are computing into oneDNN
// managed buffer // managed buffer
auto rankdiff = x->dims().size() - y->dims().size(); auto rankdiff = x->dims().size() - y->dims().size();
const auto dst_tz = (z == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz) const auto dst_tz = (out == nullptr) ? (rankdiff > 0 ? src_x_tz : src_y_tz)
: phi::vectorize(z->dims()); : phi::vectorize(out->dims());
auto src0_md = x->mem_desc(); auto src0_md = x->mem_desc();
auto src1_md = y->mem_desc(); auto src1_md = y->mem_desc();
...@@ -667,7 +655,7 @@ class BinaryMKLDNNHandler ...@@ -667,7 +655,7 @@ class BinaryMKLDNNHandler
MKLDNNMemoryFormat::any); MKLDNNMemoryFormat::any);
auto attributes = auto attributes =
CreateAttributes(algo, scale_x, scale_y, scale_z, post_ops); CreateAttributes(algo, scale_x, scale_y, scale_out, post_ops);
this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md, this->AcquireForwardPrimitiveDescriptor(attributes, algo, src0_md, src1_md,
dst_md); dst_md);
...@@ -681,7 +669,7 @@ class BinaryMKLDNNHandler ...@@ -681,7 +669,7 @@ class BinaryMKLDNNHandler
private: private:
static inline dnnl::primitive_attr CreateAttributes( static inline dnnl::primitive_attr CreateAttributes(
dnnl::algorithm op, float scale_x, float scale_y, float scale_z, dnnl::algorithm op, float scale_x, float scale_y, float scale_out,
dnnl::post_ops post_ops = dnnl::post_ops{}) { dnnl::post_ops post_ops = dnnl::post_ops{}) {
// Scales set in attributes for inputs contibute to the output equation // Scales set in attributes for inputs contibute to the output equation
// in the following way (assuming no broadcasting takes place): // in the following way (assuming no broadcasting takes place):
...@@ -699,9 +687,9 @@ class BinaryMKLDNNHandler ...@@ -699,9 +687,9 @@ class BinaryMKLDNNHandler
// For mul operation on the other hand // For mul operation on the other hand
// output = (scale_out / scale_x) * x * (1.0 / scale_y) * y // output = (scale_out / scale_x) * x * (1.0 / scale_y) * y
// <scale_0> <scale_1> // <scale_0> <scale_1>
float scale_0 = scale_z / scale_x; float scale_0 = scale_out / scale_x;
float scale_1 = float scale_1 =
op == dnnl::algorithm::binary_add ? scale_z / scale_y : 1.0 / scale_y; op == dnnl::algorithm::binary_add ? scale_out / scale_y : 1.0 / scale_y;
dnnl::primitive_attr attributes; dnnl::primitive_attr attributes;
attributes.set_scales(/* input_x_id = */ DNNL_ARG_SRC_0, /* mask = */ 0, attributes.set_scales(/* input_x_id = */ DNNL_ARG_SRC_0, /* mask = */ 0,
{scale_0}); {scale_0});
...@@ -718,21 +706,15 @@ class BroadcastDataMKLDNNHandler ...@@ -718,21 +706,15 @@ class BroadcastDataMKLDNNHandler
public: public:
BroadcastDataMKLDNNHandler(const dnnl::algorithm algo, BroadcastDataMKLDNNHandler(const dnnl::algorithm algo,
const dnnl::engine engine, const dnnl::engine engine,
platform::Place cpu_place, const Tensor* out, platform::Place cpu_place, const Tensor* x,
const Tensor* x, float scale_x, float scale_y, Tensor* out, float scale_x, float scale_y,
const dnnl::memory::desc& x_mem_desc) const std::vector<int64_t>& extended_x_dims)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) { : platform::MKLDNNHandlerNoCachingT<T, dnnl::binary>(engine, cpu_place) {
PADDLE_ENFORCE_EQ(
x->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
const auto src0_tz = phi::vectorize(out->dims()); const auto src0_tz = phi::vectorize(out->dims());
const auto src0_md = const auto src0_md =
dnnl::memory::desc(src0_tz, platform::MKLDNNGetDataType<T>(), dnnl::memory::desc(src0_tz, platform::MKLDNNGetDataType<T>(),
platform::GetPlainMKLDNNFormat(src0_tz.size())); platform::GetPlainMKLDNNFormat(src0_tz.size()));
const auto src1_md = x->mem_desc().reshape(extended_x_dims);
const auto src1_md = x_mem_desc;
dnnl::primitive_attr attributes; dnnl::primitive_attr attributes;
attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x}); attributes.set_scales(DNNL_ARG_SRC_0, 0, {scale_x});
...@@ -743,9 +725,9 @@ class BroadcastDataMKLDNNHandler ...@@ -743,9 +725,9 @@ class BroadcastDataMKLDNNHandler
} }
template <typename T_out = T> template <typename T_out = T>
std::shared_ptr<dnnl::memory> AcquireDstMemory(framework::Tensor* output) { std::shared_ptr<dnnl::memory> AcquireZeroedDstMemory(framework::Tensor* out) {
T_out* ptr = output->mutable_data<T_out>( T_out* ptr = out->mutable_data<T_out>(this->place_,
this->place_, this->fwd_pd_->dst_desc().get_size()); this->fwd_pd_->dst_desc().get_size());
memset(ptr, 0, this->fwd_pd_->dst_desc().get_size()); memset(ptr, 0, this->fwd_pd_->dst_desc().get_size());
return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr); return this->AcquireMemoryFromPrimitive(this->fwd_pd_->dst_desc(), ptr);
} }
...@@ -758,22 +740,18 @@ class ReductionMKLDNNHandler ...@@ -758,22 +740,18 @@ class ReductionMKLDNNHandler
ReductionMKLDNNHandler(const dnnl::algorithm algo, const float p, ReductionMKLDNNHandler(const dnnl::algorithm algo, const float p,
const float eps, const dnnl::engine engine, const float eps, const dnnl::engine engine,
platform::Place cpu_place, const Tensor* x, platform::Place cpu_place, const Tensor* x,
const Tensor* y, std::vector<int64_t> y_tz, const Tensor* out, std::vector<int64_t> out_tz,
const dnnl::primitive_attr& attr = NULL) const dnnl::primitive_attr& attrs = NULL)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction>(engine, : platform::MKLDNNHandlerNoCachingT<T, dnnl::reduction>(engine,
cpu_place) { cpu_place) {
PADDLE_ENFORCE_EQ( const auto out_md = memory::desc(out_tz, platform::MKLDNNGetDataType<T>(),
x->layout(), DataLayout::kMKLDNN, dnnl::memory::format_tag::any);
platform::errors::InvalidArgument("Wrong layout set for X tensor."));
const auto y_md = memory::desc(y_tz, platform::MKLDNNGetDataType<T>(),
dnnl::memory::format_tag::any);
if (attr) if (attrs)
this->AcquireForwardPrimitiveDescriptor(attr, algo, x->mem_desc(), y_md, this->AcquireForwardPrimitiveDescriptor(attrs, algo, x->mem_desc(),
p, eps); out_md, p, eps);
else else
this->AcquireForwardPrimitiveDescriptor(algo, x->mem_desc(), y_md, p, this->AcquireForwardPrimitiveDescriptor(algo, x->mem_desc(), out_md, p,
eps); eps);
} }
}; };
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import unittest import unittest
import numpy as np import numpy as np
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci from paddle.fluid.tests.unittests.op_test import OpTest, OpTestTool, skip_check_grad_ci
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle import paddle
...@@ -92,6 +92,17 @@ class TestReduceSum4DReduceAllOneDNNOp(TestReduceDefaultWithGradOneDNNOp): ...@@ -92,6 +92,17 @@ class TestReduceSum4DReduceAllOneDNNOp(TestReduceDefaultWithGradOneDNNOp):
self.outputs = {'Out': self.inputs['X'].sum()} self.outputs = {'Out': self.inputs['X'].sum()}
@OpTestTool.skip_if_not_cpu()
class TestReduceSum4DNoReduceSimpleCopyOneDNNOp(
TestReduceDefaultWithGradOneDNNOp):
def setUp(self):
self.op_type = "reduce_sum"
self.use_mkldnn = True
self.inputs = {'X': np.random.random((5, 6, 2, 10)).astype("float32")}
self.attrs = {'dim': tuple(), 'use_mkldnn': self.use_mkldnn}
self.outputs = {'Out': np.copy(self.inputs['X'])}
@skip_check_grad_ci( @skip_check_grad_ci(
reason="reduce_max is discontinuous non-derivable function," reason="reduce_max is discontinuous non-derivable function,"
" its gradient check is not supported by unittest framework.") " its gradient check is not supported by unittest framework.")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册