提交 1c81301e 编写于 作者: K Krzysztof Binias

Update activations for MKL-DNN

上级 35e55636
......@@ -52,9 +52,11 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
mkldnn::memory::format::nchw);
// create memory primitives
auto src_memory =
auto src_memory = std::make_shared<mkldnn::memory>(
mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(src_data)));
static_cast<void *>(const_cast<float *>(src_data))));
// save source memory to device context to be referred in backward path
dev_ctx.SetBlob("InputX@eltwise_pd", src_memory);
auto dst_memory =
mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(dst_data)));
......@@ -69,7 +71,7 @@ void eltwise_forward(const ExecContext &ctx, mkldnn::algorithm algorithm,
forward_desc, mkldnn_engine);
dev_ctx.SetBlob(key_eltwise_pd, forward_pd);
auto eltwise = mkldnn::eltwise_forward(*forward_pd, src_memory, dst_memory);
auto eltwise = mkldnn::eltwise_forward(*forward_pd, *src_memory, dst_memory);
// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {eltwise};
......@@ -83,8 +85,7 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
const auto &mkldnn_engine = dev_ctx.GetEngine();
// get buffers
const auto *x = ctx.template Input<Tensor>("X");
const auto *src = x->template data<T>();
const auto *x = ctx.template Input<Tensor>("Out");
auto *dout = ctx.template Input<Tensor>(framework::GradVarName("Out"));
const auto *diff_dst = dout->template data<T>();
......@@ -103,9 +104,11 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
: platform::MKLDNNMemDesc(src_tz, mkldnn::memory::f32,
mkldnn::memory::format::nchw);
// retrieve source memory from device context
const std::shared_ptr<void> src_memory = dev_ctx.GetBlob("InputX@eltwise_pd");
auto *p_src_memory = static_cast<mkldnn::memory *>(src_memory.get());
// create memory primitives
auto src_memory = mkldnn::memory(
{data_md, mkldnn_engine}, static_cast<void *>(const_cast<float *>(src)));
auto diff_src_memory =
mkldnn::memory({data_md, mkldnn_engine},
static_cast<void *>(const_cast<float *>(diff_src)));
......@@ -128,8 +131,8 @@ void eltwise_grad(const ExecContext &ctx, mkldnn::algorithm algorithm,
auto eltwise_bwd_prim_desc = mkldnn::eltwise_backward::primitive_desc(
backward_desc, mkldnn_engine, *p_forward_pd);
auto eltwise_bwd = mkldnn::eltwise_backward(eltwise_bwd_prim_desc, src_memory,
diff_dst_memory, diff_src_memory);
auto eltwise_bwd = mkldnn::eltwise_backward(
eltwise_bwd_prim_desc, *p_src_memory, diff_dst_memory, diff_src_memory);
// push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline = {eltwise_bwd};
......
......@@ -458,21 +458,22 @@ namespace ops = paddle::operators;
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
__macro(Sigmoid, sigmoid); \
__macro(Relu, relu); \
__macro(Exp, exp); \
__macro(Tanh, tanh); \
__macro(Ceil, ceil); \
__macro(Floor, floor); \
__macro(Sqrt, sqrt); \
__macro(SoftRelu, soft_relu); \
__macro(Relu6, relu6); \
__macro(Reciprocal, reciprocal); \
__macro(HardSigmoid, hard_sigmoid);
#define FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(__macro) \
__macro(Relu, relu); \
__macro(Tanh, tanh); \
__macro(Sqrt, sqrt);
#define FOR_EACH_OP_FUNCTOR(__macro) \
__macro(LogSigmoid, logsigmoid); \
__macro(SoftShrink, softshrink); \
__macro(Abs, abs); \
__macro(Cos, cos); \
__macro(Sin, sin); \
__macro(Round, round); \
......@@ -490,18 +491,32 @@ namespace ops = paddle::operators;
__macro(Swish, swish); \
__macro(ThresholdedRelu, thresholded_relu);
#define FOR_EACH_MKLDNN_OP_FUNCTOR(__macro) __macro(Abs, abs);
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_INPLACE_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
......@@ -516,5 +531,7 @@ namespace ops = paddle::operators;
ops::grad_functor<double>>);
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
FOR_EACH_MKLDNN_OP_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_OP);
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_MKLDNN_OP);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
......@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/detail/safe_ref.h"
......@@ -61,9 +63,9 @@ class MKLDNNActivationGradKernel
};
namespace { // NOLINT
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper) {
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
......@@ -73,7 +75,7 @@ framework::OpKernelType GetKernelType(
#endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
ctx.GetPlace(), layout, library);
}
} // anonymous namespace
......@@ -89,7 +91,7 @@ class ActivationWithMKLDNNOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this);
return GetKernelType(ctx, *this, "X");
}
};
......@@ -103,7 +105,7 @@ class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this);
return GetKernelType(ctx, *this, "Out");
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册