未验证 提交 af0c4c45 编写于 作者: Q Qiao Longfei 提交者: GitHub

Impl kernel hint (#6883)

* init kernel hint

* fix typo

* rm unused code

* add include in op_kernel.h

* restore op_kernel since it will be moved to op_kernel_type

* change force_cpu to use_cpu

* fix compilation
上级 efd37269
......@@ -402,19 +402,28 @@ void OperatorWithKernel::Run(const Scope& scope,
OpKernelMap& kernels = kernels_iter->second;
ExecutionContext ctx(*this, scope, *dev_ctx);
auto kernel_key = GetKernelType(ctx);
auto kernel_iter = kernels.find(kernel_key);
auto actual_kernel_key = GetActualKernelType(ctx);
auto expected_kernel_key = GetExpectedKernelType(actual_kernel_key);
auto kernel_iter = kernels.find(expected_kernel_key);
if (kernel_iter == kernels.end()) {
PADDLE_THROW("The operator %s does not support %s", type_, kernel_key);
PADDLE_THROW("The operator %s does not support %s", type_,
expected_kernel_key);
}
kernel_iter->second->Compute(ctx);
}
OpKernelType OperatorWithKernel::GetKernelType(
OpKernelType OperatorWithKernel::GetActualKernelType(
const ExecutionContext& ctx) const {
return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}
OpKernelType OperatorWithKernel::GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const {
return actual_kernel_type;
}
proto::DataType OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
auto& scope = ctx.scope();
......
......@@ -52,6 +52,11 @@ constexpr char kGradVarSuffix[] = "@GRAD";
/// Variables with this suffix are supposed to be filled up with zeros.
constexpr char kZeroVarSuffix[] = "@ZERO";
// define some kernel hint
const std::string kUseCPU = "use_cpu";
const std::string kUseCUDNN = "use_cudnn";
const std::string kUseMKLDNN = "use_mkldnn";
inline std::string GradVarName(const std::string& var_name) {
return var_name + kGradVarSuffix;
}
......@@ -373,7 +378,9 @@ class OperatorWithKernel : public OperatorBase {
}
protected:
virtual OpKernelType GetKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetActualKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetExpectedKernelType(
const OpKernelType& actual_kernel_type) const;
private:
// indicate kernel DataType by input data. Defaultly all input data must be
......
......@@ -114,7 +114,7 @@ class OpWithKernelTest : public OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetKernelType(const ExecutionContext& ctx) const override {
OpKernelType GetActualKernelType(const ExecutionContext& ctx) const override {
return OpKernelType(proto::DataType::FP32, ctx.GetPlace());
}
};
......
......@@ -53,7 +53,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
......
......@@ -39,7 +39,7 @@ class AucOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Out")->type()),
......
......@@ -304,7 +304,7 @@ class BatchNormGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
const auto *var = ctx.InputVar(framework::GradVarName("Y"));
if (var == nullptr) {
......
......@@ -55,7 +55,7 @@ class ChunkEvalOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(framework::proto::DataType::FP32,
ctx.device_context());
......
......@@ -66,9 +66,9 @@ class CompareOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx);
// CompareOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
......
......@@ -120,12 +120,18 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
ctx.device_context());
}
framework::OpKernelType GetExpectedKernelType(
const framework::OpKernelType& actual_kernel_type) const override {
return framework::OpKernelType(actual_kernel_type.data_type_,
platform::CPUPlace());
}
};
} // namespace operators
} // namespace paddle
......
......@@ -51,7 +51,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......@@ -101,7 +101,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of cross_entropy
// is determined by its input "X".
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......
......@@ -49,7 +49,7 @@ class FillConstantBatchSizeLikeOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
......
......@@ -40,7 +40,7 @@ class GatherOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......@@ -57,7 +57,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......
......@@ -57,7 +57,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
......
......@@ -183,7 +183,7 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of computation kernel of linear_chain_crf
// is determined by its input "Emission".
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("Emission")->type()),
......@@ -242,7 +242,7 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
protected:
// Explicitly set that the data type of output of the linear_chain_crf_grad
// operator is determined by its input: gradients of LogLikelihood.
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
......
......@@ -38,7 +38,7 @@ class LoDResetOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......@@ -97,7 +97,7 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......
......@@ -99,9 +99,9 @@ class LogicalOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx);
framework::OpKernelType kt = OperatorWithKernel::GetActualKernelType(ctx);
// LogicalOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
......
......@@ -41,7 +41,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
......@@ -98,7 +98,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<LoDTensor>("W")->type()),
......
......@@ -92,7 +92,7 @@ class LSTMOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
......@@ -260,7 +260,7 @@ class LSTMGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("Input")->type()),
......
......@@ -51,7 +51,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
......@@ -102,7 +102,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.MultiInput<Tensor>("X")[0]->type()),
......
......@@ -63,7 +63,7 @@ class NCEOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
......@@ -166,7 +166,7 @@ class NCEOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Input")->type()),
......
......@@ -69,7 +69,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......@@ -90,7 +90,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......
......@@ -85,7 +85,7 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Score")->type()),
......
......@@ -80,7 +80,7 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("MaxProbs")->type()),
......
......@@ -68,7 +68,7 @@ class ROIPoolOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......@@ -89,7 +89,7 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......
......@@ -49,7 +49,7 @@ class ScatterOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
......@@ -68,7 +68,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Ref")->type()),
......
......@@ -107,7 +107,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("X")->type()),
......
......@@ -48,7 +48,7 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......@@ -69,7 +69,7 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
......
......@@ -118,7 +118,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<Tensor>("Logits")->type()),
......@@ -159,7 +159,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(
......
......@@ -53,7 +53,7 @@ class SumOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
auto x_vars = ctx.MultiInputVar("X");
if (x_vars[0]->IsType<framework::LoDTensor>()) {
......
......@@ -63,7 +63,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
}
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::DataType>(ctx.Attr<int>("dtype")),
......
......@@ -71,7 +71,7 @@ int OutputSize(int input_size, int ksize, int padding, int stride) {
class UnpoolOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......@@ -110,7 +110,7 @@ class UnpoolOp : public framework::OperatorWithKernel {
class UnpoolOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetKernelType(
framework::OpKernelType GetActualKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::Tensor>("X")->type()),
......
......@@ -23,6 +23,11 @@ void BindConstValue(pybind11::module& m) {
m.def("kTempVarName", [] { return framework::kTempVarName; });
m.def("kGradVarSuffix", [] { return framework::kGradVarSuffix; });
m.def("kZeroVarSuffix", [] { return framework::kZeroVarSuffix; });
// for kernel_hint key
m.def("kUseCPU", [] { return framework::kUseCPU; });
m.def("kUseCUDNN", [] { return framework::kUseCUDNN; });
m.def("kUseMKLDNN", [] { return framework::kUseMKLDNN; });
}
} // namespace pybind
......
......@@ -17,6 +17,10 @@ TEMP_VAR_NAME = core.kTempVarName()
GRAD_VAR_SUFFIX = core.kGradVarSuffix()
ZERO_VAR_SUFFIX = core.kZeroVarSuffix()
USE_CPU = core.kUseCPU()
USE_CUDNN = core.kUseMKLDNN()
USE_MKLDNN = core.kUseMKLDNN()
def grad_var_name(var_name):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册