未验证 提交 2cff0e8a 编写于 作者: J Jacek Czaja 提交者: GitHub

slice & mul & requantize tensors to use mem_desc (#47617)

* slice & mul & requantize

* - Fix to requentize test
上级 7c62d2ab
...@@ -199,15 +199,17 @@ class MulPrimitiveFactory { ...@@ -199,15 +199,17 @@ class MulPrimitiveFactory {
const ExecutionContext &ctx) { const ExecutionContext &ctx) {
Tensor x_tmp; Tensor x_tmp;
Tensor data_matrix; Tensor data_matrix;
MKLDNNMemoryFormat src_fmt = data->format(); // This code is enforcing plain (non-blocked) memory arrangement
MKLDNNMemoryFormat dst_fmt; // in order to flatten (reduce dimensionality) of Tensor later
auto src_mdesc = CreateMemDescriptor<T>(data, src_fmt); auto src_mdesc = data->mem_desc();
auto dst_mdesc =
if ((data->dims().size() == 4 && data->dims().size() >= 4
src_fmt != (dst_fmt = MKLDNNMemoryFormat::nchw)) || ? (data->dims().size() == 5
(data->dims().size() == 5 && ? CreateMemDescriptor<T>(data, MKLDNNMemoryFormat::ncdhw)
src_fmt != (dst_fmt = MKLDNNMemoryFormat::ncdhw))) { : CreateMemDescriptor<T>(data, MKLDNNMemoryFormat::nchw))
auto dst_mdesc = CreateMemDescriptor<T>(data, dst_fmt); : src_mdesc;
if (src_mdesc != dst_mdesc) {
x_tmp.mutable_data<T>(ctx.GetPlace(), data->memory_size()); x_tmp.mutable_data<T>(ctx.GetPlace(), data->memory_size());
Reorder(src_mdesc, Reorder(src_mdesc,
......
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <iterator> // NOLINT
#include "dnnl.hpp" // NOLINT #include "dnnl.hpp" // NOLINT
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
...@@ -85,15 +86,19 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -85,15 +86,19 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
if (reorder_p == nullptr) { if (reorder_p == nullptr) {
auto dst_tz = phi::vectorize(output->dims());
auto src_dt = framework::ToMKLDNNDataType( auto src_dt = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(input->dtype())); framework::TransToProtoVarType(input->dtype()));
auto dst_dt = with_shift ? framework::MKLDNNDataType::u8 : src_dt; auto dst_dt = with_shift ? framework::MKLDNNDataType::u8 : src_dt;
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, input->format());
src_memory = std::make_shared<dnnl::memory>( src_memory = std::make_shared<dnnl::memory>(
src_md, engine, to_void_cast<T>(input_data)); input->mem_desc(), engine, to_void_cast<T>(input_data));
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, input->format());
auto xstrides = input->mem_desc().data.format_desc.blocking.strides;
std::vector<dnnl_dim_t> vstrides(xstrides,
xstrides + input->mem_desc().data.ndims);
auto dst_md = dnnl::memory::desc({src_tz}, dst_dt, vstrides);
dnnl::primitive_attr attri; dnnl::primitive_attr attri;
int mask = 0; int mask = 0;
......
...@@ -162,11 +162,9 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -162,11 +162,9 @@ class SliceOp : public framework::OperatorWithKernel {
// reorders, because if blocked dimension is not divisible by 8 or // reorders, because if blocked dimension is not divisible by 8 or
// 16(depending on which blocking format is used) submemory cannot be // 16(depending on which blocking format is used) submemory cannot be
// created, so in that scenario a fallback is needed // created, so in that scenario a fallback is needed
auto tmp_md = dnnl::memory::desc( if (ctx.Input<phi::DenseTensor>("Input")
phi::vectorize(ctx.Input<phi::DenseTensor>("Input")->dims()), ->mem_desc()
dnnl::memory::data_type::f32, .data.format_desc.blocking.inner_nblks == 0)
ctx.Input<phi::DenseTensor>("Input")->format());
if (tmp_md.data.format_desc.blocking.inner_nblks == 0)
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type,
ctx.GetPlace(), ctx.GetPlace(),
phi::DataLayout::kMKLDNN, phi::DataLayout::kMKLDNN,
...@@ -337,13 +335,9 @@ class SliceOpGrad : public framework::OperatorWithKernel { ...@@ -337,13 +335,9 @@ class SliceOpGrad : public framework::OperatorWithKernel {
// reorders, because if blocked dimension is not divisible by 8 or // reorders, because if blocked dimension is not divisible by 8 or
// 16(depending on which blocking format is used) submemory cannot be // 16(depending on which blocking format is used) submemory cannot be
// created, so in that scenario a fallback is needed // created, so in that scenario a fallback is needed
auto tmp_md = dnnl::memory::desc( if (ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"))
phi::vectorize( ->mem_desc()
ctx.Input<phi::DenseTensor>(framework::GradVarName("Out")) .data.format_desc.blocking.inner_nblks == 0)
->dims()),
dnnl::memory::data_type::f32,
ctx.Input<phi::DenseTensor>(framework::GradVarName("Out"))->format());
if (tmp_md.data.format_desc.blocking.inner_nblks == 0)
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type,
ctx.GetPlace(), ctx.GetPlace(),
phi::DataLayout::kMKLDNN, phi::DataLayout::kMKLDNN,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册