提交 a76d0dd4 编写于 作者: K Krzysztof Binias

MKL-DNN activations improvements

上级 1c81301e
...@@ -41,7 +41,7 @@ namespace operators { ...@@ -41,7 +41,7 @@ namespace operators {
\ \
protected: \ protected: \
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \ std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \
auto *op = new ::paddle::framework::OpDesc(); \ auto* op = new ::paddle::framework::OpDesc(); \
op->SetType(#KERNEL_TYPE "_grad"); \ op->SetType(#KERNEL_TYPE "_grad"); \
op->SetInput("Out", Output("Out")); \ op->SetInput("Out", Output("Out")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \ op->SetInput(::paddle::framework::GradVarName("Out"), \
...@@ -54,23 +54,50 @@ namespace operators { ...@@ -54,23 +54,50 @@ namespace operators {
} \ } \
} }
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
auto it = oper.Attrs().find("use_mkldnn");
if (library == framework::LibraryType::kPlain && it != oper.Attrs().end() &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
ctx.GetPlace(), layout, library);
}
class ActivationOp : public framework::OperatorWithKernel { class ActivationOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
}
}; };
class ActivationOpGrad : public framework::OperatorWithKernel { class ActivationOpGrad : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out");
}
}; };
__attribute__((unused)) constexpr char SigmoidDoc[] = R"DOC( __attribute__((unused)) constexpr char SigmoidDoc[] = R"DOC(
...@@ -458,22 +485,21 @@ namespace ops = paddle::operators; ...@@ -458,22 +485,21 @@ namespace ops = paddle::operators;
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \ #define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
__macro(Sigmoid, sigmoid); \ __macro(Sigmoid, sigmoid); \
__macro(Relu, relu); \
__macro(Exp, exp); \ __macro(Exp, exp); \
__macro(Tanh, tanh); \
__macro(Ceil, ceil); \ __macro(Ceil, ceil); \
__macro(Floor, floor); \ __macro(Floor, floor); \
__macro(Sqrt, sqrt); \
__macro(SoftRelu, soft_relu); \ __macro(SoftRelu, soft_relu); \
__macro(Relu6, relu6); \ __macro(Relu6, relu6); \
__macro(Reciprocal, reciprocal); \ __macro(Reciprocal, reciprocal); \
__macro(HardSigmoid, hard_sigmoid); __macro(HardSigmoid, hard_sigmoid);
#define FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(__macro) \
__macro(Relu, relu); \
__macro(Tanh, tanh); \
__macro(Sqrt, sqrt);
#define FOR_EACH_OP_FUNCTOR(__macro) \ #define FOR_EACH_OP_FUNCTOR(__macro) \
__macro(LogSigmoid, logsigmoid); \ __macro(LogSigmoid, logsigmoid); \
__macro(SoftShrink, softshrink); \ __macro(SoftShrink, softshrink); \
__macro(Abs, abs); \
__macro(Cos, cos); \ __macro(Cos, cos); \
__macro(Sin, sin); \ __macro(Sin, sin); \
__macro(Round, round); \ __macro(Round, round); \
...@@ -491,32 +517,18 @@ namespace ops = paddle::operators; ...@@ -491,32 +517,18 @@ namespace ops = paddle::operators;
__macro(Swish, swish); \ __macro(Swish, swish); \
__macro(ThresholdedRelu, thresholded_relu); __macro(ThresholdedRelu, thresholded_relu);
#define FOR_EACH_MKLDNN_OP_FUNCTOR(__macro) __macro(Abs, abs);
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ #define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::OP_NAME##GradMaker); \ ::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_INPLACE_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::OP_NAME##GradMaker); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \ #define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \ REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \ ::paddle::operators::OP_NAME##OpMaker, \
::paddle::framework::DefaultGradOpDescMaker<true>); \ ::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad) REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_MKLDNN_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ops::ActivationWithMKLDNNOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ops::ActivationWithMKLDNNOpGrad)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \ REGISTER_OP_CPU_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \ act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
...@@ -531,7 +543,5 @@ namespace ops = paddle::operators; ...@@ -531,7 +543,5 @@ namespace ops = paddle::operators;
ops::grad_functor<double>>); ops::grad_functor<double>>);
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP); FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
FOR_EACH_MKLDNN_OP_FUNCTOR(REGISTER_ACTIVATION_MKLDNN_OP);
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP); FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
FOR_EACH_MKLDNN_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_MKLDNN_OP);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL); FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
...@@ -62,52 +62,5 @@ class MKLDNNActivationGradKernel ...@@ -62,52 +62,5 @@ class MKLDNNActivationGradKernel
} }
}; };
namespace { // NOLINT
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
const std::string& name) {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>(name)->type()),
ctx.GetPlace(), layout, library);
}
} // anonymous namespace
class ActivationWithMKLDNNOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "X");
}
};
class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out");
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册