提交 792d3b24 编写于 作者: M mozga-intel

MKLDNN layout: Support for activation operator

上级 d7345959
...@@ -12,16 +12,20 @@ ...@@ -12,16 +12,20 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "mkldnn.hpp"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mkldnn_activation_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using paddle::framework::Tensor; using framework::DataLayout;
using paddle::platform::MKLDNNDeviceContext; using framework::Tensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
using platform::MKLDNNDeviceContext;
using platform::to_void_cast;
namespace { namespace {
std::string gethash(const mkldnn::memory::dims &operand_dims, std::string gethash(const mkldnn::memory::dims &operand_dims,
...@@ -35,188 +39,260 @@ std::string gethash(const mkldnn::memory::dims &operand_dims, ...@@ -35,188 +39,260 @@ std::string gethash(const mkldnn::memory::dims &operand_dims,
}; };
return dim2str(operand_dims) + std::to_string(algorithm); return dim2str(operand_dims) + std::to_string(algorithm);
} }
} // namespace
template <typename Functor>
class MKLDNNActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *x = ctx.Input<Tensor>("X");
PADDLE_ENFORCE(x->layout() == DataLayout::kMKLDNN &&
x->format() != memory::format::format_undef,
"Wrong layout/format set for Input x tensor");
Functor functor;
auto attrs = functor.GetAttrs();
for (auto &attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
functor(ctx);
}
};
template <typename T, typename ExecContext> template <typename Functor>
void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, class MKLDNNActivationGradKernel
const T alpha = 0, const T beta = 0) { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
PADDLE_ENFORCE(diff_y->layout() == DataLayout::kMKLDNN &&
diff_y->format() != memory::format::format_undef,
"Wrong layout/format set for Input OutGrad tensor");
Functor functor;
auto attrs = functor.GetAttrs();
for (auto &attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
functor(ctx);
}
};
template <typename T>
void eltwise_forward(const framework::ExecutionContext &ctx,
mkldnn::algorithm algorithm, const T alpha = 0,
const T beta = 0) {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine(); const auto &mkldnn_engine = dev_ctx.GetEngine();
// get buffers const auto *x = ctx.Input<Tensor>("X");
const auto *src = ctx.template Input<Tensor>("X"); auto *y = ctx.Output<Tensor>("Out");
const auto *src_data = src->template data<T>();
auto *dst = ctx.template Output<Tensor>("Out"); const T *x_data = x->data<T>();
T *dst_data = dst->template mutable_data<T>(ctx.GetPlace()); T *y_data = y->mutable_data<T>(ctx.GetPlace());
// get memory dim PADDLE_ENFORCE(x->dims().size() == 2 || x->dims().size() == 4,
PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4,
"Input dim must be with 2 or 4"); "Input dim must be with 2 or 4");
std::vector<int> src_tz = framework::vectorize2int(src->dims());
std::vector<int> src_tz = framework::vectorize2int(x->dims());
auto src_format =
src_tz.size() == 2 ? mkldnn::memory::format::nc : x->format();
const std::string key = gethash(src_tz, algorithm); const std::string key = gethash(src_tz, algorithm);
const std::string key_src_data = const std::string key_src_data =
key + ctx.op().Output("Out") + "@eltwise_fwd_src_data"; key + ctx.op().Output("Out") + "@eltwise_fwd_src_data";
const std::string key_src_mem = key + "@eltwise_fwd_src_mem"; const std::string key_src_layout =
const std::string key_dst_mem = key + "@eltwise_fwd_dst_mem"; key + ctx.op().Output("Out") + "@eltwise_fwd_src_layout";
const std::string key_fwd = key + "@eltwise_fwd"; const std::string key_with_layout = key + std::to_string(src_format);
const std::string key_src_mem = key_with_layout + "@eltwise_fwd_src_mem";
const std::string key_dst_mem = key_with_layout + "@eltwise_fwd_dst_mem";
const std::string key_fwd = key_with_layout + "@eltwise_fwd";
const std::string key_fwd_pd = key_with_layout + "@eltwise_fwd_pd";
// save input data and layout to be referred in backward path
auto p_src_data = std::make_shared<const T *>(x_data);
dev_ctx.SetBlob(key_src_data, p_src_data);
auto p_src_layout = std::make_shared<memory::format>(src_format);
dev_ctx.SetBlob(key_src_layout, p_src_layout);
auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>( auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
dev_ctx.GetBlob(key_fwd)); dev_ctx.GetBlob(key_fwd));
// save input data to be referred in backward path std::shared_ptr<memory> dst_memory;
auto p_src_data = std::make_shared<const T *>(src_data);
dev_ctx.SetBlob(key_src_data, p_src_data);
if (p_fwd == nullptr) { if (p_fwd == nullptr) {
// create memory description // create mkldnn memory for input X
auto data_md = src_tz.size() == 2 auto src_md = platform::MKLDNNMemDesc(
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, src_tz, platform::MKLDNNGetDataType<T>(), src_format);
mkldnn::memory::format::nc) auto src_memory = std::shared_ptr<memory>(
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, new memory({src_md, mkldnn_engine}, to_void_cast(x_data)));
mkldnn::memory::format::nchw); // save src_memory to be referred in backward path
dev_ctx.SetBlob(key_src_mem, src_memory);
// create memory primitives
auto p_src_mem = std::make_shared<mkldnn::memory>(mkldnn::memory( // create primitive descriptor for activation forward and save it
{data_md, mkldnn_engine}, platform::to_void_cast(src_data))); auto forward_desc = mkldnn::eltwise_forward::desc(
dev_ctx.SetBlob(key_src_mem, p_src_mem); mkldnn::prop_kind::forward_training, algorithm,
src_memory->get_primitive_desc().desc(), alpha, beta);
auto p_dst_mem = std::make_shared<mkldnn::memory>(mkldnn::memory( auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
{data_md, mkldnn_engine}, platform::to_void_cast(dst_data))); forward_desc, mkldnn_engine);
dev_ctx.SetBlob(key_dst_mem, p_dst_mem);
// save prim desc into global device context to be referred in backward path
auto fwd_desc = mkldnn::eltwise_forward::desc( dev_ctx.SetBlob(key_fwd_pd, forward_pd);
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
auto p_fwd_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>( // create mkldnn memory for output y
fwd_desc, mkldnn_engine); dst_memory =
const std::string key_fwd_pd = key + "eltwise_fwd_pd"; std::make_shared<memory>(forward_pd->dst_primitive_desc(), y_data);
dev_ctx.SetBlob(key_fwd_pd, p_fwd_pd);
p_fwd = std::make_shared<mkldnn::eltwise_forward>( dev_ctx.SetBlob(key_dst_mem, dst_memory);
*p_fwd_pd, *(p_src_mem.get()), *(p_dst_mem.get()));
// create activation primitive
p_fwd = std::make_shared<mkldnn::eltwise_forward>(*forward_pd, *src_memory,
*dst_memory);
dev_ctx.SetBlob(key_fwd, p_fwd); dev_ctx.SetBlob(key_fwd, p_fwd);
} else { } else {
// primitives already exist // primitives already exist
auto p_src_mem = auto src_memory =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem)); std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
PADDLE_ENFORCE(p_src_mem != nullptr, PADDLE_ENFORCE(src_memory != nullptr,
"Fail to find eltwise p_src_mem in device context."); "Fail to find eltwise src_memory in device context.");
auto p_dst_mem = dst_memory =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_dst_mem)); std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_dst_mem));
PADDLE_ENFORCE(p_dst_mem != nullptr, PADDLE_ENFORCE(dst_memory != nullptr,
"Fail to find eltwise p_src_mem in device context."); "Fail to find eltwise dst_memory in device context.");
p_src_mem->set_data_handle(platform::to_void_reinterpret_cast(src_data)); src_memory->set_data_handle(platform::to_void_cast(x_data));
p_dst_mem->set_data_handle(dst_data); dst_memory->set_data_handle(y_data);
} }
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {*(p_fwd.get())}; std::vector<primitive> pipeline;
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); pipeline.push_back(*p_fwd);
stream(stream::kind::eager).submit(pipeline).wait();
y->set_layout(DataLayout::kMKLDNN);
y->set_format(GetMKLDNNFormat(*dst_memory));
} }
template <typename T, typename ExecContext> template <typename T>
void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, void eltwise_grad(const framework::ExecutionContext &ctx,
const T alpha = 0, const T beta = 0) { mkldnn::algorithm algorithm, const T alpha = 0,
const T beta = 0) {
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto &mkldnn_engine = dev_ctx.GetEngine(); const auto &mkldnn_engine = dev_ctx.GetEngine();
// get buffers const auto *diff_y = ctx.Input<Tensor>(framework::GradVarName("Out"));
const auto *out = ctx.template Input<Tensor>("Out"); auto *diff_x = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
const auto *diff_dst = dout->template data<T>();
auto *dx = const T *diff_y_data = diff_y->data<T>();
ctx.template Output<framework::Tensor>(framework::GradVarName("X")); T *diff_x_data = diff_x->mutable_data<T>(ctx.GetPlace());
const T *diff_src = dx->template mutable_data<T>(ctx.GetPlace());
// get memory dim std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims());
std::vector<int> src_tz = framework::vectorize2int(out->dims());
const std::string key = gethash(src_tz, algorithm); auto diff_y_format =
const std::string key_diff_src_mem = key + "@eltwise_diff_src_mem"; diff_dst_tz.size() == 2 ? mkldnn::memory::format::nc : diff_y->format();
const std::string key_diff_dst_mem = key + "@eltwise_diff_dst_mem";
const std::string key_grad = key + "@eltwise_grad";
const std::string key = gethash(diff_dst_tz, algorithm);
const std::string key_src_data = const std::string key_src_data =
key + ctx.op().Input("Out") + "@eltwise_fwd_src_data"; key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
const std::string key_src_layout =
key + ctx.op().Input("Out") + "@eltwise_fwd_src_layout";
const auto p_src_layout =
std::static_pointer_cast<memory::format>(dev_ctx.GetBlob(key_src_layout));
const std::string key_src_mem =
key + std::to_string(*p_src_layout) + "@eltwise_fwd_src_mem";
const std::string key_fwd_pd =
key + std::to_string(*p_src_layout) + "@eltwise_fwd_pd";
const std::string key_with_layouts =
key + std::to_string(*p_src_layout) + "-" + std::to_string(diff_y_format);
const std::string key_diff_src_mem =
key_with_layouts + "@eltwise_diff_src_mem";
const std::string key_diff_dst_mem =
key_with_layouts + "@eltwise_diff_dst_mem";
const std::string key_grad = key_with_layouts + "@eltwise_grad";
const auto p_src_data = const auto p_src_data =
std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data)); std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data));
const std::string key_src_mem = key + "@eltwise_fwd_src_mem"; auto src_memory =
auto p_src_mem =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem)); std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
p_src_mem->set_data_handle(*p_src_data.get()); PADDLE_ENFORCE(src_memory != nullptr,
"Fail to find src_memory in device context");
src_memory->set_data_handle(*p_src_data.get());
std::shared_ptr<memory> diff_src_memory;
auto p_grad = std::static_pointer_cast<mkldnn::eltwise_forward::primitive>( auto p_grad = std::static_pointer_cast<mkldnn::eltwise_backward>(
dev_ctx.GetBlob(key_grad)); dev_ctx.GetBlob(key_grad));
if (p_grad == nullptr) { if (p_grad == nullptr) {
// create memory description // create mkldnn memory for input diff_y
auto data_md = src_tz.size() == 2 auto diff_dst_md = platform::MKLDNNMemDesc(
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, diff_dst_tz, platform::MKLDNNGetDataType<T>(), diff_y_format);
mkldnn::memory::format::nc) auto diff_dst_memory = std::shared_ptr<memory>(
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, new memory({diff_dst_md, mkldnn_engine}, to_void_cast(diff_y_data)));
mkldnn::memory::format::nchw); dev_ctx.SetBlob(key_diff_dst_mem, diff_dst_memory);
// create memory primitives // retrieve eltwise primitive desc from device context
std::shared_ptr<void> p_diff_src_mem = auto forward_pd =
std::make_shared<mkldnn::memory>(mkldnn::memory( std::static_pointer_cast<mkldnn::eltwise_forward::primitive_desc>(
{data_md, mkldnn_engine}, platform::to_void_cast(diff_src))); dev_ctx.GetBlob(key_fwd_pd));
dev_ctx.SetBlob(key_diff_src_mem, p_diff_src_mem); PADDLE_ENFORCE(forward_pd != nullptr,
std::shared_ptr<void> p_diff_dst_mem = "Fail to find eltwise_fwd_pd in device context");
std::make_shared<mkldnn::memory>(mkldnn::memory(
{data_md, mkldnn_engine}, platform::to_void_cast(diff_dst))); // ceate primitive descriptor for activation backward
dev_ctx.SetBlob(key_diff_dst_mem, p_diff_dst_mem); auto backward_desc = mkldnn::eltwise_backward::desc(
algorithm, diff_dst_memory->get_primitive_desc().desc(),
auto bwd_desc = mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, src_memory->get_primitive_desc().desc(), alpha, beta);
alpha, beta); auto backward_pd = mkldnn::eltwise_backward::primitive_desc(
backward_desc, mkldnn_engine, *forward_pd);
const std::string key_fwd_pd = key + "eltwise_fwd_pd";
auto *p_fwd_pd = static_cast<mkldnn::eltwise_forward::primitive_desc *>( // create mkldnn memory for output diff_src
dev_ctx.GetBlob(key_fwd_pd).get()); diff_src_memory = std::make_shared<memory>(
backward_pd.diff_src_primitive_desc(), diff_x_data);
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc( dev_ctx.SetBlob(key_diff_src_mem, diff_src_memory);
bwd_desc, mkldnn_engine, *p_fwd_pd);
// create activation backward primitive
p_grad = std::make_shared<mkldnn::eltwise_backward>( p_grad = std::make_shared<mkldnn::eltwise_backward>(
eltwise_bwd_prim_desc, *static_cast<mkldnn::memory *>(p_src_mem.get()), backward_pd, *src_memory, *diff_dst_memory, *diff_src_memory);
*(static_cast<mkldnn::memory *>(p_diff_dst_mem.get())), dev_ctx.SetBlob(key_grad, p_grad);
*(static_cast<mkldnn::memory *>(p_diff_src_mem.get())));
} else { } else {
// primitives already exist // primitives already exist
auto p_diff_src_mem = std::static_pointer_cast<mkldnn::memory>( diff_src_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_diff_src_mem)); dev_ctx.GetBlob(key_diff_src_mem));
auto p_diff_dst_mem = std::static_pointer_cast<mkldnn::memory>( auto diff_dst_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_diff_dst_mem)); dev_ctx.GetBlob(key_diff_dst_mem));
p_diff_src_mem->set_data_handle( diff_src_memory->set_data_handle(
platform::to_void_reinterpret_cast(diff_src)); platform::to_void_reinterpret_cast(diff_x_data));
p_diff_dst_mem->set_data_handle( diff_dst_memory->set_data_handle(
platform::to_void_reinterpret_cast(diff_dst)); platform::to_void_reinterpret_cast(diff_y_data));
} }
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {*(p_grad.get())}; std::vector<primitive> pipeline;
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); pipeline.push_back(*p_grad);
stream(stream::kind::eager).submit(pipeline).wait();
diff_x->set_layout(DataLayout::kMKLDNN);
diff_x->set_format(GetMKLDNNFormat(*diff_src_memory));
} }
} // anonymous namespace
template <typename T, mkldnn::algorithm algorithm> template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationFunc : public BaseActivationFunctor<T> { struct MKLDNNActivationFunc : public BaseActivationFunctor<T> {
template <typename ExecContext> void operator()(const framework::ExecutionContext &ctx) const {
void operator()(const ExecContext &ctx) const {
eltwise_forward<T>(ctx, algorithm); eltwise_forward<T>(ctx, algorithm);
} }
}; };
template <typename T, mkldnn::algorithm algorithm> template <typename T, mkldnn::algorithm algorithm>
struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> { struct MKLDNNActivationGradFunc : public BaseActivationFunctor<T> {
template <typename ExecContext> void operator()(const framework::ExecutionContext &ctx) const {
void operator()(const ExecContext &ctx) const {
eltwise_grad<T>(ctx, algorithm); eltwise_grad<T>(ctx, algorithm);
} }
}; };
......
...@@ -19,18 +19,20 @@ limitations under the License. */ ...@@ -19,18 +19,20 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ using paddle::framework::Tensor;
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \ #define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
public: \ class OP_NAME##OpMaker \
void Make() override { \ : public ::paddle::framework::OpProtoAndCheckerMaker { \
AddInput("X", "Input of " #OP_NAME " operator"); \ public: \
AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \ void Make() override { \
AddAttr<bool>("use_mkldnn", \ AddInput("X", "Input of " #OP_NAME " operator"); \
"(default false) Only used in mkldnn kernel") \ AddOutput("Out", "Output of " #OP_NAME " operator").Reuse("X"); \
.SetDefault(false); \ AddAttr<bool>("use_mkldnn", \
AddComment(OP_COMMENT); \ "(bool, default false) Only used in mkldnn kernel") \
} \ .SetDefault(false); \
AddComment(#OP_COMMENT); \
} \
} }
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \ #define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
...@@ -58,7 +60,6 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, ...@@ -58,7 +60,6 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper, const framework::OperatorWithKernel& oper,
const std::string& name) { const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn"); auto it = oper.Attrs().find("use_mkldnn");
...@@ -82,6 +83,7 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -82,6 +83,7 @@ class ActivationOp : public framework::OperatorWithKernel {
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X"); return GetKernelType(ctx, *this, "X");
...@@ -96,6 +98,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel { ...@@ -96,6 +98,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
} }
protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out"); return GetKernelType(ctx, *this, "Out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册