提交 d8bd436f 编写于 作者: K Krzysztof Binias

Fixed tests

上级 a64b312e
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/mkldnn_activation_op.h"
namespace paddle {
namespace operators {
......@@ -25,11 +26,6 @@ class ActivationOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return ActivationHelper().GetKernelType(ctx, *this);
}
};
class ActivationOpGrad : public framework::OperatorWithKernel {
......@@ -39,11 +35,6 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return ActivationHelper().GetKernelType(ctx, *this);
}
};
class SigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -546,11 +537,11 @@ REGISTER_OP(logsigmoid, ops::ActivationOp, ops::LogSigmoidOpMaker,
REGISTER_OP(exp, ops::ActivationOp, ops::ExpOpMaker, exp_grad,
ops::ActivationOpGrad);
REGISTER_OP(relu, ops::ActivationOp, ops::ReluOpMaker, relu_grad,
ops::ActivationOpGrad);
REGISTER_OP(relu, ops::ActivationWithMKLDNNOp, ops::ReluOpMaker, relu_grad,
ops::ActivationWithMKLDNNOpGrad);
REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationOpGrad);
REGISTER_OP(tanh, ops::ActivationWithMKLDNNOp, ops::TanhOpMaker, tanh_grad,
ops::ActivationWithMKLDNNOpGrad);
REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
tanh_shrink_grad, ops::ActivationOpGrad);
......@@ -558,11 +549,11 @@ REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker,
softshrink_grad, ops::ActivationOpGrad);
REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationOpGrad);
REGISTER_OP(sqrt, ops::ActivationWithMKLDNNOp, ops::SqrtOpMaker, sqrt_grad,
ops::ActivationWithMKLDNNOpGrad);
REGISTER_OP(abs, ops::ActivationOp, ops::AbsOpMaker, abs_grad,
ops::ActivationOpGrad);
REGISTER_OP(abs, ops::ActivationWithMKLDNNOp, ops::AbsOpMaker, abs_grad,
ops::ActivationWithMKLDNNOpGrad);
REGISTER_OP(ceil, ops::ActivationOp, ops::CeilOpMaker, ceil_grad,
ops::ActivationOpGrad);
......
......@@ -24,25 +24,6 @@ limitations under the License. */
namespace paddle {
namespace operators {
class ActivationHelper {
public:
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper) const {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace(), layout, library);
}
};
template <typename DeviceContext, typename Functor>
class ActivationKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......
......@@ -60,5 +60,52 @@ class MKLDNNActivationGradKernel
}
};
namespace {
framework::OpKernelType GetKernelType(
const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper) {
framework::LibraryType library{framework::LibraryType::kPlain};
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library = framework::LibraryType::kMKLDNN;
}
#endif
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
ctx.GetPlace(), layout, library);
}
} // anonymous namespace
class ActivationWithMKLDNNOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Out");
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this);
}
};
class ActivationWithMKLDNNOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Out"));
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this);
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册