提交 d83d0f33 编写于 作者: X xiaoli.liu@intel.com

extract templated function

test=develop
上级 019dbf7f
......@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/quantize_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
......@@ -59,37 +60,25 @@ class QuantOpKernel : public framework::OpKernel<T> {
std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
bool is_negative = ctx.Attr<bool>("is_negative_input");
mkldnn::memory::primitive_desc dst_pd;
std::shared_ptr<mkldnn::memory::primitive_desc> dst_pd;
std::shared_ptr<mkldnn::memory> dst_memory;
if (is_negative) {
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::s8,
memory::format::nhwc);
dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
dst_memory.reset(
new mkldnn::memory(dst_pd, to_void_cast<int8_t>(output_data)));
platform::ConvMKLDNNHandler::SetDstMemory<int8_t>(
ctx, output, dst_tz, engine, dst_pd, dst_memory);
} else {
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::u8,
memory::format::nhwc);
dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
dst_memory.reset(
new mkldnn::memory(dst_pd, to_void_cast<uint8_t>(output_data)));
platform::ConvMKLDNNHandler::SetDstMemory<uint8_t>(
ctx, output, dst_tz, engine, dst_pd, dst_memory);
}
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(src_pd, dst_pd, attri));
new reorder::primitive_desc(src_pd, *dst_pd, attri));
auto reorder_p = std::shared_ptr<reorder>(
new reorder(*reorder_pd, *src_memory_p, *dst_memory));
pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/place.h"
......@@ -181,6 +182,21 @@ class MKLDNNHandler {
return dims2str(operand_dims) + suffix;
}
template <typename M>
static void SetDstMemory(
const framework::ExecutionContext& ctx, framework::Tensor* output,
std::vector<int> dst_tz, const mkldnn::engine& engine,
std::shared_ptr<mkldnn::memory::primitive_desc>& dst_pd, // NOLINT
std::shared_ptr<mkldnn::memory>& dst_memory) { // NOLINT
M* output_data = output->mutable_data<M>(ctx.GetPlace());
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, paddle::framework::ToMKLDNNDataType(
framework::DataTypeTrait<M>::DataType),
mkldnn::memory::format::nhwc);
dst_pd.reset(new mkldnn::memory::primitive_desc(dst_md, engine));
dst_memory.reset(new mkldnn::memory(*dst_pd, to_void_cast<M>(output_data)));
}
protected:
static std::string dims2str(const mkldnn::memory::dims& operand_dims) {
std::string dstr = "";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册