提交 d8bd436f 编写于 作者: K Krzysztof Binias

Fixed tests

上级 a64b312e
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mkldnn_activation_op.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,11 +26,6 @@ class ActivationOp : public framework::OperatorWithKernel { ...@@ -25,11 +26,6 @@ class ActivationOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", ctx->GetInputDim("X")); ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return ActivationHelper().GetKernelType(ctx, *this);
}
}; };
class ActivationOpGrad : public framework::OperatorWithKernel { class ActivationOpGrad : public framework::OperatorWithKernel {
...@@ -39,11 +35,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel { ...@@ -39,11 +35,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return ActivationHelper().GetKernelType(ctx, *this);
}
}; };
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker { class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -546,11 +537,11 @@ REGISTER_OP(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker, ...@@ -546,11 +537,11 @@ REGISTER_OP(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker,
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad, REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad, REGISTER_OP(relu, ops::ActivationWithMKLDNNOp, ops::ReluOpMaker, relu_grad,
ops::ActivationOpGrad); ops::ActivationWithMKLDNNOpGrad);
REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, REGISTER_OP(tanh, ops::ActivationWithMKLDNNOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationOpGrad); ops::ActivationWithMKLDNNOpGrad);
REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
tanh_shrink_grad, ops::ActivationOpGrad); tanh_shrink_grad, ops::ActivationOpGrad);
...@@ -558,11 +549,11 @@ REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, ...@@ -558,11 +549,11 @@ REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker, REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker,
softshrink_grad, ops::ActivationOpGrad); softshrink_grad, ops::ActivationOpGrad);
REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, REGISTER_OP(sqrt, ops::ActivationWithMKLDNNOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationOpGrad); ops::ActivationWithMKLDNNOpGrad);
REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad, REGISTER_OP(abs, ops::ActivationWithMKLDNNOp, ops::AbsOpMaker, abs_grad,
ops::ActivationOpGrad); ops::ActivationWithMKLDNNOpGrad);
REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad, REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
......
...@@ -24,25 +24,6 @@ limitations under the License. */ ...@@ -24,25 +24,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
class ActivationHelper {
public:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper) const {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace(), layout, library);
}
};
template <typename DeviceContext, typename Functor> template <typename DeviceContext, typename Functor>
class ActivationKernel class ActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> { : public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......
...@@ -60,5 +60,52 @@ class MKLDNNActivationGradKernel ...@@ -60,5 +60,52 @@ class MKLDNNActivationGradKernel
} }
}; };
namespace {
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper) {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace(), layout, library);
}
} // anonymous namespace
class ActivationWithMKLDNNOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this);
}
};
class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this);
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册