未验证 提交 c6c65c65 编写于 作者: J Jacek Czaja 提交者: GitHub

[DNNL] Added elementwise_add mkl-dnn inplace (#23477)

上级 9ff558a4
...@@ -86,7 +86,7 @@ endif() ...@@ -86,7 +86,7 @@ endif()
if(WITH_MKLDNN) if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn) pass_library(mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn)
pass_library(mkldnn_inplace_pass inference DEPS mkldnn_placement_pass op_registry softmax_op softmax DIR mkldnn) pass_library(mkldnn_inplace_pass inference DEPS mkldnn_placement_pass op_registry elementwise_add_op activation_op softmax_op softmax DIR mkldnn)
pass_library(depthwise_conv_mkldnn_pass base DIR mkldnn) pass_library(depthwise_conv_mkldnn_pass base DIR mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn) pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn)
......
...@@ -1892,30 +1892,30 @@ PDNode *patterns::MultipleQuantize::operator()() { ...@@ -1892,30 +1892,30 @@ PDNode *patterns::MultipleQuantize::operator()() {
} }
PDNode *patterns::MKLDNNInPlace::operator()() { PDNode *patterns::MKLDNNInPlace::operator()() {
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, elementwise_add,
// batch_norm....
auto possible_inplace_op = auto possible_inplace_op =
pattern->NewNode(inplace_to_be_op_repr())->assert_is_ops({"softmax"}); pattern->NewNode(inplace_to_be_op_repr())
->assert_is_ops({"elementwise_add", "softmax"});
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, elementwise_add, // TODO(jczaja): Enable more mkl-dnn ops e.g. activation, batch_norm....
// batch_norm....
auto input = pattern->NewNode(inplace_to_be_op_in_repr()) auto input = pattern->NewNode(inplace_to_be_op_in_repr())
->assert_is_ops_input({"softmax"}) ->assert_is_ops_input({"elementwise_add", "softmax"})
->AsInput(); ->AsInput();
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, elementwise_add, // TODO(jczaja): Enable more mkl-dnn ops e.g. activation, batch_norm....
// batch_norm....
auto output = pattern->NewNode(inplace_to_be_op_out_repr()) auto output = pattern->NewNode(inplace_to_be_op_out_repr())
->assert_is_ops_output({"softmax"}) ->assert_is_ops_output({"elementwise_add", "softmax"})
->AsIntermediate(); ->AsOutput();
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op(); auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
auto next_output = pattern->NewNode(next_op_out_repr())->AsOutput();
// Check if op is MKL-DNN enabled // Check if op is MKL-DNN enabled
possible_inplace_op->assert_op_attr("use_mkldnn", true); possible_inplace_op->assert_op_attr("use_mkldnn", true);
// linked structure
possible_inplace_op->LinksTo({output}); possible_inplace_op->LinksTo({output});
possible_inplace_op->LinksFrom({input}); possible_inplace_op->LinksFrom({input});
next_op->LinksFrom({output}); next_op->LinksFrom({output});
next_op->LinksTo({next_output});
return possible_inplace_op; return possible_inplace_op;
} }
......
...@@ -1140,11 +1140,12 @@ struct MKLDNNInPlace : public PatternBase { ...@@ -1140,11 +1140,12 @@ struct MKLDNNInPlace : public PatternBase {
: PatternBase(pattern, name_scope, "mkldnn_inplace") {} : PatternBase(pattern, name_scope, "mkldnn_inplace") {}
PDNode* operator()(); PDNode* operator()();
// MKL-DNN's in-place ops: BatchNorm, Softmax, Layer Norm // MKL-DNN's in-place ops: BatchNorm, Softmax, Elementwise_add
PATTERN_DECL_NODE(inplace_to_be_op); PATTERN_DECL_NODE(inplace_to_be_op);
PATTERN_DECL_NODE(inplace_to_be_op_in); PATTERN_DECL_NODE(inplace_to_be_op_in);
PATTERN_DECL_NODE(inplace_to_be_op_out); PATTERN_DECL_NODE(inplace_to_be_op_out);
PATTERN_DECL_NODE(next_op); PATTERN_DECL_NODE(next_op);
PATTERN_DECL_NODE(next_op_out);
}; };
struct TransposeFlattenConcat : public PatternBase { struct TransposeFlattenConcat : public PatternBase {
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -30,6 +31,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -30,6 +31,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph, PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL.")); "Pointer to graph argument should not be NULL."));
std::unordered_map<std::string, std::string> original_output_names;
GraphPatternDetector gpd; GraphPatternDetector gpd;
patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(), patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(),
"mkldnn_inplace"}; "mkldnn_inplace"};
...@@ -40,72 +42,136 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const { ...@@ -40,72 +42,136 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
Graph* g) { Graph* g) {
VLOG(3) << "Start to handle MKL-DNN In-Place pass"; VLOG(3) << "Start to handle MKL-DNN In-Place pass";
GET_IR_NODE_FROM_SUBGRAPH(inplace_to_be_op, inplace_to_be_op, GET_IR_NODE_FROM_SUBGRAPH(current_op, inplace_to_be_op, mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(current_op_in, inplace_to_be_op_in,
mkldnn_inplace); mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(inplace_to_be_op_in, inplace_to_be_op_in, GET_IR_NODE_FROM_SUBGRAPH(current_op_out, inplace_to_be_op_out,
mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(inplace_to_be_op_out, inplace_to_be_op_out,
mkldnn_inplace); mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, mkldnn_inplace); GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(next_op_out, next_op_out, mkldnn_inplace);
if ((inplace_to_be_op->Op()->HasAttr("use_mkldnn") == false) || if ((current_op->Op()->HasAttr("use_mkldnn") == false) ||
(boost::get<bool>(inplace_to_be_op->Op()->GetAttr("use_mkldnn")) == (boost::get<bool>(current_op->Op()->GetAttr("use_mkldnn")) == false)) {
false)) {
VLOG(3) << "do not perform mkl-dnn inplace: use_mkldnn missing or set to " VLOG(3) << "do not perform mkl-dnn inplace: use_mkldnn missing or set to "
"false"; "false";
return; return;
} }
auto& infer_inplace = OpInfoMap::Instance() auto& infer_inplace =
.Get(inplace_to_be_op->Op()->Type()) OpInfoMap::Instance().Get(current_op->Op()->Type()).infer_inplace_;
.infer_inplace_;
if (!infer_inplace) { if (!infer_inplace) {
VLOG(3) << "do not perform mkl-dnn inplace: missing InplaceInferer"; VLOG(3) << "do not perform mkl-dnn inplace: missing InplaceInferer";
return; return;
} }
// TODO(jczaja): Enable more ops VLOG(3) << "DNNL Inplace op(" << current_op->id() << ") "
if (inplace_to_be_op->Op()->Type() != "softmax") { << "Curr Node In: " << current_op_in->Name()
VLOG(3) << " Curr Node out: " << current_op_out->Name();
<< "Curently works for softmax only. TODO(jczaja): support other ops";
VLOG(3) << "DNNL Inplace next op(" << next_op->id() << ") "
<< " next Node out: " << next_op_out->Name();
auto inputs = current_op->Op()->Inputs();
auto outputs = current_op->Op()->Outputs();
auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN
VLOG(3) << "DNNL InplaceInferer op(" << current_op->id() << ") "
<< in_to_outs.begin()->first << ": "
<< inputs[in_to_outs.begin()->first][0] << " "
<< in_to_outs.begin()->second << ": "
<< outputs[in_to_outs.begin()->second][0];
// If InferInplace pattern does not contain input node then skip
auto inplace_input_vec = inputs[in_to_outs.begin()->first];
if (std::find(inplace_input_vec.begin(), inplace_input_vec.end(),
current_op_in->Name()) == inplace_input_vec.end()) {
VLOG(3) << "DNNL in-place pass SKIP pattern ";
return; return;
} }
// Iterate over all nodes that are ops // Checking if this particular node (to be inplaced, overwritten)
// and check if in-place to be var is part of inputs // is used anywhere else apart from inplaced op
// if positive then do not perform inplace auto input_consumers = current_op_in->outputs;
for (const Node* n : graph->Nodes()) { if (input_consumers.size() > 1) {
if (n->IsOp()) { VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
// Avoid searchin in op that is to be inplace "be an input to multiple operators";
if ((n->id() != inplace_to_be_op->id())) { return;
auto* op = n->Op(); }
auto inputs = op->Inputs();
auto in_place_input = inplace_to_be_op_in->Name(); // If this op was alrady inplaced in previous pass placements
for (auto& it : inputs) { // then we need to update input of next op
for (auto& var_name : it.second) { // but original name to be changed is gone, so we need to remember it
if (var_name == in_place_input) { // on first time given op is to be inplaced
VLOG(3) << "MKL-DNN in-place pass: in-place var cannot be an " if (current_op_in->Name() != current_op_out->Name()) {
"input to more than one operator"; original_output_names[current_op->Name() + current_op_in->Name()] =
return; current_op_out->Name();
} } else {
} VLOG(3) << "DNNL Inplace: Current op already inplaced! ";
}
// It may be that next op is reusing some of vars, we need to
// make sure that unwanted inplace is not created
// TODO(jczaja): Make UT for that one
for (auto& n : current_op_out->outputs) {
auto& n_op_infer_inplace =
OpInfoMap::Instance().Get(n->Op()->Type()).infer_inplace_;
if ((n_op_infer_inplace == nullptr)) {
for (auto& m : n->outputs) {
if (m->Name() == current_op_in->Name()) {
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
"be an output to non-inplaced next op";
return;
} }
} }
} }
} }
auto original_name = inplace_to_be_op_out->Name(); auto original_name =
inplace_to_be_op_out->RenameVar(inplace_to_be_op_in->Name()); original_output_names[current_op->Name() + current_op_in->Name()];
current_op_out->RenameVar(current_op_in->Name());
// Get mapping of input to output // Get mapping of input to output
auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN
// TODO(jczaja): Support more complex situations
auto out_name = in_to_outs.begin()->second; auto out_name = in_to_outs.begin()->second;
inplace_to_be_op->Op()->SetOutput( current_op->Op()->SetOutput(
out_name, std::vector<std::string>({inplace_to_be_op_out->Name()})); out_name, std::vector<std::string>({current_op_out->Name()}));
next_op->Op()->RenameInput(original_name, inplace_to_be_op_out->Name());
// If next op in a line is doing inplace
// then we need to update its output as well
// Get inferer of next op
// If no inferer then we are done
auto& next_op_infer_inplace =
OpInfoMap::Instance().Get(next_op->Op()->Type()).infer_inplace_;
if (next_op_infer_inplace) {
auto in_to_outs = next_op_infer_inplace(false);
auto out_name = in_to_outs.begin()->second;
auto* op = next_op->Op();
auto inputs = op->Inputs();
auto outputs = op->Outputs();
// Check if in-place happened
// for variable we changed (original name)
// TODO(jczaja): make recursive propagation of inplace
auto next_op_inplace_inputs = inputs[in_to_outs.begin()->first];
if ((next_op_inplace_inputs == outputs[in_to_outs.begin()->second]) &&
(std::find(next_op_inplace_inputs.begin(),
next_op_inplace_inputs.end(),
original_name) != next_op_inplace_inputs.end())) {
VLOG(3) << "DNNL InPlace: Next Op is in-placed , updating its "
"input "
"and output var!";
next_op->Op()->SetOutput(
out_name, std::vector<std::string>({current_op_out->Name()}));
next_op_out->RenameVar(current_op_in->Name());
// Get ops that next_op_out is linked to and update their input
auto next_op_out_consumers = next_op_out->outputs; // Has to be ops
for (auto& c : next_op_out_consumers) {
c->Op()->RenameInput(original_name, current_op_out->Name());
}
}
}
next_op->Op()->RenameInput(original_name, current_op_out->Name());
found_inplace_count++; found_inplace_count++;
VLOG(3) << "MKL-DNN InPlace applied!"; VLOG(3) << "DNNL InPlace applied!";
}; };
gpd(graph, handler); gpd(graph, handler);
......
...@@ -21,6 +21,9 @@ ...@@ -21,6 +21,9 @@
USE_OP(softmax); USE_OP(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(relu);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -62,8 +65,9 @@ class MKLDNNInplacePassTest { ...@@ -62,8 +65,9 @@ class MKLDNNInplacePassTest {
bool branched) { bool branched) {
ProgramDesc prog; ProgramDesc prog;
for (auto& v : std::vector<std::string>( for (auto& v :
{"a", "weights", "bias", "f", "g", "h", "i", "j", "k"})) { std::vector<std::string>({"a", "weights", "bias", "f", "g", "h", "i",
"j", "k", "l", "m", "z"})) {
auto* var = prog.MutableBlock(0)->Var(v); auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS); var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") { if (v == "weights" || v == "bias") {
...@@ -83,9 +87,12 @@ class MKLDNNInplacePassTest { ...@@ -83,9 +87,12 @@ class MKLDNNInplacePassTest {
SetOp(&prog, "elementwise_add", "elementwise_add1", SetOp(&prog, "elementwise_add", "elementwise_add1",
std::vector<std::string>({"h", "i"}), std::vector<std::string>({"j"}), std::vector<std::string>({"h", "i"}), std::vector<std::string>({"j"}),
mkldnn_enabled_op.compare("elementwise_add") == 0); mkldnn_enabled_op.compare("elementwise_add") == 0);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"j"}),
std::vector<std::string>({"k"}),
mkldnn_enabled_op.compare("softmax") == 0);
if (branched == true) { if (branched == true) {
SetOp(&prog, "softmax", "softmax2", std::vector<std::string>({"g"}), SetOp(&prog, "softmax", "softmax2", std::vector<std::string>({"g"}),
std::vector<std::string>({"k"}), std::vector<std::string>({"z"}),
mkldnn_enabled_op.compare("softmax") == 0); mkldnn_enabled_op.compare("softmax") == 0);
} }
...@@ -105,12 +112,11 @@ class MKLDNNInplacePassTest { ...@@ -105,12 +112,11 @@ class MKLDNNInplacePassTest {
unsigned use_mkldnn_true_count = 0; unsigned use_mkldnn_true_count = 0;
std::unordered_map<std::string, std::string> input_names; std::unordered_map<std::string, std::string> input_names;
std::unordered_map<std::string, std::string> output_names; std::unordered_map<std::string, std::string> output_names;
input_names["softmax"] = "X"; input_names["softmax"] = "X";
output_names["softmax"] = "Out"; output_names["softmax"] = "Out";
input_names["batch_norm"] = "X"; input_names["elementwise_add"] = "X";
output_names["batch_norm"] = "Y"; output_names["elementwise_add"] = "Out";
input_names["layer_norm"] = "X";
output_names["layer_norm"] = "Y";
VLOG(3) << DebugString(graph); VLOG(3) << DebugString(graph);
...@@ -135,15 +141,18 @@ class MKLDNNInplacePassTest { ...@@ -135,15 +141,18 @@ class MKLDNNInplacePassTest {
TEST(MKLDNNInplacePass, inplace_softmax) { TEST(MKLDNNInplacePass, inplace_softmax) {
// softmax to be mkl-dnn enabled and made in-place // softmax to be mkl-dnn enabled and made in-place
MKLDNNInplacePassTest().MainTest("softmax", false, 1); MKLDNNInplacePassTest().MainTest("softmax", false, 1);
} }
TEST(MKLDNNInplacePass, inplace_softmax_branched) { TEST(MKLDNNInplacePass, inplace_softmax_branched) {
// softmax to be mkl-dnn enabled and made in-place // softmax's input is shared by two branches. so no in-place
MKLDNNInplacePassTest().MainTest("softmax", true, 0); MKLDNNInplacePassTest().MainTest("softmax", true, 0);
} }
TEST(MKLDNNInplacePass, inplace_elementwise_add) {
// Two elementwise_add mkl-dnn enabled op instances to be made inplace
MKLDNNInplacePassTest().MainTest("elementwise_add", false, 1);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -56,39 +56,34 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -56,39 +56,34 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
y->format(), MKLDNNMemoryFormat::undef, y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for Y tensor")); platform::errors::InvalidArgument("Wrong format set for Y tensor"));
const T* x_data = x->data<T>();
const T* y_data = y->data<T>();
auto src_x_tz = framework::vectorize<int64_t>(x->dims()); auto src_x_tz = framework::vectorize<int64_t>(x->dims());
auto src_y_tz = framework::vectorize<int64_t>(y->dims()); auto src_y_tz = framework::vectorize<int64_t>(y->dims());
auto dst_tz = framework::vectorize<int64_t>(z->dims()); auto dst_tz = framework::vectorize<int64_t>(z->dims());
std::vector<float> scales = {1.0f, 1.0f}; // Currently MKL-DNN kernel supports only Z <- X + Y, shape(X) == shape(Y)
// TODO(jczaja): Binary primitive support broadcasting, so we can support
// this in kernel
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_add, src_x_tz, x->format(), y->format(),
dev_ctx, ctx.GetPlace(), ctx.OutputName("Out"));
const std::string key = auto src_x_memory = handler.AcquireSrcMemory(x);
platform::CreateKey(src_x_tz, ctx.OutputName("Out")); auto src_y_memory = handler.AcquireSecondSrcMemory(y);
platform::SumMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); // For Inplace src and and dst are the same memory object
auto dst_memory =
x->IsSharedBufferWith(*z) ? src_x_memory : handler.AcquireDstMemory(z);
auto src_x_memory = handler.AcquireSrcMemory( auto binary_prim = handler.AcquireForwardPrimitive();
{{src_x_tz}, platform::MKLDNNGetDataType<T>(), x->format()},
paddle::platform::to_void_cast(x_data));
auto src_y_memory = handler.AcquireSecondSrcMemory(
{{src_y_tz}, platform::MKLDNNGetDataType<T>(), y->format()},
paddle::platform::to_void_cast(y_data));
auto dst_md = memory::desc({dst_tz}, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
auto sum_pd = handler.AcquireSumPrimitiveDescriptor(
{src_x_memory, src_y_memory}, scales, dst_md);
T* z_data =
z->mutable_data<T>(ctx.GetPlace(), sum_pd->dst_desc().get_size());
auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data);
auto sum_prim = handler.AcquireSum();
mkldnn::stream astream(mkldnn_engine); mkldnn::stream astream(mkldnn_engine);
sum_prim->execute(astream, {{MKLDNN_ARG_MULTIPLE_SRC, *src_x_memory},
{MKLDNN_ARG_MULTIPLE_SRC + 1, *src_y_memory}, std::unordered_map<int, dnnl::memory> args = {
{MKLDNN_ARG_DST, *dst_memory}}); {DNNL_ARG_SRC_0, *src_x_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_memory}};
binary_prim->execute(astream, args);
astream.wait(); astream.wait();
z->set_layout(DataLayout::kMKLDNN); z->set_layout(DataLayout::kMKLDNN);
......
cc_test(test_mkldnn_op_inplace SRCS mkldnn/test_mkldnn_op_inplace.cc DEPS op_registry softmax_op softmax scope device_context enforce executor) cc_test(test_mkldnn_op_inplace SRCS mkldnn/test_mkldnn_op_inplace.cc DEPS op_registry elementwise_add_op softmax_op softmax scope device_context enforce executor)
...@@ -45,7 +45,8 @@ class SoftmaxMKLDNNHandler ...@@ -45,7 +45,8 @@ class SoftmaxMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>( mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) { // Softmax may be inplace then uniq_name is no longer unique
platform::CreateKey(dims, axis, uniq_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md, this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
...@@ -60,7 +61,7 @@ class SoftmaxMKLDNNHandler ...@@ -60,7 +61,7 @@ class SoftmaxMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>( mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) { platform::CreateKey(dims, axis, uniq_name)) {
auto data_softmax_md = auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md = auto diff_softmax_md =
...@@ -95,13 +96,13 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -95,13 +96,13 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto softmax_src_memory_p = handler.AcquireSrcMemory(input); auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
auto softmax_p = handler.AcquireForwardPrimitive(); auto softmax_p = handler.AcquireForwardPrimitive();
// For Inplace src and and dst are the same memory object // For Inplace src and and dst are the same memory object
auto softmax_dst_memory_p = input->Holder() == output->Holder() auto softmax_dst_memory_p = input->IsSharedBufferWith(*output)
? softmax_src_memory_p ? softmax_src_memory_p
: handler.AcquireDstMemory(output); : handler.AcquireDstMemory(output);
mkldnn::stream astream(dev_ctx.GetEngine()); mkldnn::stream astream(dev_ctx.GetEngine());
softmax_p->execute(astream, {{MKLDNN_ARG_SRC, *softmax_src_memory_p}, softmax_p->execute(astream, {{DNNL_ARG_SRC, *softmax_src_memory_p},
{MKLDNN_ARG_DST, *softmax_dst_memory_p}}); {DNNL_ARG_DST, *softmax_dst_memory_p}});
astream.wait(); astream.wait();
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
......
...@@ -27,38 +27,68 @@ ...@@ -27,38 +27,68 @@
USE_OP(softmax); USE_OP(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN); USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
namespace paddle { namespace paddle {
namespace operators { namespace operators {
struct InputVars {
std::string name;
framework::LoDTensor *tensor;
};
template <typename T> template <typename T>
bool TestMain(const platform::Place &place, const framework::DDim &dims) { bool TestMain(const platform::Place &place, const std::string &op_type,
const framework::DDim &dims, const int num_inputs) {
framework::Scope scope; framework::Scope scope;
auto *x = scope.Var("x")->GetMutable<framework::LoDTensor>();
auto *y = scope.Var("y")->GetMutable<framework::LoDTensor>();
x->Resize(dims); std::vector<InputVars> input_names = {
y->Resize(dims); {"x", scope.Var("x")->GetMutable<framework::LoDTensor>()},
{"x1", num_inputs > 1
size_t numel = static_cast<size_t>(framework::product(dims)); ? scope.Var("x1")->GetMutable<framework::LoDTensor>()
: nullptr},
auto x_ptr = x->mutable_data<T>(place); {"x2", num_inputs > 2
auto y_ptr = y->mutable_data<T>(place); ? scope.Var("x2")->GetMutable<framework::LoDTensor>()
: nullptr},
{"x3", num_inputs > 3
? scope.Var("x3")->GetMutable<framework::LoDTensor>()
: nullptr},
{"x4", num_inputs > 4
? scope.Var("x4")->GetMutable<framework::LoDTensor>()
: nullptr}};
auto *y = scope.Var("y")->GetMutable<framework::LoDTensor>();
// Initialize input data
std::uniform_real_distribution<T> dist(static_cast<T>(10.0), std::uniform_real_distribution<T> dist(static_cast<T>(10.0),
static_cast<T>(20.0)); static_cast<T>(20.0));
std::mt19937 engine; std::mt19937 engine;
size_t numel = static_cast<size_t>(framework::product(dims));
for (int i = 0; i < num_inputs; ++i) {
input_names[i].tensor->Resize(dims);
auto data_ptr = input_names[i].tensor->mutable_data<T>(place);
for (size_t i = 0; i < numel; ++i) {
data_ptr[i] = dist(engine);
}
}
// Initialize output
y->Resize(dims);
auto y_ptr = y->mutable_data<T>(place);
for (size_t i = 0; i < numel; ++i) { for (size_t i = 0; i < numel; ++i) {
x_ptr[i] = dist(engine);
y_ptr[i] = static_cast<T>(0); y_ptr[i] = static_cast<T>(0);
} }
auto &pool = platform::DeviceContextPool::Instance(); auto &pool = platform::DeviceContextPool::Instance();
// Out of place (reference) computation // Out of place (reference) computation
auto op_ref = framework::OpRegistry::CreateOp( auto op_ref = num_inputs > 1 ? framework::OpRegistry::CreateOp(
"softmax", {{"X", {"x"}}}, {{"Out", {"y"}}}, {{"use_mkldnn", {true}}}); op_type, {{"X", {"x"}}, {"Y", {"x1"}}},
{{"Out", {"y"}}}, {{"use_mkldnn", {true}}})
: framework::OpRegistry::CreateOp(
op_type, {{"X", {"x"}}}, {{"Out", {"y"}}},
{{"use_mkldnn", {true}}});
op_ref->Run(scope, place); op_ref->Run(scope, place);
pool.Get(place)->Wait(); pool.Get(place)->Wait();
...@@ -66,15 +96,20 @@ bool TestMain(const platform::Place &place, const framework::DDim &dims) { ...@@ -66,15 +96,20 @@ bool TestMain(const platform::Place &place, const framework::DDim &dims) {
auto &ref_tensor = scope.FindVar("y")->Get<framework::LoDTensor>(); auto &ref_tensor = scope.FindVar("y")->Get<framework::LoDTensor>();
// In-place (to be tested) computation // In-place (to be tested) computation
auto op = framework::OpRegistry::CreateOp( auto op = num_inputs > 1 ? framework::OpRegistry::CreateOp(
"softmax", {{"X", {"x"}}}, {{"Out", {"x"}}}, {{"use_mkldnn", {true}}}); op_type, {{"X", {"x"}}, {"Y", {"x1"}}},
{{"Out", {"x"}}}, {{"use_mkldnn", {true}}})
: framework::OpRegistry::CreateOp(
op_type, {{"X", {"x"}}}, {{"Out", {"x"}}},
{{"use_mkldnn", {true}}});
op->Run(scope, place); op->Run(scope, place);
platform::DeviceContextPool::Instance().Get(place)->Wait(); platform::DeviceContextPool::Instance().Get(place)->Wait();
// Get in-place result // Get in-place result
auto &out_tensor = scope.FindVar("x")->Get<framework::LoDTensor>(); auto &out_tensor = scope.FindVar("x")->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
&out_tensor, x, &out_tensor, input_names[0].tensor,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input and output vars should share tensor for In-place test")); "Input and output vars should share tensor for In-place test"));
...@@ -88,7 +123,13 @@ bool TestMain(const platform::Place &place, const framework::DDim &dims) { ...@@ -88,7 +123,13 @@ bool TestMain(const platform::Place &place, const framework::DDim &dims) {
TEST(test_softmax_inplace, cpu_place) { TEST(test_softmax_inplace, cpu_place) {
framework::DDim dims({32, 64}); framework::DDim dims({32, 64});
platform::CPUPlace p; platform::CPUPlace p;
ASSERT_TRUE(TestMain<float>(p, dims)); ASSERT_TRUE(TestMain<float>(p, "softmax", dims, 1));
}
TEST(test_elementwise_add_inplace, cpu_place) {
framework::DDim dims({1, 12, 20, 20});
platform::CPUPlace p;
ASSERT_TRUE(TestMain<float>(p, "elementwise_add", dims, 2));
} }
} // namespace operators } // namespace operators
......
...@@ -101,6 +101,11 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in, ...@@ -101,6 +101,11 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in,
} }
} }
struct mkldnn_dummy_primitive {
struct primitive_desc {};
struct desc {};
};
inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims, inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
mkldnn::memory::data_type data_type, mkldnn::memory::data_type data_type,
MKLDNNMemoryFormat format) { MKLDNNMemoryFormat format) {
......
...@@ -30,7 +30,8 @@ namespace platform { ...@@ -30,7 +30,8 @@ namespace platform {
using user_function = std::function<std::shared_ptr<float>(const float*)>; using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = mkldnn::memory; using memory = mkldnn::memory;
template <typename T, typename TForward, typename TBackward> template <typename T, typename TForward,
typename TBackward = mkldnn_dummy_primitive>
class MKLDNNHandlerT { class MKLDNNHandlerT {
public: public:
MKLDNNHandlerT(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, MKLDNNHandlerT(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
...@@ -351,6 +352,35 @@ class MKLDNNHandler { ...@@ -351,6 +352,35 @@ class MKLDNNHandler {
std::string key_common_; std::string key_common_;
}; };
template <typename T>
class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
public:
BinaryMKLDNNHandler(const dnnl::algorithm algo,
const std::vector<int64_t>& dims,
const MKLDNNMemoryFormat src0_fmt,
const MKLDNNMemoryFormat src1_fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
// TODO(jczaja): Add function checking if data already exists
auto src0_md = dnnl::memory::desc(dims, MKLDNNGetDataType<T>(), src0_fmt);
auto src1_md = dnnl::memory::desc(dims, MKLDNNGetDataType<T>(), src1_fmt);
auto dst_md =
memory::desc(dims, MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(algo, src0_md, src1_md, dst_md);
}
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->src_desc(), to_void_cast<T>(input_data), "@src1_mem_p");
}
};
class SumMKLDNNHandler : public MKLDNNHandler { class SumMKLDNNHandler : public MKLDNNHandler {
public: public:
SumMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx, SumMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx,
...@@ -419,7 +449,7 @@ class ActivationMKLDNNHandler ...@@ -419,7 +449,7 @@ class ActivationMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward, : platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>( mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, unique_name)) { platform::CreateKey(dims, "a", algorithm, unique_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
...@@ -437,7 +467,7 @@ class ActivationMKLDNNHandler ...@@ -437,7 +467,7 @@ class ActivationMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward, : platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>( mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, unique_name)) { platform::CreateKey(dims, "a", algorithm, unique_name)) {
auto diff_dst_md = platform::MKLDNNMemDesc( auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt); dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md = auto src_md =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册