未验证 提交 c6c65c65 编写于 作者: J Jacek Czaja 提交者: GitHub

[DNNL] Added elementwise_add mkl-dnn inplace (#23477)

上级 9ff558a4
......@@ -86,7 +86,7 @@ endif()
if(WITH_MKLDNN)
pass_library(mkldnn_placement_pass base DEPS placement_pass_base DIR mkldnn)
pass_library(mkldnn_inplace_pass inference DEPS mkldnn_placement_pass op_registry softmax_op softmax DIR mkldnn)
pass_library(mkldnn_inplace_pass inference DEPS mkldnn_placement_pass op_registry elementwise_add_op activation_op softmax_op softmax DIR mkldnn)
pass_library(depthwise_conv_mkldnn_pass base DIR mkldnn)
pass_library(conv_bias_mkldnn_fuse_pass inference DIR mkldnn)
pass_library(conv_activation_mkldnn_fuse_pass inference DIR mkldnn)
......
......@@ -1892,30 +1892,30 @@ PDNode *patterns::MultipleQuantize::operator()() {
}
PDNode *patterns::MKLDNNInPlace::operator()() {
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, elementwise_add,
// batch_norm....
auto possible_inplace_op =
pattern->NewNode(inplace_to_be_op_repr())->assert_is_ops({"softmax"});
pattern->NewNode(inplace_to_be_op_repr())
->assert_is_ops({"elementwise_add", "softmax"});
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, elementwise_add,
// batch_norm....
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, batch_norm....
auto input = pattern->NewNode(inplace_to_be_op_in_repr())
->assert_is_ops_input({"softmax"})
->assert_is_ops_input({"elementwise_add", "softmax"})
->AsInput();
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, elementwise_add,
// batch_norm....
// TODO(jczaja): Enable more mkl-dnn ops e.g. activation, batch_norm....
auto output = pattern->NewNode(inplace_to_be_op_out_repr())
->assert_is_ops_output({"softmax"})
->AsIntermediate();
->assert_is_ops_output({"elementwise_add", "softmax"})
->AsOutput();
auto next_op = pattern->NewNode(next_op_repr())->assert_is_op();
auto next_output = pattern->NewNode(next_op_out_repr())->AsOutput();
// Check if op is MKL-DNN enabled
possible_inplace_op->assert_op_attr("use_mkldnn", true);
// linked structure
possible_inplace_op->LinksTo({output});
possible_inplace_op->LinksFrom({input});
next_op->LinksFrom({output});
next_op->LinksTo({next_output});
return possible_inplace_op;
}
......
......@@ -1140,11 +1140,12 @@ struct MKLDNNInPlace : public PatternBase {
: PatternBase(pattern, name_scope, "mkldnn_inplace") {}
PDNode* operator()();
// MKL-DNN's in-place ops: BatchNorm, Softmax, Layer Norm
// MKL-DNN's in-place ops: BatchNorm, Softmax, Elementwise_add
PATTERN_DECL_NODE(inplace_to_be_op);
PATTERN_DECL_NODE(inplace_to_be_op_in);
PATTERN_DECL_NODE(inplace_to_be_op_out);
PATTERN_DECL_NODE(next_op);
PATTERN_DECL_NODE(next_op_out);
};
struct TransposeFlattenConcat : public PatternBase {
......
......@@ -16,6 +16,7 @@
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -30,6 +31,7 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(graph,
platform::errors::InvalidArgument(
"Pointer to graph argument should not be NULL."));
std::unordered_map<std::string, std::string> original_output_names;
GraphPatternDetector gpd;
patterns::MKLDNNInPlace mkldnn_inplace{gpd.mutable_pattern(),
"mkldnn_inplace"};
......@@ -40,72 +42,136 @@ void MKLDNNInPlacePass::ApplyImpl(ir::Graph* graph) const {
Graph* g) {
VLOG(3) << "Start to handle MKL-DNN In-Place pass";
GET_IR_NODE_FROM_SUBGRAPH(inplace_to_be_op, inplace_to_be_op,
GET_IR_NODE_FROM_SUBGRAPH(current_op, inplace_to_be_op, mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(current_op_in, inplace_to_be_op_in,
mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(inplace_to_be_op_in, inplace_to_be_op_in,
mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(inplace_to_be_op_out, inplace_to_be_op_out,
GET_IR_NODE_FROM_SUBGRAPH(current_op_out, inplace_to_be_op_out,
mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(next_op, next_op, mkldnn_inplace);
GET_IR_NODE_FROM_SUBGRAPH(next_op_out, next_op_out, mkldnn_inplace);
if ((inplace_to_be_op->Op()->HasAttr("use_mkldnn") == false) ||
(boost::get<bool>(inplace_to_be_op->Op()->GetAttr("use_mkldnn")) ==
false)) {
if ((current_op->Op()->HasAttr("use_mkldnn") == false) ||
(boost::get<bool>(current_op->Op()->GetAttr("use_mkldnn")) == false)) {
VLOG(3) << "do not perform mkl-dnn inplace: use_mkldnn missing or set to "
"false";
return;
}
auto& infer_inplace = OpInfoMap::Instance()
.Get(inplace_to_be_op->Op()->Type())
.infer_inplace_;
auto& infer_inplace =
OpInfoMap::Instance().Get(current_op->Op()->Type()).infer_inplace_;
if (!infer_inplace) {
VLOG(3) << "do not perform mkl-dnn inplace: missing InplaceInferer";
return;
}
// TODO(jczaja): Enable more ops
if (inplace_to_be_op->Op()->Type() != "softmax") {
VLOG(3)
<< "Curently works for softmax only. TODO(jczaja): support other ops";
VLOG(3) << "DNNL Inplace op(" << current_op->id() << ") "
<< "Curr Node In: " << current_op_in->Name()
<< " Curr Node out: " << current_op_out->Name();
VLOG(3) << "DNNL Inplace next op(" << next_op->id() << ") "
<< " next Node out: " << next_op_out->Name();
auto inputs = current_op->Op()->Inputs();
auto outputs = current_op->Op()->Outputs();
auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN
VLOG(3) << "DNNL InplaceInferer op(" << current_op->id() << ") "
<< in_to_outs.begin()->first << ": "
<< inputs[in_to_outs.begin()->first][0] << " "
<< in_to_outs.begin()->second << ": "
<< outputs[in_to_outs.begin()->second][0];
// If InferInplace pattern does not contain input node then skip
auto inplace_input_vec = inputs[in_to_outs.begin()->first];
if (std::find(inplace_input_vec.begin(), inplace_input_vec.end(),
current_op_in->Name()) == inplace_input_vec.end()) {
VLOG(3) << "DNNL in-place pass SKIP pattern ";
return;
}
// Iterate over all nodes that are ops
// and check if in-place to be var is part of inputs
// if positive then do not perform inplace
for (const Node* n : graph->Nodes()) {
if (n->IsOp()) {
// Avoid searchin in op that is to be inplace
if ((n->id() != inplace_to_be_op->id())) {
auto* op = n->Op();
auto inputs = op->Inputs();
auto in_place_input = inplace_to_be_op_in->Name();
for (auto& it : inputs) {
for (auto& var_name : it.second) {
if (var_name == in_place_input) {
VLOG(3) << "MKL-DNN in-place pass: in-place var cannot be an "
"input to more than one operator";
// Checking if this particular node (to be inplaced, overwritten)
// is used anywhere else apart from inplaced op
auto input_consumers = current_op_in->outputs;
if (input_consumers.size() > 1) {
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
"be an input to multiple operators";
return;
}
// If this op was alrady inplaced in previous pass placements
// then we need to update input of next op
// but original name to be changed is gone, so we need to remember it
// on first time given op is to be inplaced
if (current_op_in->Name() != current_op_out->Name()) {
original_output_names[current_op->Name() + current_op_in->Name()] =
current_op_out->Name();
} else {
VLOG(3) << "DNNL Inplace: Current op already inplaced! ";
}
// It may be that next op is reusing some of vars, we need to
// make sure that unwanted inplace is not created
// TODO(jczaja): Make UT for that one
for (auto& n : current_op_out->outputs) {
auto& n_op_infer_inplace =
OpInfoMap::Instance().Get(n->Op()->Type()).infer_inplace_;
if ((n_op_infer_inplace == nullptr)) {
for (auto& m : n->outputs) {
if (m->Name() == current_op_in->Name()) {
VLOG(3) << "DNNL in-place pass FAIL: in-place var cannot "
"be an output to non-inplaced next op";
return;
}
}
}
}
auto original_name = inplace_to_be_op_out->Name();
inplace_to_be_op_out->RenameVar(inplace_to_be_op_in->Name());
auto original_name =
original_output_names[current_op->Name() + current_op_in->Name()];
current_op_out->RenameVar(current_op_in->Name());
// Get mapping of input to output
auto in_to_outs = infer_inplace(false); // strictly no CUDA for MKL-DNN
// TODO(jczaja): Support more complex situations
auto out_name = in_to_outs.begin()->second;
inplace_to_be_op->Op()->SetOutput(
out_name, std::vector<std::string>({inplace_to_be_op_out->Name()}));
next_op->Op()->RenameInput(original_name, inplace_to_be_op_out->Name());
current_op->Op()->SetOutput(
out_name, std::vector<std::string>({current_op_out->Name()}));
// If next op in a line is doing inplace
// then we need to update its output as well
// Get inferer of next op
// If no inferer then we are done
auto& next_op_infer_inplace =
OpInfoMap::Instance().Get(next_op->Op()->Type()).infer_inplace_;
if (next_op_infer_inplace) {
auto in_to_outs = next_op_infer_inplace(false);
auto out_name = in_to_outs.begin()->second;
auto* op = next_op->Op();
auto inputs = op->Inputs();
auto outputs = op->Outputs();
// Check if in-place happened
// for variable we changed (original name)
// TODO(jczaja): make recursive propagation of inplace
auto next_op_inplace_inputs = inputs[in_to_outs.begin()->first];
if ((next_op_inplace_inputs == outputs[in_to_outs.begin()->second]) &&
(std::find(next_op_inplace_inputs.begin(),
next_op_inplace_inputs.end(),
original_name) != next_op_inplace_inputs.end())) {
VLOG(3) << "DNNL InPlace: Next Op is in-placed , updating its "
"input "
"and output var!";
next_op->Op()->SetOutput(
out_name, std::vector<std::string>({current_op_out->Name()}));
next_op_out->RenameVar(current_op_in->Name());
// Get ops that next_op_out is linked to and update their input
auto next_op_out_consumers = next_op_out->outputs; // Has to be ops
for (auto& c : next_op_out_consumers) {
c->Op()->RenameInput(original_name, current_op_out->Name());
}
}
}
next_op->Op()->RenameInput(original_name, current_op_out->Name());
found_inplace_count++;
VLOG(3) << "MKL-DNN InPlace applied!";
VLOG(3) << "DNNL InPlace applied!";
};
gpd(graph, handler);
......
......@@ -21,6 +21,9 @@
USE_OP(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
USE_OP(relu);
namespace paddle {
namespace framework {
......@@ -62,8 +65,9 @@ class MKLDNNInplacePassTest {
bool branched) {
ProgramDesc prog;
for (auto& v : std::vector<std::string>(
{"a", "weights", "bias", "f", "g", "h", "i", "j", "k"})) {
for (auto& v :
std::vector<std::string>({"a", "weights", "bias", "f", "g", "h", "i",
"j", "k", "l", "m", "z"})) {
auto* var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
if (v == "weights" || v == "bias") {
......@@ -83,9 +87,12 @@ class MKLDNNInplacePassTest {
SetOp(&prog, "elementwise_add", "elementwise_add1",
std::vector<std::string>({"h", "i"}), std::vector<std::string>({"j"}),
mkldnn_enabled_op.compare("elementwise_add") == 0);
SetOp(&prog, "relu", "relu2", std::vector<std::string>({"j"}),
std::vector<std::string>({"k"}),
mkldnn_enabled_op.compare("softmax") == 0);
if (branched == true) {
SetOp(&prog, "softmax", "softmax2", std::vector<std::string>({"g"}),
std::vector<std::string>({"k"}),
std::vector<std::string>({"z"}),
mkldnn_enabled_op.compare("softmax") == 0);
}
......@@ -105,12 +112,11 @@ class MKLDNNInplacePassTest {
unsigned use_mkldnn_true_count = 0;
std::unordered_map<std::string, std::string> input_names;
std::unordered_map<std::string, std::string> output_names;
input_names["softmax"] = "X";
output_names["softmax"] = "Out";
input_names["batch_norm"] = "X";
output_names["batch_norm"] = "Y";
input_names["layer_norm"] = "X";
output_names["layer_norm"] = "Y";
input_names["elementwise_add"] = "X";
output_names["elementwise_add"] = "Out";
VLOG(3) << DebugString(graph);
......@@ -135,15 +141,18 @@ class MKLDNNInplacePassTest {
TEST(MKLDNNInplacePass, inplace_softmax) {
// softmax to be mkl-dnn enabled and made in-place
MKLDNNInplacePassTest().MainTest("softmax", false, 1);
}
TEST(MKLDNNInplacePass, inplace_softmax_branched) {
// softmax to be mkl-dnn enabled and made in-place
// softmax's input is shared by two branches. so no in-place
MKLDNNInplacePassTest().MainTest("softmax", true, 0);
}
TEST(MKLDNNInplacePass, inplace_elementwise_add) {
// Two elementwise_add mkl-dnn enabled op instances to be made inplace
MKLDNNInplacePassTest().MainTest("elementwise_add", false, 1);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......
......@@ -56,39 +56,34 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
y->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for Y tensor"));
const T* x_data = x->data<T>();
const T* y_data = y->data<T>();
auto src_x_tz = framework::vectorize<int64_t>(x->dims());
auto src_y_tz = framework::vectorize<int64_t>(y->dims());
auto dst_tz = framework::vectorize<int64_t>(z->dims());
std::vector<float> scales = {1.0f, 1.0f};
// Currently MKL-DNN kernel supports only Z <- X + Y, shape(X) == shape(Y)
// TODO(jczaja): Binary primitive support broadcasting, so we can support
// this in kernel
platform::BinaryMKLDNNHandler<T> handler(
dnnl::algorithm::binary_add, src_x_tz, x->format(), y->format(),
dev_ctx, ctx.GetPlace(), ctx.OutputName("Out"));
const std::string key =
platform::CreateKey(src_x_tz, ctx.OutputName("Out"));
auto src_x_memory = handler.AcquireSrcMemory(x);
auto src_y_memory = handler.AcquireSecondSrcMemory(y);
platform::SumMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
// For Inplace src and and dst are the same memory object
auto dst_memory =
x->IsSharedBufferWith(*z) ? src_x_memory : handler.AcquireDstMemory(z);
auto src_x_memory = handler.AcquireSrcMemory(
{{src_x_tz}, platform::MKLDNNGetDataType<T>(), x->format()},
paddle::platform::to_void_cast(x_data));
auto src_y_memory = handler.AcquireSecondSrcMemory(
{{src_y_tz}, platform::MKLDNNGetDataType<T>(), y->format()},
paddle::platform::to_void_cast(y_data));
auto dst_md = memory::desc({dst_tz}, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
auto sum_pd = handler.AcquireSumPrimitiveDescriptor(
{src_x_memory, src_y_memory}, scales, dst_md);
T* z_data =
z->mutable_data<T>(ctx.GetPlace(), sum_pd->dst_desc().get_size());
auto dst_memory = handler.AcquireDstMemoryFromPrimitive(z_data);
auto sum_prim = handler.AcquireSum();
auto binary_prim = handler.AcquireForwardPrimitive();
mkldnn::stream astream(mkldnn_engine);
sum_prim->execute(astream, {{MKLDNN_ARG_MULTIPLE_SRC, *src_x_memory},
{MKLDNN_ARG_MULTIPLE_SRC + 1, *src_y_memory},
{MKLDNN_ARG_DST, *dst_memory}});
std::unordered_map<int, dnnl::memory> args = {
{DNNL_ARG_SRC_0, *src_x_memory},
{DNNL_ARG_SRC_1, *src_y_memory},
{DNNL_ARG_DST, *dst_memory}};
binary_prim->execute(astream, args);
astream.wait();
z->set_layout(DataLayout::kMKLDNN);
......
cc_test(test_mkldnn_op_inplace SRCS mkldnn/test_mkldnn_op_inplace.cc DEPS op_registry softmax_op softmax scope device_context enforce executor)
cc_test(test_mkldnn_op_inplace SRCS mkldnn/test_mkldnn_op_inplace.cc DEPS op_registry elementwise_add_op softmax_op softmax scope device_context enforce executor)
......@@ -45,7 +45,8 @@ class SoftmaxMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
// Softmax may be inplace then uniq_name is no longer unique
platform::CreateKey(dims, axis, uniq_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring, md,
......@@ -60,7 +61,7 @@ class SoftmaxMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
platform::CreateKey(dims, axis, uniq_name)) {
auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md =
......@@ -95,13 +96,13 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
auto softmax_src_memory_p = handler.AcquireSrcMemory(input);
auto softmax_p = handler.AcquireForwardPrimitive();
// For Inplace src and and dst are the same memory object
auto softmax_dst_memory_p = input->Holder() == output->Holder()
auto softmax_dst_memory_p = input->IsSharedBufferWith(*output)
? softmax_src_memory_p
: handler.AcquireDstMemory(output);
mkldnn::stream astream(dev_ctx.GetEngine());
softmax_p->execute(astream, {{MKLDNN_ARG_SRC, *softmax_src_memory_p},
{MKLDNN_ARG_DST, *softmax_dst_memory_p}});
softmax_p->execute(astream, {{DNNL_ARG_SRC, *softmax_src_memory_p},
{DNNL_ARG_DST, *softmax_dst_memory_p}});
astream.wait();
const bool is_test = ctx.Attr<bool>("is_test");
......
......@@ -27,38 +27,68 @@
USE_OP(softmax);
USE_OP_DEVICE_KERNEL(softmax, MKLDNN);
USE_OP(elementwise_add);
USE_OP_DEVICE_KERNEL(elementwise_add, MKLDNN);
namespace paddle {
namespace operators {
struct InputVars {
std::string name;
framework::LoDTensor *tensor;
};
template <typename T>
bool TestMain(const platform::Place &place, const framework::DDim &dims) {
bool TestMain(const platform::Place &place, const std::string &op_type,
const framework::DDim &dims, const int num_inputs) {
framework::Scope scope;
auto *x = scope.Var("x")->GetMutable<framework::LoDTensor>();
auto *y = scope.Var("y")->GetMutable<framework::LoDTensor>();
x->Resize(dims);
y->Resize(dims);
size_t numel = static_cast<size_t>(framework::product(dims));
auto x_ptr = x->mutable_data<T>(place);
auto y_ptr = y->mutable_data<T>(place);
std::vector<InputVars> input_names = {
{"x", scope.Var("x")->GetMutable<framework::LoDTensor>()},
{"x1", num_inputs > 1
? scope.Var("x1")->GetMutable<framework::LoDTensor>()
: nullptr},
{"x2", num_inputs > 2
? scope.Var("x2")->GetMutable<framework::LoDTensor>()
: nullptr},
{"x3", num_inputs > 3
? scope.Var("x3")->GetMutable<framework::LoDTensor>()
: nullptr},
{"x4", num_inputs > 4
? scope.Var("x4")->GetMutable<framework::LoDTensor>()
: nullptr}};
auto *y = scope.Var("y")->GetMutable<framework::LoDTensor>();
// Initialize input data
std::uniform_real_distribution<T> dist(static_cast<T>(10.0),
static_cast<T>(20.0));
std::mt19937 engine;
size_t numel = static_cast<size_t>(framework::product(dims));
for (int i = 0; i < num_inputs; ++i) {
input_names[i].tensor->Resize(dims);
auto data_ptr = input_names[i].tensor->mutable_data<T>(place);
for (size_t i = 0; i < numel; ++i) {
data_ptr[i] = dist(engine);
}
}
// Initialize output
y->Resize(dims);
auto y_ptr = y->mutable_data<T>(place);
for (size_t i = 0; i < numel; ++i) {
x_ptr[i] = dist(engine);
y_ptr[i] = static_cast<T>(0);
}
auto &pool = platform::DeviceContextPool::Instance();
// Out of place (reference) computation
auto op_ref = framework::OpRegistry::CreateOp(
"softmax", {{"X", {"x"}}}, {{"Out", {"y"}}}, {{"use_mkldnn", {true}}});
auto op_ref = num_inputs > 1 ? framework::OpRegistry::CreateOp(
op_type, {{"X", {"x"}}, {"Y", {"x1"}}},
{{"Out", {"y"}}}, {{"use_mkldnn", {true}}})
: framework::OpRegistry::CreateOp(
op_type, {{"X", {"x"}}}, {{"Out", {"y"}}},
{{"use_mkldnn", {true}}});
op_ref->Run(scope, place);
pool.Get(place)->Wait();
......@@ -66,15 +96,20 @@ bool TestMain(const platform::Place &place, const framework::DDim &dims) {
auto &ref_tensor = scope.FindVar("y")->Get<framework::LoDTensor>();
// In-place (to be tested) computation
auto op = framework::OpRegistry::CreateOp(
"softmax", {{"X", {"x"}}}, {{"Out", {"x"}}}, {{"use_mkldnn", {true}}});
auto op = num_inputs > 1 ? framework::OpRegistry::CreateOp(
op_type, {{"X", {"x"}}, {"Y", {"x1"}}},
{{"Out", {"x"}}}, {{"use_mkldnn", {true}}})
: framework::OpRegistry::CreateOp(
op_type, {{"X", {"x"}}}, {{"Out", {"x"}}},
{{"use_mkldnn", {true}}});
op->Run(scope, place);
platform::DeviceContextPool::Instance().Get(place)->Wait();
// Get in-place result
auto &out_tensor = scope.FindVar("x")->Get<framework::LoDTensor>();
PADDLE_ENFORCE_EQ(
&out_tensor, x,
&out_tensor, input_names[0].tensor,
platform::errors::InvalidArgument(
"Input and output vars should share tensor for In-place test"));
......@@ -88,7 +123,13 @@ bool TestMain(const platform::Place &place, const framework::DDim &dims) {
TEST(test_softmax_inplace, cpu_place) {
framework::DDim dims({32, 64});
platform::CPUPlace p;
ASSERT_TRUE(TestMain<float>(p, dims));
ASSERT_TRUE(TestMain<float>(p, "softmax", dims, 1));
}
TEST(test_elementwise_add_inplace, cpu_place) {
framework::DDim dims({1, 12, 20, 20});
platform::CPUPlace p;
ASSERT_TRUE(TestMain<float>(p, "elementwise_add", dims, 2));
}
} // namespace operators
......
......@@ -101,6 +101,11 @@ inline void MatchShapeToLayout(framework::Tensor* tensor_in,
}
}
struct mkldnn_dummy_primitive {
struct primitive_desc {};
struct desc {};
};
inline mkldnn::memory::desc MKLDNNMemDesc(const std::vector<int64_t>& dims,
mkldnn::memory::data_type data_type,
MKLDNNMemoryFormat format) {
......
......@@ -30,7 +30,8 @@ namespace platform {
using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = mkldnn::memory;
template <typename T, typename TForward, typename TBackward>
template <typename T, typename TForward,
typename TBackward = mkldnn_dummy_primitive>
class MKLDNNHandlerT {
public:
MKLDNNHandlerT(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
......@@ -351,6 +352,35 @@ class MKLDNNHandler {
std::string key_common_;
};
template <typename T>
class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
public:
BinaryMKLDNNHandler(const dnnl::algorithm algo,
const std::vector<int64_t>& dims,
const MKLDNNMemoryFormat src0_fmt,
const MKLDNNMemoryFormat src1_fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
// TODO(jczaja): Add function checking if data already exists
auto src0_md = dnnl::memory::desc(dims, MKLDNNGetDataType<T>(), src0_fmt);
auto src1_md = dnnl::memory::desc(dims, MKLDNNGetDataType<T>(), src1_fmt);
auto dst_md =
memory::desc(dims, MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any);
this->AcquireForwardPrimitiveDescriptor(algo, src0_md, src1_md, dst_md);
}
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->src_desc(), to_void_cast<T>(input_data), "@src1_mem_p");
}
};
class SumMKLDNNHandler : public MKLDNNHandler {
public:
SumMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx,
......@@ -419,7 +449,7 @@ class ActivationMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, unique_name)) {
platform::CreateKey(dims, "a", algorithm, unique_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
......@@ -437,7 +467,7 @@ class ActivationMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, unique_name)) {
platform::CreateKey(dims, "a", algorithm, unique_name)) {
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册