未验证 提交 7205d331 编写于 作者: T tensor-tang 提交者: GitHub

Merge pull request #10597 from kbinias/mkldnn-activations-improvments

Update activations for MKL-DNN
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "mkldnn.hpp" #include "mkldnn.hpp"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mkldnn_activation_op.h" #include "paddle/fluid/operators/mkldnn_activation_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -23,6 +24,18 @@ using paddle::framework::Tensor; ...@@ -23,6 +24,18 @@ using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext; using paddle::platform::MKLDNNDeviceContext;
namespace { namespace {
std::string gethash(const mkldnn::memory::dims &operand_dims,
const mkldnn::algorithm algorithm) {
auto dim2str = [](const mkldnn::memory::dims &operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
};
return dim2str(operand_dims) + std::to_string(algorithm);
}
template <typename T, typename ExecContext> template <typename T, typename ExecContext>
void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
const T alpha = 0, const T beta = 0) { const T alpha = 0, const T beta = 0) {
...@@ -37,42 +50,70 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -37,42 +50,70 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
const auto *src_data = src->template data<T>(); const auto *src_data = src->template data<T>();
auto *dst = ctx.template Output<Tensor>("Out"); auto *dst = ctx.template Output<Tensor>("Out");
const T *dst_data = dst->template mutable_data<T>(ctx.GetPlace()); T *dst_data = dst->template mutable_data<T>(ctx.GetPlace());
// get memory dim // get memory dim
PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4, PADDLE_ENFORCE(src->dims().size() == 2 || src->dims().size() == 4,
"Input dim must be with 2 or 4"); "Input dim must be with 2 or 4");
std::vector<int> src_tz = framework::vectorize2int(src->dims()); std::vector<int> src_tz = framework::vectorize2int(src->dims());
// create memory description const std::string key = gethash(src_tz, algorithm);
auto data_md = src_tz.size() == 2 const std::string key_src_data =
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, key + ctx.op().Output("Out") + "@eltwise_fwd_src_data";
mkldnn::memory::format::nc) const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, const std::string key_dst_mem = key + "@eltwise_fwd_dst_mem";
mkldnn::memory::format::nchw); const std::string key_fwd = key + "@eltwise_fwd";
// create memory primitives auto p_fwd = std::static_pointer_cast<mkldnn::eltwise_forward>(
auto src_memory = dev_ctx.GetBlob(key_fwd));
mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(src_data))); // save input data to be referred in backward path
auto dst_memory = auto p_src_data = std::make_shared<const T *>(src_data);
mkldnn::memory({data_md, mkldnn_engine}, dev_ctx.SetBlob(key_src_data, p_src_data);
static_cast<void *>(const_cast<float *>(dst_data)));
if (p_fwd == nullptr) {
auto forward_desc = mkldnn::eltwise_forward::desc( // create memory description
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta); auto data_md = src_tz.size() == 2
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
// save prim desc into global device context to be referred in backward path mkldnn::memory::format::nc)
const std::string key = ctx.op().Output("Out"); : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
const std::string key_eltwise_pd = key + "@eltwise_pd"; mkldnn::memory::format::nchw);
auto forward_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
forward_desc, mkldnn_engine); // create memory primitives
dev_ctx.SetBlob(key_eltwise_pd, forward_pd); auto p_src_mem = std::make_shared<mkldnn::memory>(mkldnn::memory(
{data_md, mkldnn_engine}, platform::to_void_cast(src_data)));
auto eltwise = mkldnn::eltwise_forward(*forward_pd, src_memory, dst_memory); dev_ctx.SetBlob(key_src_mem, p_src_mem);
auto p_dst_mem = std::make_shared<mkldnn::memory>(mkldnn::memory(
{data_md, mkldnn_engine}, platform::to_void_cast(dst_data)));
dev_ctx.SetBlob(key_dst_mem, p_dst_mem);
auto fwd_desc = mkldnn::eltwise_forward::desc(
mkldnn::prop_kind::forward_training, algorithm, data_md, alpha, beta);
auto p_fwd_pd = std::make_shared<mkldnn::eltwise_forward::primitive_desc>(
fwd_desc, mkldnn_engine);
const std::string key_fwd_pd = key + "eltwise_fwd_pd";
dev_ctx.SetBlob(key_fwd_pd, p_fwd_pd);
p_fwd = std::make_shared<mkldnn::eltwise_forward>(
*p_fwd_pd, *(p_src_mem.get()), *(p_dst_mem.get()));
dev_ctx.SetBlob(key_fwd, p_fwd);
} else {
// primitives already exist
auto p_src_mem =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
PADDLE_ENFORCE(p_src_mem != nullptr,
"Fail to find eltwise p_src_mem in device context.");
auto p_dst_mem =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_dst_mem));
PADDLE_ENFORCE(p_dst_mem != nullptr,
"Fail to find eltwise p_src_mem in device context.");
p_src_mem->set_data_handle(platform::to_void_reinterpret_cast(src_data));
p_dst_mem->set_data_handle(dst_data);
}
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {eltwise}; std::vector<mkldnn::primitive> pipeline = {*(p_fwd.get())};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
} }
...@@ -83,8 +124,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -83,8 +124,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
const auto &mkldnn_engine = dev_ctx.GetEngine(); const auto &mkldnn_engine = dev_ctx.GetEngine();
// get buffers // get buffers
const auto *x = ctx.template Input<Tensor>("X"); const auto *out = ctx.template Input<Tensor>("Out");
const auto *src = x->template data<T>();
auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out")); auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
const auto *diff_dst = dout->template data<T>(); const auto *diff_dst = dout->template data<T>();
...@@ -94,45 +134,73 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -94,45 +134,73 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
const T *diff_src = dx->template mutable_data<T>(ctx.GetPlace()); const T *diff_src = dx->template mutable_data<T>(ctx.GetPlace());
// get memory dim // get memory dim
std::vector<int> src_tz = framework::vectorize2int(x->dims()); std::vector<int> src_tz = framework::vectorize2int(out->dims());
// create memory description const std::string key = gethash(src_tz, algorithm);
auto data_md = src_tz.size() == 2 const std::string key_diff_src_mem = key + "@eltwise_diff_src_mem";
? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, const std::string key_diff_dst_mem = key + "@eltwise_diff_dst_mem";
mkldnn::memory::format::nc) const std::string key_grad = key + "@eltwise_grad";
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw); const std::string key_src_data =
key + ctx.op().Input("Out") + "@eltwise_fwd_src_data";
// create memory primitives const auto p_src_data =
auto src_memory = mkldnn::memory( std::static_pointer_cast<T *>(dev_ctx.GetBlob(key_src_data));
{data_md, mkldnn_engine}, static_cast<void *>(const_cast<float *>(src)));
auto diff_src_memory = const std::string key_src_mem = key + "@eltwise_fwd_src_mem";
mkldnn::memory({data_md, mkldnn_engine}, auto p_src_mem =
static_cast<void *>(const_cast<float *>(diff_src))); std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(key_src_mem));
auto diff_dst_memory = p_src_mem->set_data_handle(*p_src_data.get());
mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(diff_dst))); auto p_grad = std::static_pointer_cast<mkldnn::eltwise_forward::primitive>(
dev_ctx.GetBlob(key_grad));
auto backward_desc =
mkldnn::eltwise_backward::desc(algorithm, data_md, data_md, alpha, beta); if (p_grad == nullptr) {
// create memory description
// retrieve eltwise primitive desc from device context auto data_md = src_tz.size() == 2
const std::string key = ctx.op().Input("Out"); ? platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
const std::string key_eltwise_pd = key + "@eltwise_pd"; mkldnn::memory::format::nc)
const std::shared_ptr<void> forward_pd = dev_ctx.GetBlob(key_eltwise_pd); : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
PADDLE_ENFORCE(forward_pd != nullptr, mkldnn::memory::format::nchw);
"Fail to find eltwise_pd in device context");
auto *p_forward_pd = // create memory primitives
static_cast<mkldnn::eltwise_forward::primitive_desc *>(forward_pd.get()); std::shared_ptr<void> p_diff_src_mem =
std::make_shared<mkldnn::memory>(mkldnn::memory(
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc( {data_md, mkldnn_engine}, platform::to_void_cast(diff_src)));
backward_desc, mkldnn_engine, *p_forward_pd); dev_ctx.SetBlob(key_diff_src_mem, p_diff_src_mem);
std::shared_ptr<void> p_diff_dst_mem =
auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, src_memory, std::make_shared<mkldnn::memory>(mkldnn::memory(
diff_dst_memory, diff_src_memory); {data_md, mkldnn_engine}, platform::to_void_cast(diff_dst)));
dev_ctx.SetBlob(key_diff_dst_mem, p_diff_dst_mem);
auto bwd_desc = mkldnn::eltwise_backward::desc(algorithm, data_md, data_md,
alpha, beta);
const std::string key_fwd_pd = key + "eltwise_fwd_pd";
auto *p_fwd_pd = static_cast<mkldnn::eltwise_forward::primitive_desc *>(
dev_ctx.GetBlob(key_fwd_pd).get());
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
bwd_desc, mkldnn_engine, *p_fwd_pd);
p_grad = std::make_shared<mkldnn::eltwise_backward>(
eltwise_bwd_prim_desc, *static_cast<mkldnn::memory *>(p_src_mem.get()),
*(static_cast<mkldnn::memory *>(p_diff_dst_mem.get())),
*(static_cast<mkldnn::memory *>(p_diff_src_mem.get())));
} else {
// primitives already exist
auto p_diff_src_mem = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_diff_src_mem));
auto p_diff_dst_mem = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_diff_dst_mem));
p_diff_src_mem->set_data_handle(
platform::to_void_reinterpret_cast(diff_src));
p_diff_dst_mem->set_data_handle(
platform::to_void_reinterpret_cast(diff_dst));
}
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {eltwise_bwd}; std::vector<mkldnn::primitive> pipeline = {*(p_grad.get())};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
} }
} // anonymous namespace } // anonymous namespace
......
...@@ -41,7 +41,7 @@ namespace operators { ...@@ -41,7 +41,7 @@ namespace operators {
\ \
protected: \ protected: \
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \ std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \
auto *op = new ::paddle::framework::OpDesc(); \ auto* op = new ::paddle::framework::OpDesc(); \
op->SetType(#KERNEL_TYPE "_grad"); \ op->SetType(#KERNEL_TYPE "_grad"); \
op->SetInput("Out", Output("Out")); \ op->SetInput("Out", Output("Out")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \ op->SetInput(::paddle::framework::GradVarName("Out"), \
...@@ -54,23 +54,50 @@ namespace operators { ...@@ -54,23 +54,50 @@ namespace operators {
} \ } \
} }
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
ctx.GetPlace(), layout, library);
}
class ActivationOp : public framework::OperatorWithKernel { class ActivationOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
}
}; };
class ActivationOpGrad : public framework::OperatorWithKernel { class ActivationOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out");
}
}; };
__attribute__((unused)) constexpr char SigmoidDoc[] = R"DOC( __attribute__((unused)) constexpr char SigmoidDoc[] = R"DOC(
......
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
...@@ -60,52 +62,5 @@ class MKLDNNActivationGradKernel ...@@ -60,52 +62,5 @@ class MKLDNNActivationGradKernel
} }
}; };
namespace { // NOLINT
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper) {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace(), layout, library);
}
} // anonymous namespace
class ActivationWithMKLDNNOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this);
}
};
class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -38,6 +38,11 @@ void* to_void_cast(const Type* t) { ...@@ -38,6 +38,11 @@ void* to_void_cast(const Type* t) {
return static_cast<void*>(const_cast<Type*>(t)); return static_cast<void*>(const_cast<Type*>(t));
} }
template <typename Type>
void* to_void_reinterpret_cast(const Type* t) {
return reinterpret_cast<void*>(const_cast<Type*>(t));
}
template <class Type> template <class Type>
using tf_desc = typename Type::desc; using tf_desc = typename Type::desc;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册