提交 1c81301e 编写于 作者: K Krzysztof Binias

Update activations for MKL-DNN

上级 35e55636
...@@ -52,9 +52,11 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -52,9 +52,11 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
mkldnn::memory::format::nchw); mkldnn::memory::format::nchw);
// create memory primitives // create memory primitives
auto src_memory = auto src_memory = std::make_shared<mkldnn::memory>(
mkldnn::memory({data_md, mkldnn_engine}, mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(src_data))); static_cast<void *>(const_cast<float *>(src_data))));
// save source memory to device context to be referred in backward path
dev_ctx.SetBlob("InputX@eltwise_pd", src_memory);
auto dst_memory = auto dst_memory =
mkldnn::memory({data_md, mkldnn_engine}, mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(dst_data))); static_cast<void *>(const_cast<float *>(dst_data)));
...@@ -69,7 +71,7 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -69,7 +71,7 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
forward_desc, mkldnn_engine); forward_desc, mkldnn_engine);
dev_ctx.SetBlob(key_eltwise_pd, forward_pd); dev_ctx.SetBlob(key_eltwise_pd, forward_pd);
auto eltwise = mkldnn::eltwise_forward(*forward_pd, src_memory, dst_memory); auto eltwise = mkldnn::eltwise_forward(*forward_pd, *src_memory, dst_memory);
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {eltwise}; std::vector<mkldnn::primitive> pipeline = {eltwise};
...@@ -83,8 +85,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -83,8 +85,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
const auto &mkldnn_engine = dev_ctx.GetEngine(); const auto &mkldnn_engine = dev_ctx.GetEngine();
// get buffers // get buffers
const auto *x = ctx.template Input<Tensor>("X"); const auto *x = ctx.template Input<Tensor>("Out");
const auto *src = x->template data<T>();
auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out")); auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
const auto *diff_dst = dout->template data<T>(); const auto *diff_dst = dout->template data<T>();
...@@ -103,9 +104,11 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -103,9 +104,11 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32, : platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw); mkldnn::memory::format::nchw);
// retrieve source memory from device context
const std::shared_ptr<void> src_memory = dev_ctx.GetBlob("InputX@eltwise_pd");
auto *p_src_memory = static_cast<mkldnn::memory *>(src_memory.get());
// create memory primitives // create memory primitives
auto src_memory = mkldnn::memory(
{data_md, mkldnn_engine}, static_cast<void *>(const_cast<float *>(src)));
auto diff_src_memory = auto diff_src_memory =
mkldnn::memory({data_md, mkldnn_engine}, mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(diff_src))); static_cast<void *>(const_cast<float *>(diff_src)));
...@@ -128,8 +131,8 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm, ...@@ -128,8 +131,8 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc( auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
backward_desc, mkldnn_engine, *p_forward_pd); backward_desc, mkldnn_engine, *p_forward_pd);
auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, src_memory, auto eltwise_bwd = mkldnn::eltwise_backward(
diff_dst_memory, diff_src_memory); eltwise_bwd_prim_desc, *p_src_memory, diff_dst_memory, diff_src_memory);
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {eltwise_bwd}; std::vector<mkldnn::primitive> pipeline = {eltwise_bwd};
......
...@@ -458,21 +458,22 @@ namespace ops = paddle::operators; ...@@ -458,21 +458,22 @@ namespace ops = paddle::operators;
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \ #define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
__macro(Sigmoid, sigmoid); \ __macro(Sigmoid, sigmoid); \
__macro(Relu, relu); \
__macro(Exp, exp); \ __macro(Exp, exp); \
__macro(Tanh, tanh); \
__macro(Ceil, ceil); \ __macro(Ceil, ceil); \
__macro(Floor, floor); \ __macro(Floor, floor); \
__macro(Sqrt, sqrt); \
__macro(SoftRelu, soft_relu); \ __macro(SoftRelu, soft_relu); \
__macro(Relu6, relu6); \ __macro(Relu6, relu6); \
__macro(Reciprocal, reciprocal); \ __macro(Reciprocal, reciprocal); \
__macro(HardSigmoid, hard_sigmoid); __macro(HardSigmoid, hard_sigmoid);
#define FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(__macro) \
__macro(Relu, relu); \
__macro(Tanh, tanh); \
__macro(Sqrt, sqrt);
#define FOR_EACH_OP_FUNCTOR(__macro) \ #define FOR_EACH_OP_FUNCTOR(__macro) \
__macro(LogSigmoid, logsigmoid); \ __macro(LogSigmoid, logsigmoid); \
__macro(SoftShrink, softshrink); \ __macro(SoftShrink, softshrink); \
__macro(Abs, abs); \
__macro(Cos, cos); \ __macro(Cos, cos); \
__macro(Sin, sin); \ __macro(Sin, sin); \
__macro(Round, round); \ __macro(Round, round); \
...@@ -490,18 +491,32 @@ namespace ops = paddle::operators; ...@@ -490,18 +491,32 @@ namespace ops = paddle::operators;
__macro(Swish, swish); \ __macro(Swish, swish); \
__macro(ThresholdedRelu, thresholded_relu); __macro(ThresholdedRelu, thresholded_relu);
#define FOR_EACH_MKLDNN_OP_FUNCTOR(__macro) __macro(Abs, abs);
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ #define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::OP_NAME##GradMaker); \ ::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_INPLACE_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ #define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##OpMaker, \
::paddle::framework::DefaultGradOpDescMaker<true>); \ ::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \ REGISTER_OP_CPU_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \ act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
...@@ -516,5 +531,7 @@ namespace ops = paddle::operators; ...@@ -516,5 +531,7 @@ namespace ops = paddle::operators;
ops::grad_functor<double>>); ops::grad_functor<double>>);
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP); FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
FOR_EACH_MKLDNN_OP_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_OP);
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP); FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_MKLDNN_OP);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h" #include "paddle/fluid/operators/detail/safe_ref.h"
...@@ -61,9 +63,9 @@ class MKLDNNActivationGradKernel ...@@ -61,9 +63,9 @@ class MKLDNNActivationGradKernel
}; };
namespace { // NOLINT namespace { // NOLINT
framework::OpKernelType GetKernelType( framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::ExecutionContext& ctx, const framework::OperatorWithKernel& oper,
const framework::OperatorWithKernel& oper) { const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain && if (library == framework::LibraryType::kPlain &&
...@@ -73,7 +75,7 @@ framework::OpKernelType GetKernelType( ...@@ -73,7 +75,7 @@ framework::OpKernelType GetKernelType(
#endif #endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()), framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
ctx.GetPlace(), layout, library); ctx.GetPlace(), layout, library);
} }
} // anonymous namespace } // anonymous namespace
...@@ -89,7 +91,7 @@ class ActivationWithMKLDNNOp : public framework::OperatorWithKernel { ...@@ -89,7 +91,7 @@ class ActivationWithMKLDNNOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this); return GetKernelType(ctx, *this, "X");
} }
}; };
...@@ -103,7 +105,7 @@ class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel { ...@@ -103,7 +105,7 @@ class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this); return GetKernelType(ctx, *this, "Out");
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册