未验证 提交 26cc1fe5 编写于 作者: C Chen Weihang 提交者: GitHub

Replace risky GetInputType method with secure IndicateVarDataType interface (#20668)

* replace part of the old implementation, test=develop

* restore concat op, test=develop

* update all ops implemention & delete GetDataTypeOfVar func, test=develop
上级 fd5321b3
...@@ -48,16 +48,6 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = { ...@@ -48,16 +48,6 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
std::make_tuple(platform::CPUPlace(), LibraryType::kPlain), std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
}; };
proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
if (var->IsType<framework::LoDTensor>()) {
return var->Get<framework::LoDTensor>().type();
} else if (var->IsType<framework::SelectedRows>()) {
return var->Get<framework::SelectedRows>().value().type();
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
}
}
static DDim GetDimsDebug(const Scope& scope, const std::string& name, static DDim GetDimsDebug(const Scope& scope, const std::string& name,
bool get_actual_dim = false) { bool get_actual_dim = false) {
Variable* var = scope.FindVar(name); Variable* var = scope.FindVar(name);
......
...@@ -102,7 +102,6 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) { ...@@ -102,7 +102,6 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) {
} }
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
......
...@@ -114,9 +114,8 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, ...@@ -114,9 +114,8 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(oper.IndicateVarDataType(ctx, name),
framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout, ctx.GetPlace(), layout, library);
library);
} }
class ActivationOp : public framework::OperatorWithKernel { class ActivationOp : public framework::OperatorWithKernel {
......
...@@ -37,8 +37,9 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel { ...@@ -37,8 +37,9 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
} }
}; };
...@@ -56,9 +57,9 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel { ...@@ -56,9 +57,9 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -121,9 +121,9 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel { ...@@ -121,9 +121,9 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -80,7 +80,7 @@ class AffineGridOp : public framework::OperatorWithKernel { ...@@ -80,7 +80,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
library = framework::LibraryType::kCUDNN; library = framework::LibraryType::kCUDNN;
} }
#endif #endif
auto data_type = ctx.Input<Tensor>("Theta")->type(); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta");
return framework::OpKernelType(data_type, ctx.GetPlace(), return framework::OpKernelType(data_type, ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library); framework::DataLayout::kAnyLayout, library);
} }
...@@ -191,9 +191,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel { ...@@ -191,9 +191,9 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType(ctx.Input<Tensor>("Theta")->type(), return framework::OpKernelType(
ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "Theta"), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
......
...@@ -89,8 +89,9 @@ class AssignOp : public framework::OperatorWithKernel { ...@@ -89,8 +89,9 @@ class AssignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -129,8 +129,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -129,8 +129,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType( framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
} }
void AttentionLSTMOpMaker::Make() { void AttentionLSTMOpMaker::Make() {
......
...@@ -103,8 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel { ...@@ -103,8 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("param")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "param"), ctx.GetPlace());
} }
}; };
......
...@@ -115,7 +115,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -115,7 +115,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType BatchNormOp::GetExpectedKernelType( framework::OpKernelType BatchNormOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
auto input_data_type = ctx.Input<Tensor>("X")->type(); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean, // By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor) // and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor). // or double (For double input tensor).
...@@ -432,8 +432,9 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType( ...@@ -432,8 +432,9 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
} }
#endif #endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), ctx.GetPlace(), return framework::OpKernelType(
layout, library); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
library);
} }
template <typename T> template <typename T>
......
...@@ -109,10 +109,11 @@ class BeamSearchOp : public framework::OperatorWithKernel { ...@@ -109,10 +109,11 @@ class BeamSearchOp : public framework::OperatorWithKernel {
// Compute on CPU for cases with batch_size > 4. // Compute on CPU for cases with batch_size > 4.
if (batch_size <= 4) { if (batch_size <= 4) {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("pre_ids")->type(), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"),
ctx.GetPlace());
} else { } else {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("pre_ids")->type(), OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"),
platform::CPUPlace()); platform::CPUPlace());
} }
} }
......
...@@ -52,8 +52,9 @@ class BprLossOp : public framework::OperatorWithKernel { ...@@ -52,8 +52,9 @@ class BprLossOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
} }
}; };
...@@ -98,8 +99,9 @@ class BprLossGradientOp : public framework::OperatorWithKernel { ...@@ -98,8 +99,9 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
} }
}; };
......
...@@ -61,8 +61,9 @@ class CenterLossOp : public framework::OperatorWithKernel { ...@@ -61,8 +61,9 @@ class CenterLossOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -117,7 +118,8 @@ class CenterLossGradOp : public framework::OperatorWithKernel { ...@@ -117,7 +118,8 @@ class CenterLossGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<Tensor>("SampleCenterDiff")->type(), ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "SampleCenterDiff"),
ctx.device_context());
} }
}; };
......
...@@ -41,8 +41,8 @@ class CAllReduceOp : public framework::OperatorWithKernel { ...@@ -41,8 +41,8 @@ class CAllReduceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
......
...@@ -28,8 +28,8 @@ class CBroadcastOp : public framework::OperatorWithKernel { ...@@ -28,8 +28,8 @@ class CBroadcastOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
......
...@@ -102,7 +102,6 @@ class ConcatOp : public framework::OperatorWithKernel { ...@@ -102,7 +102,6 @@ class ConcatOp : public framework::OperatorWithKernel {
if (flag == 0) { if (flag == 0) {
PADDLE_THROW("All Inputs of Concat OP are Empty!"); PADDLE_THROW("All Inputs of Concat OP are Empty!");
} }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(), return framework::OpKernelType(input_data_type, ctx.GetPlace(),
...@@ -175,9 +174,9 @@ class ConcatOpGrad : public framework::OperatorWithKernel { ...@@ -175,9 +174,9 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -135,7 +135,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType( ...@@ -135,7 +135,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
framework::OpKernelType::kDefaultCustomizedTypeValue; framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::LibraryType library{framework::LibraryType::kPlain}; framework::LibraryType library{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
auto input_data_type = ctx.Input<Tensor>("Input")->type(); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
std::string data_format = std::string data_format =
"AnyLayout"; // todo enable data layout when it's ready "AnyLayout"; // todo enable data layout when it's ready
framework::DataLayout layout = framework::StringToDataLayout(data_format); framework::DataLayout layout = framework::StringToDataLayout(data_format);
...@@ -527,9 +527,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType( ...@@ -527,9 +527,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
} }
#endif #endif
auto type = framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), auto type = framework::OpKernelType(
ctx.GetPlace(), layout_, library_, OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
customized_type_value); layout_, library_, customized_type_value);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) { if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type]; std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
...@@ -704,9 +704,9 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType( ...@@ -704,9 +704,9 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
customized_type_value = kConvMKLDNNFP32; customized_type_value = kConvMKLDNNFP32;
} }
#endif #endif
auto type = framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), auto type = framework::OpKernelType(
ctx.GetPlace(), layout_, library_, OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
customized_type_value); layout_, library_, customized_type_value);
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) { if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type]; std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
......
...@@ -132,8 +132,9 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType( ...@@ -132,8 +132,9 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
} }
#endif #endif
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.GetPlace(), layout_, library_); OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
} }
void Conv2DTransposeOpMaker::Make() { void Conv2DTransposeOpMaker::Make() {
...@@ -384,8 +385,9 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType( ...@@ -384,8 +385,9 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
} }
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.GetPlace(), layout_, library_); OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
} }
class ConvTransposeGradOpDescMaker : public framework::SingleGradOpDescMaker { class ConvTransposeGradOpDescMaker : public framework::SingleGradOpDescMaker {
......
...@@ -160,8 +160,9 @@ class CRFDecodingOp : public framework::OperatorWithKernel { ...@@ -160,8 +160,9 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Emission"),
platform::CPUPlace());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -53,8 +53,9 @@ class CropOp : public framework::OperatorWithKernel { ...@@ -53,8 +53,9 @@ class CropOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -174,9 +175,9 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -174,9 +175,9 @@ class CropOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -98,8 +98,9 @@ class CropTensorOp : public framework::OperatorWithKernel { ...@@ -98,8 +98,9 @@ class CropTensorOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
...@@ -254,9 +255,9 @@ class CropTensorOpGrad : public framework::OperatorWithKernel { ...@@ -254,9 +255,9 @@ class CropTensorOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
...@@ -107,8 +107,9 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel { ...@@ -107,8 +107,9 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const { virtual bool IsSoftLabel(framework::InferShapeContext* ctx) const {
...@@ -157,9 +158,9 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { ...@@ -157,9 +158,9 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Y"))->type(), ctx, framework::GradVarName("Y")),
ctx.device_context()); ctx.device_context());
} }
virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const { virtual framework::DDim GetXDim(framework::InferShapeContext* ctx) const {
......
...@@ -39,8 +39,9 @@ class CTCAlignOp : public framework::OperatorWithKernel { ...@@ -39,8 +39,9 @@ class CTCAlignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
}; };
......
...@@ -52,8 +52,9 @@ class CVMOp : public framework::OperatorWithKernel { ...@@ -52,8 +52,9 @@ class CVMOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
} }
}; };
...@@ -93,8 +94,9 @@ class CVMGradientOp : public framework::OperatorWithKernel { ...@@ -93,8 +94,9 @@ class CVMGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
} }
}; };
......
...@@ -81,7 +81,7 @@ class DataNormOp : public framework::OperatorWithKernel { ...@@ -81,7 +81,7 @@ class DataNormOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = ctx.Input<Tensor>("X")->type(); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean, // By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor) // and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor). // or double (For double input tensor).
...@@ -89,12 +89,14 @@ class DataNormOp : public framework::OperatorWithKernel { ...@@ -89,12 +89,14 @@ class DataNormOp : public framework::OperatorWithKernel {
if (input_data_type == framework::proto::VarType::FP64) { if (input_data_type == framework::proto::VarType::FP64) {
dn_param_type = framework::proto::VarType::FP64; dn_param_type = framework::proto::VarType::FP64;
} }
PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input<Tensor>("BatchSize")->type(), PADDLE_ENFORCE_EQ(dn_param_type,
OperatorWithKernel::IndicateVarDataType(ctx, "BatchSize"),
"BatchSize input should be of float type"); "BatchSize input should be of float type");
PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input<Tensor>("BatchSum")->type(),
"BatchSum input should be of float type");
PADDLE_ENFORCE_EQ(dn_param_type, PADDLE_ENFORCE_EQ(dn_param_type,
ctx.Input<Tensor>("BatchSquareSum")->type(), OperatorWithKernel::IndicateVarDataType(ctx, "BatchSum"),
"BatchSum input should be of float type");
PADDLE_ENFORCE_EQ(dn_param_type, OperatorWithKernel::IndicateVarDataType(
ctx, "BatchSquareSum"),
"BatchSquareSum input should be of float type"); "BatchSquareSum input should be of float type");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready // TODO(pzelazko-intel): enable MKLDNN layout when it's ready
...@@ -276,8 +278,9 @@ class DataNormGradOp : public framework::OperatorWithKernel { ...@@ -276,8 +278,9 @@ class DataNormGradOp : public framework::OperatorWithKernel {
} }
#endif #endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace(), layout, library); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
} }
}; };
......
...@@ -216,8 +216,9 @@ class DeformableConvOp : public framework::OperatorWithKernel { ...@@ -216,8 +216,9 @@ class DeformableConvOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
}; };
...@@ -275,8 +276,9 @@ class DeformableConvGradOp : public framework::OperatorWithKernel { ...@@ -275,8 +276,9 @@ class DeformableConvGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -199,8 +199,9 @@ class DeformableConvV1Op : public framework::OperatorWithKernel { ...@@ -199,8 +199,9 @@ class DeformableConvV1Op : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
}; };
...@@ -253,8 +254,9 @@ class DeformableConvV1GradOp : public framework::OperatorWithKernel { ...@@ -253,8 +254,9 @@ class DeformableConvV1GradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -199,8 +199,9 @@ class DeformablePSROIPoolOp : public framework::OperatorWithKernel { ...@@ -199,8 +199,9 @@ class DeformablePSROIPoolOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
}; };
...@@ -247,8 +248,9 @@ class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel { ...@@ -247,8 +248,9 @@ class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Trans")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Trans"),
ctx.device_context());
} }
}; };
......
...@@ -25,8 +25,9 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType( ...@@ -25,8 +25,9 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(
framework::LibraryType library_ = framework::LibraryType::kMKLDNN; framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.GetPlace(), layout_, library_); OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
} }
void DeQuantOpMaker::Make() { void DeQuantOpMaker::Make() {
......
...@@ -53,7 +53,8 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel { ...@@ -53,7 +53,8 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
}; };
......
...@@ -45,8 +45,9 @@ class BipartiteMatchOp : public framework::OperatorWithKernel { ...@@ -45,8 +45,9 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<LoDTensor>("DistMat")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "DistMat"),
platform::CPUPlace());
} }
}; };
......
...@@ -68,7 +68,7 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel { ...@@ -68,7 +68,7 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto data_type = auto data_type =
framework::GetDataTypeOfVar(ctx.MultiInputVar("MultiLevelRois")[0]); OperatorWithKernel::IndicateVarDataType(ctx, "MultiLevelRois");
return framework::OpKernelType(data_type, ctx.GetPlace()); return framework::OpKernelType(data_type, ctx.GetPlace());
} }
}; };
......
...@@ -66,7 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel { ...@@ -66,7 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::Tensor>("Input")->type(), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace());
} }
}; };
......
...@@ -46,7 +46,7 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel { ...@@ -46,7 +46,7 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("FpnRois")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "FpnRois");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
......
...@@ -80,7 +80,7 @@ class GenerateMaskLabelsOp : public framework::OperatorWithKernel { ...@@ -80,7 +80,7 @@ class GenerateMaskLabelsOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Rois")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Rois");
return framework::OpKernelType(data_type, platform::CPUPlace()); return framework::OpKernelType(data_type, platform::CPUPlace());
} }
}; };
......
...@@ -87,7 +87,7 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel { ...@@ -87,7 +87,7 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("RpnRois")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "RpnRois");
return framework::OpKernelType(data_type, platform::CPUPlace()); return framework::OpKernelType(data_type, platform::CPUPlace());
} }
}; };
......
...@@ -60,8 +60,9 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { ...@@ -60,8 +60,9 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Anchors")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"),
ctx.device_context());
} }
}; };
......
...@@ -255,7 +255,8 @@ class MineHardExamplesOp : public framework::OperatorWithKernel { ...@@ -255,7 +255,8 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::Tensor>("ClsLoss")->type(), platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "ClsLoss"),
platform::CPUPlace());
} }
}; };
......
...@@ -80,7 +80,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel { ...@@ -80,7 +80,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Scores")->type(), OperatorWithKernel::IndicateVarDataType(ctx, "Scores"),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -69,7 +69,8 @@ class PriorBoxOp : public framework::OperatorWithKernel { ...@@ -69,7 +69,8 @@ class PriorBoxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_input_type = ctx.Input<framework::Tensor>("Input")->type(); auto input_input_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input");
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
......
...@@ -94,8 +94,7 @@ class RetinanetDetectionOutputOp : public framework::OperatorWithKernel { ...@@ -94,8 +94,7 @@ class RetinanetDetectionOutputOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type =
framework::GetDataTypeOfVar(ctx.MultiInputVar("Scores")[0]); OperatorWithKernel::IndicateVarDataType(ctx, "Scores");
return framework::OpKernelType(input_data_type, return framework::OpKernelType(input_data_type,
platform::CPUPlace()); // ctx.GetPlace()); platform::CPUPlace()); // ctx.GetPlace());
} }
......
...@@ -525,8 +525,9 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel { ...@@ -525,8 +525,9 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -545,8 +546,9 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel { ...@@ -545,8 +546,9 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -77,7 +77,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -77,7 +77,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Anchor")->type(), OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
...@@ -726,7 +726,7 @@ class RetinanetTargetAssignOp : public framework::OperatorWithKernel { ...@@ -726,7 +726,7 @@ class RetinanetTargetAssignOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Anchor")->type(), OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -63,8 +63,9 @@ class SigmoidFocalLossOp : public framework::OperatorWithKernel { ...@@ -63,8 +63,9 @@ class SigmoidFocalLossOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -116,8 +117,9 @@ class SigmoidFocalLossGradOp : public framework::OperatorWithKernel { ...@@ -116,8 +117,9 @@ class SigmoidFocalLossGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -57,8 +57,9 @@ class TargetAssignOp : public framework::OperatorWithKernel { ...@@ -57,8 +57,9 @@ class TargetAssignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -65,8 +65,8 @@ class YoloBoxOp : public framework::OperatorWithKernel { ...@@ -65,8 +65,8 @@ class YoloBoxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
......
...@@ -98,8 +98,9 @@ class Yolov3LossOp : public framework::OperatorWithKernel { ...@@ -98,8 +98,9 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
} }
}; };
...@@ -255,8 +256,9 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel { ...@@ -255,8 +256,9 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
} }
}; };
......
...@@ -73,7 +73,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel { ...@@ -73,7 +73,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::Tensor>("DetectRes")->type(), OperatorWithKernel::IndicateVarDataType(ctx, "DetectRes"),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -29,8 +29,8 @@ class AllReduceOp : public framework::OperatorWithKernel { ...@@ -29,8 +29,8 @@ class AllReduceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
......
...@@ -72,7 +72,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel { ...@@ -72,7 +72,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
......
...@@ -108,7 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel { ...@@ -108,7 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.MultiInput<framework::Tensor>("X").front()->type(), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
......
...@@ -42,7 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel { ...@@ -42,7 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.MultiInput<framework::Tensor>("X")[0]->type(), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
......
...@@ -66,8 +66,7 @@ class SplitIdsOp : public framework::OperatorWithKernel { ...@@ -66,8 +66,7 @@ class SplitIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.MultiInputVar("Ids").front()), OperatorWithKernel::IndicateVarDataType(ctx, "Ids"), ctx.GetPlace());
ctx.GetPlace());
} }
}; };
......
...@@ -121,9 +121,9 @@ class DropoutOpGrad : public framework::OperatorWithKernel { ...@@ -121,9 +121,9 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -153,7 +153,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -153,7 +153,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = ctx.Input<Tensor>("DDX")->type(); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
......
...@@ -82,7 +82,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { ...@@ -82,7 +82,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
...@@ -236,8 +236,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -236,8 +236,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(); ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
...@@ -274,7 +274,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel { ...@@ -274,7 +274,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = ctx.Input<Tensor>("DOut")->type(); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) { if (platform::CanMKLDNNBeUsed(ctx)) {
...@@ -306,13 +306,13 @@ class ElementwiseOpDoubleGradWithoutDXDY ...@@ -306,13 +306,13 @@ class ElementwiseOpDoubleGradWithoutDXDY
if (ctx.HasInput("DDX") == false) { if (ctx.HasInput("DDX") == false) {
PADDLE_ENFORCE_EQ(ctx.HasInput("DDY"), true, PADDLE_ENFORCE_EQ(ctx.HasInput("DDY"), true,
"Input(DDY) should not be null"); "Input(DDY) should not be null");
input_data_type = ctx.Input<Tensor>("DDY")->type(); input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDY");
} else if (ctx.HasInput("DDY") == false) { } else if (ctx.HasInput("DDY") == false) {
PADDLE_ENFORCE_EQ(ctx.HasInput("DDX"), true, PADDLE_ENFORCE_EQ(ctx.HasInput("DDX"), true,
"Input(DDX) should not be null"); "Input(DDX) should not be null");
input_data_type = ctx.Input<Tensor>("DDX")->type(); input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
} else { } else {
input_data_type = ctx.Input<Tensor>("DDX")->type(); input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
} }
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
...@@ -65,8 +65,9 @@ class ExpandOp : public framework::OperatorWithKernel { ...@@ -65,8 +65,9 @@ class ExpandOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
...@@ -180,9 +181,9 @@ class ExpandGradOp : public framework::OperatorWithKernel { ...@@ -180,9 +181,9 @@ class ExpandGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
...@@ -190,8 +190,9 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -190,8 +190,9 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -241,8 +242,8 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -241,8 +242,8 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
...@@ -303,8 +304,9 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel { ...@@ -303,8 +304,9 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -375,8 +377,9 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp ...@@ -375,8 +377,9 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -450,8 +453,8 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel { ...@@ -450,8 +453,8 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
......
...@@ -80,8 +80,9 @@ class FCOp : public framework::OperatorWithKernel { ...@@ -80,8 +80,9 @@ class FCOp : public framework::OperatorWithKernel {
library = framework::LibraryType::kMKLDNN; library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN; layout = framework::DataLayout::kMKLDNN;
} }
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.GetPlace(), layout, library); OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout, library);
} }
}; };
......
...@@ -48,7 +48,7 @@ class FilterByInstagOp : public framework::OperatorWithKernel { ...@@ -48,7 +48,7 @@ class FilterByInstagOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Ins")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Ins");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
...@@ -101,8 +101,8 @@ class FilterByInstagOpGrad : public framework::OperatorWithKernel { ...@@ -101,8 +101,8 @@ class FilterByInstagOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar( auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx.InputVar(framework::GradVarName("Out"))); ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
......
...@@ -69,8 +69,9 @@ class FlattenOp : public framework::OperatorWithKernel { ...@@ -69,8 +69,9 @@ class FlattenOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -130,8 +131,9 @@ class FlattenGradOp : public framework::OperatorWithKernel { ...@@ -130,8 +131,9 @@ class FlattenGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -221,9 +223,9 @@ class Flatten2GradOp : public framework::OperatorWithKernel { ...@@ -221,9 +223,9 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -49,8 +49,9 @@ class FSPOp : public framework::OperatorWithKernel { ...@@ -49,8 +49,9 @@ class FSPOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context(), layout_, library_); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(),
layout_, library_);
} }
}; };
...@@ -107,9 +108,9 @@ class FSPOpGrad : public framework::OperatorWithKernel { ...@@ -107,9 +108,9 @@ class FSPOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -140,8 +140,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel { ...@@ -140,8 +140,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(), PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(),
ctx.Input<framework::Tensor>("Y")->type(), ctx.Input<framework::Tensor>("Y")->type(),
"The element's type of input should be the same."); "The element's type of input should be the same.");
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
...@@ -328,8 +328,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel { ...@@ -328,8 +328,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("Y")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Y"), ctx.GetPlace());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -114,7 +114,7 @@ void FusedEmbeddingFCLSTMOp::InferShape( ...@@ -114,7 +114,7 @@ void FusedEmbeddingFCLSTMOp::InferShape(
framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType( framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Embeddings")->type(), OperatorWithKernel::IndicateVarDataType(ctx, "Embeddings"),
ctx.device_context()); ctx.device_context());
} }
......
...@@ -57,7 +57,7 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel { ...@@ -57,7 +57,7 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
...@@ -126,7 +126,7 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel { ...@@ -126,7 +126,7 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
......
...@@ -58,7 +58,8 @@ class ConvInceptionFusionOp : public framework::OperatorWithKernel { ...@@ -58,7 +58,8 @@ class ConvInceptionFusionOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
}; };
......
...@@ -93,8 +93,8 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -93,8 +93,8 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionGRUOp::GetExpectedKernelType( framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
} }
void FusionGRUOpMaker::Make() { void FusionGRUOpMaker::Make() {
......
...@@ -117,8 +117,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -117,8 +117,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType( framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
} }
void FusionLSTMOpMaker::Make() { void FusionLSTMOpMaker::Make() {
......
...@@ -60,8 +60,8 @@ void FusionRepeatedFCReluOp::InferShape( ...@@ -60,8 +60,8 @@ void FusionRepeatedFCReluOp::InferShape(
framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType( framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
void FusionRepeatedFCReluOpMaker::Make() { void FusionRepeatedFCReluOpMaker::Make() {
......
...@@ -61,8 +61,8 @@ void FusionSeqConvEltAddReluOp::InferShape( ...@@ -61,8 +61,8 @@ void FusionSeqConvEltAddReluOp::InferShape(
framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType( framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
} }
void FusionSeqConvEltAddReluOpMaker::Make() { void FusionSeqConvEltAddReluOpMaker::Make() {
......
...@@ -67,8 +67,8 @@ void FusionSeqExpandConcatFCOp::InferShape( ...@@ -67,8 +67,8 @@ void FusionSeqExpandConcatFCOp::InferShape(
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(ctx.MultiInput<LoDTensor>("X")[0]->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
} }
void FusionSeqExpandConcatFCOpMaker::Make() { void FusionSeqExpandConcatFCOpMaker::Make() {
......
...@@ -47,7 +47,7 @@ void FusionSeqPoolConcatOp::InferShape( ...@@ -47,7 +47,7 @@ void FusionSeqPoolConcatOp::InferShape(
framework::OpKernelType FusionSeqPoolConcatOp::GetExpectedKernelType( framework::OpKernelType FusionSeqPoolConcatOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
void FusionSeqPoolConcatOpMaker::Make() { void FusionSeqPoolConcatOpMaker::Make() {
......
...@@ -52,7 +52,7 @@ void FusionSeqPoolCVMConcatOp::InferShape( ...@@ -52,7 +52,7 @@ void FusionSeqPoolCVMConcatOp::InferShape(
framework::OpKernelType FusionSeqPoolCVMConcatOp::GetExpectedKernelType( framework::OpKernelType FusionSeqPoolCVMConcatOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
void FusionSeqPoolCVMConcatOpMaker::Make() { void FusionSeqPoolCVMConcatOpMaker::Make() {
......
...@@ -53,8 +53,8 @@ void FusionSquaredMatSubOp::InferShape( ...@@ -53,8 +53,8 @@ void FusionSquaredMatSubOp::InferShape(
framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType( framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
void FusionSquaredMatSubOpMaker::Make() { void FusionSquaredMatSubOpMaker::Make() {
......
...@@ -61,7 +61,7 @@ class GatherNdOp : public framework::OperatorWithKernel { ...@@ -61,7 +61,7 @@ class GatherNdOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
const auto& x_type = x->type(); const auto& x_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType( return framework::OpKernelType(
x_type, x_type,
x_type == framework::proto::VarType::BOOL x_type == framework::proto::VarType::BOOL
...@@ -82,9 +82,9 @@ class GatherNdGradOp : public framework::OperatorWithKernel { ...@@ -82,9 +82,9 @@ class GatherNdGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -45,8 +45,9 @@ class GatherOp : public framework::OperatorWithKernel { ...@@ -45,8 +45,9 @@ class GatherOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -62,9 +63,9 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -62,9 +63,9 @@ class GatherGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -40,8 +40,9 @@ class GatherTreeOp : public framework::OperatorWithKernel { ...@@ -40,8 +40,9 @@ class GatherTreeOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Ids")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Ids"),
ctx.device_context());
} }
}; };
......
...@@ -45,7 +45,8 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel { ...@@ -45,7 +45,8 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -68,9 +68,9 @@ class GridSampleOp : public framework::OperatorWithKernel { ...@@ -68,9 +68,9 @@ class GridSampleOp : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
...@@ -164,9 +164,9 @@ class GridSampleOpGrad : public framework::OperatorWithKernel { ...@@ -164,9 +164,9 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN; library_ = framework::LibraryType::kCUDNN;
} }
#endif #endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_); framework::DataLayout::kAnyLayout, library_);
} }
}; };
......
...@@ -81,8 +81,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel { ...@@ -81,8 +81,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
...@@ -224,8 +224,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel { ...@@ -224,8 +224,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
......
...@@ -70,7 +70,7 @@ void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -70,7 +70,7 @@ void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType InstanceNormOp::GetExpectedKernelType( framework::OpKernelType InstanceNormOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const {
auto input_data_type = ctx.Input<Tensor>("X")->type(); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean, // By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor) // and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor). // or double (For double input tensor).
...@@ -236,8 +236,8 @@ framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType( ...@@ -236,8 +236,8 @@ framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType(
if (t == nullptr) { if (t == nullptr) {
PADDLE_THROW("cannot find Y@GRAD"); PADDLE_THROW("cannot find Y@GRAD");
} }
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
template <typename T> template <typename T>
...@@ -396,8 +396,8 @@ framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType( ...@@ -396,8 +396,8 @@ framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType(
if (t == nullptr) { if (t == nullptr) {
PADDLE_THROW("cannot find Y@GRAD"); PADDLE_THROW("cannot find Y@GRAD");
} }
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
std::unique_ptr<framework::OpDesc> InstanceNormDoubleGradMaker::Apply() const { std::unique_ptr<framework::OpDesc> InstanceNormDoubleGradMaker::Apply() const {
......
...@@ -204,8 +204,8 @@ class InterpolateOp : public framework::OperatorWithKernel { ...@@ -204,8 +204,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
...@@ -407,9 +407,9 @@ class InterpolateOpGrad : public framework::OperatorWithKernel { ...@@ -407,9 +407,9 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
...@@ -35,7 +35,8 @@ class IsEmptyOp : public framework::OperatorWithKernel { ...@@ -35,7 +35,8 @@ class IsEmptyOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X"); auto *x = ctx.Input<framework::LoDTensor>("X");
return framework::OpKernelType(x->type(), x->place()); return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), x->place());
} }
}; };
......
...@@ -58,8 +58,8 @@ class KLDivLossOp : public framework::OperatorWithKernel { ...@@ -58,8 +58,8 @@ class KLDivLossOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
...@@ -136,8 +136,8 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel { ...@@ -136,8 +136,8 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
......
...@@ -224,8 +224,9 @@ class LinearChainCRFOp : public framework::OperatorWithKernel { ...@@ -224,8 +224,9 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
// is determined by its input "Emission". // is determined by its input "Emission".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Emission"),
platform::CPUPlace());
} }
}; };
...@@ -263,7 +264,8 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel { ...@@ -263,7 +264,8 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))->type(), OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("LogLikelihood")),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -52,8 +52,8 @@ class LinspaceOp : public framework::OperatorWithKernel { ...@@ -52,8 +52,8 @@ class LinspaceOp : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::Tensor>("Start")->type(), ctx.device_context(), OperatorWithKernel::IndicateVarDataType(ctx, "Start"),
layout_, library_); ctx.device_context(), layout_, library_);
} }
}; };
......
...@@ -46,8 +46,9 @@ class LoDResetOp : public framework::OperatorWithKernel { ...@@ -46,8 +46,9 @@ class LoDResetOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -172,9 +173,9 @@ class LoDResetGradOp : public framework::OperatorWithKernel { ...@@ -172,9 +173,9 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -64,7 +64,7 @@ class LookupTableOp : public framework::OperatorWithKernel { ...@@ -64,7 +64,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
...@@ -166,8 +166,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -166,8 +166,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar( auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx.InputVar(framework::GradVarName("Out"))); ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
......
...@@ -58,7 +58,7 @@ class LookupTableV2Op : public framework::OperatorWithKernel { ...@@ -58,7 +58,7 @@ class LookupTableV2Op : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
...@@ -154,8 +154,8 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel { ...@@ -154,8 +154,8 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar( auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx.InputVar(framework::GradVarName("Out"))); ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
......
...@@ -130,26 +130,6 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> { ...@@ -130,26 +130,6 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
template struct LRNGradFunctor<platform::CPUDeviceContext, float>; template struct LRNGradFunctor<platform::CPUDeviceContext, float>;
template struct LRNGradFunctor<platform::CPUDeviceContext, double>; template struct LRNGradFunctor<platform::CPUDeviceContext, double>;
namespace {
framework::OpKernelType GetExpectedLRNKernel(
const framework::ExecutionContext& ctx) {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), ctx.GetPlace(),
layout_, library_);
}
} // namespace
class LRNOp : public framework::OperatorWithKernel { class LRNOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -175,7 +155,20 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -175,7 +155,20 @@ class LRNOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetExpectedLRNKernel(ctx); framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
} }
}; };
...@@ -281,7 +274,20 @@ class LRNOpGrad : public framework::OperatorWithKernel { ...@@ -281,7 +274,20 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetExpectedLRNKernel(ctx); framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -97,7 +97,8 @@ class LSTMOp : public framework::OperatorWithKernel { ...@@ -97,7 +97,8 @@ class LSTMOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
}; };
...@@ -261,7 +262,8 @@ class LSTMGradOp : public framework::OperatorWithKernel { ...@@ -261,7 +262,8 @@ class LSTMGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
}; };
......
...@@ -109,7 +109,8 @@ class LSTMPOp : public framework::OperatorWithKernel { ...@@ -109,7 +109,8 @@ class LSTMPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
}; };
...@@ -347,7 +348,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel { ...@@ -347,7 +348,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("BatchGate")->type(), OperatorWithKernel::IndicateVarDataType(ctx, "BatchGate"),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -44,8 +44,9 @@ class MeanIoUOp : public framework::OperatorWithKernel { ...@@ -44,8 +44,9 @@ class MeanIoUOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Predictions")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Predictions"),
ctx.GetPlace());
} }
}; };
......
...@@ -64,8 +64,8 @@ class MeanGradOp : public framework::OperatorWithKernel { ...@@ -64,8 +64,8 @@ class MeanGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(); ctx, framework::GradVarName("Out"));
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -68,8 +68,8 @@ class AccuracyOp : public framework::OperatorWithKernel { ...@@ -68,8 +68,8 @@ class AccuracyOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Out")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace());
} }
}; };
......
...@@ -53,8 +53,9 @@ class AucOp : public framework::OperatorWithKernel { ...@@ -53,8 +53,9 @@ class AucOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Predict")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Predict"),
platform::CPUPlace());
} }
}; };
......
...@@ -92,8 +92,9 @@ class PrecisionRecallOp : public framework::OperatorWithKernel { ...@@ -92,8 +92,9 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("MaxProbs")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "MaxProbs"),
ctx.device_context());
} }
}; };
......
...@@ -55,8 +55,9 @@ class MultiplexOp : public framework::OperatorWithKernel { ...@@ -55,8 +55,9 @@ class MultiplexOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -125,9 +126,9 @@ class MultiplexGradOp : public framework::OperatorWithKernel { ...@@ -125,9 +126,9 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -92,8 +92,9 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -92,8 +92,9 @@ class NCEOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
platform::CPUPlace());
} }
}; };
...@@ -246,8 +247,9 @@ class NCEOpGrad : public framework::OperatorWithKernel { ...@@ -246,8 +247,9 @@ class NCEOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
platform::CPUPlace());
} }
}; };
......
...@@ -51,8 +51,9 @@ class OneHotOp : public framework::OperatorWithKernel { ...@@ -51,8 +51,9 @@ class OneHotOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
...@@ -48,8 +48,9 @@ class OneHotV2Op : public framework::OperatorWithKernel { ...@@ -48,8 +48,9 @@ class OneHotV2Op : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
...@@ -75,8 +75,8 @@ class AdadeltaOp : public framework::OperatorWithKernel { ...@@ -75,8 +75,8 @@ class AdadeltaOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
} }
}; };
......
...@@ -64,8 +64,8 @@ class AdagradOp : public framework::OperatorWithKernel { ...@@ -64,8 +64,8 @@ class AdagradOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
} }
}; };
......
...@@ -78,7 +78,7 @@ void AdamOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -78,7 +78,7 @@ void AdamOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType AdamOp::GetExpectedKernelType( framework::OpKernelType AdamOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
auto input_data_type = ctx.Input<framework::Tensor>("Param")->type(); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
......
...@@ -81,8 +81,8 @@ class AdamaxOp : public framework::OperatorWithKernel { ...@@ -81,8 +81,8 @@ class AdamaxOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
} }
}; };
......
...@@ -69,8 +69,8 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { ...@@ -69,8 +69,8 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
} }
}; };
......
...@@ -55,8 +55,8 @@ class DpsgdOp : public framework::OperatorWithKernel { ...@@ -55,8 +55,8 @@ class DpsgdOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
} }
}; };
......
...@@ -71,7 +71,8 @@ class FTRLOp : public framework::OperatorWithKernel { ...@@ -71,7 +71,8 @@ class FTRLOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = ctx.Input<Tensor>("Param")->type(); auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -85,7 +85,8 @@ class MomentumOp : public framework::OperatorWithKernel { ...@@ -85,7 +85,8 @@ class MomentumOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
}; };
......
...@@ -58,8 +58,8 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { ...@@ -58,8 +58,8 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
} }
}; };
......
...@@ -46,8 +46,8 @@ class ProximalGDOp : public framework::OperatorWithKernel { ...@@ -46,8 +46,8 @@ class ProximalGDOp : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
} }
}; };
......
...@@ -48,7 +48,7 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -48,7 +48,7 @@ class SGDOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
......
...@@ -520,8 +520,8 @@ class Pad2dOp : public framework::OperatorWithKernel { ...@@ -520,8 +520,8 @@ class Pad2dOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
...@@ -621,9 +621,9 @@ class Pad2dOpGrad : public framework::OperatorWithKernel { ...@@ -621,9 +621,9 @@ class Pad2dOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -56,8 +56,9 @@ class PadConstantLikeOp : public framework::OperatorWithKernel { ...@@ -56,8 +56,9 @@ class PadConstantLikeOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Y")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Y"),
ctx.device_context());
} }
}; };
...@@ -186,8 +187,9 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel { ...@@ -186,8 +187,9 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Y")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Y"),
ctx.device_context());
} }
}; };
......
...@@ -134,8 +134,9 @@ framework::OpKernelType PoolOp::GetExpectedKernelType( ...@@ -134,8 +134,9 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
} }
#endif #endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), ctx.GetPlace(), return framework::OpKernelType(
layout_, library_); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
} }
void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const { void PoolOpGrad::InferShape(framework::InferShapeContext* ctx) const {
...@@ -164,7 +165,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType( ...@@ -164,7 +165,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
} }
#endif #endif
auto input_data_type = ctx.Input<Tensor>("X")->type(); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN, PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used"); "float16 can only be used when CUDNN is used");
......
...@@ -76,8 +76,9 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { ...@@ -76,8 +76,9 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -96,8 +97,9 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { ...@@ -96,8 +97,9 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -95,8 +95,9 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel { ...@@ -95,8 +95,9 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Score")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Score"),
ctx.device_context());
} }
}; };
......
...@@ -56,8 +56,9 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -56,8 +56,9 @@ class PReluOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -112,8 +113,9 @@ class PReluGradOp : public framework::OperatorWithKernel { ...@@ -112,8 +113,9 @@ class PReluGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -114,8 +114,9 @@ class PRROIPoolOp : public framework::OperatorWithKernel { ...@@ -114,8 +114,9 @@ class PRROIPoolOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -135,8 +136,9 @@ class PRROIPoolGradOp : public framework::OperatorWithKernel { ...@@ -135,8 +136,9 @@ class PRROIPoolGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -131,8 +131,9 @@ class PSROIPoolOp : public framework::OperatorWithKernel { ...@@ -131,8 +131,9 @@ class PSROIPoolOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -151,8 +152,9 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel { ...@@ -151,8 +152,9 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -104,10 +104,9 @@ class PushBoxSparseOp : public framework::OperatorWithKernel { ...@@ -104,10 +104,9 @@ class PushBoxSparseOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.MultiInput<framework::Tensor>(framework::GradVarName("Out"))[0] ctx, framework::GradVarName("Out")),
->type(), ctx.device_context());
ctx.device_context());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -25,8 +25,9 @@ framework::OpKernelType QuantOp::GetExpectedKernelType( ...@@ -25,8 +25,9 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(
framework::LibraryType library_ = framework::LibraryType::kMKLDNN; framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.GetPlace(), layout_, library_); OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
} }
void QuantOpMaker::Make() { void QuantOpMaker::Make() {
......
...@@ -22,8 +22,9 @@ class RandomCropOp : public framework::OperatorWithKernel { ...@@ -22,8 +22,9 @@ class RandomCropOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -267,9 +267,9 @@ class ReduceGradOp : public framework::OperatorWithKernel { ...@@ -267,9 +267,9 @@ class ReduceGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -25,8 +25,9 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType( ...@@ -25,8 +25,9 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType(
framework::LibraryType library_ = framework::LibraryType::kMKLDNN; framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN; framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.GetPlace(), layout_, library_); OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
} }
void ReQuantOpMaker::Make() { void ReQuantOpMaker::Make() {
......
...@@ -203,8 +203,9 @@ class ReshapeOp : public framework::OperatorWithKernel { ...@@ -203,8 +203,9 @@ class ReshapeOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
...@@ -305,8 +306,9 @@ class ReshapeGradOp : public framework::OperatorWithKernel { ...@@ -305,8 +306,9 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -475,9 +477,9 @@ class Reshape2GradOp : public framework::OperatorWithKernel { ...@@ -475,9 +477,9 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
...@@ -511,8 +513,9 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel { ...@@ -511,8 +513,9 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("DDX")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "DDX"),
ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
......
...@@ -65,8 +65,9 @@ class ROIAlignOp : public framework::OperatorWithKernel { ...@@ -65,8 +65,9 @@ class ROIAlignOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -85,8 +86,9 @@ class ROIAlignGradOp : public framework::OperatorWithKernel { ...@@ -85,8 +86,9 @@ class ROIAlignGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("ROIs")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "ROIs"),
ctx.device_context());
} }
}; };
......
...@@ -70,8 +70,9 @@ class ROIPoolOp : public framework::OperatorWithKernel { ...@@ -70,8 +70,9 @@ class ROIPoolOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -90,8 +91,9 @@ class ROIPoolGradOp : public framework::OperatorWithKernel { ...@@ -90,8 +91,9 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -162,7 +162,7 @@ class SampleLogitsOp : public framework::OperatorWithKernel { ...@@ -162,7 +162,7 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Logits");
framework::OpKernelType kt = framework::OpKernelType kt =
framework::OpKernelType(data_type, ctx.device_context()); framework::OpKernelType(data_type, ctx.device_context());
return kt; return kt;
...@@ -201,8 +201,8 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel { ...@@ -201,8 +201,8 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar( auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx.InputVar(framework::GradVarName("SampledLogits"))); ctx, framework::GradVarName("SampledLogits"));
framework::OpKernelType kt = framework::OpKernelType kt =
framework::OpKernelType(data_type, ctx.device_context()); framework::OpKernelType(data_type, ctx.device_context());
return kt; return kt;
......
...@@ -69,8 +69,8 @@ class ScatterNdAddOp : public framework::OperatorWithKernel { ...@@ -69,8 +69,8 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->type(), PADDLE_ENFORCE_EQ(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.Input<Tensor>("Updates")->type(), OperatorWithKernel::IndicateVarDataType(ctx, "Updates"),
"Ref and Updates must have same type"); "Ref and Updates must have same type");
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context()); ctx.device_context());
...@@ -95,9 +95,9 @@ class ScatterNdAddGradOp : public framework::OperatorWithKernel { ...@@ -95,9 +95,9 @@ class ScatterNdAddGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -48,8 +48,9 @@ class ScatterOp : public framework::OperatorWithKernel { ...@@ -48,8 +48,9 @@ class ScatterOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -71,9 +72,9 @@ class ScatterGradOp : public framework::OperatorWithKernel { ...@@ -71,9 +72,9 @@ class ScatterGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/selu_op.h" #include "paddle/fluid/operators/selu_op.h"
#include <memory>
#include <string> #include <string>
#include <unordered_map>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -39,7 +42,7 @@ class SeluOp : public framework::OperatorWithKernel { ...@@ -39,7 +42,7 @@ class SeluOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
...@@ -115,7 +118,7 @@ class SeluGradOp : public framework::OperatorWithKernel { ...@@ -115,7 +118,7 @@ class SeluGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.InputVar("Out")), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace());
} }
}; };
......
...@@ -102,9 +102,9 @@ class SeqConcatGradOp : public framework::OperatorWithKernel { ...@@ -102,9 +102,9 @@ class SeqConcatGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -75,8 +75,8 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel { ...@@ -75,8 +75,8 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
...@@ -153,9 +153,9 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel { ...@@ -153,9 +153,9 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -100,8 +100,8 @@ class SequenceExpandOp : public framework::OperatorWithKernel { ...@@ -100,8 +100,8 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
...@@ -208,9 +208,9 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel { ...@@ -208,9 +208,9 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -40,8 +40,9 @@ class SequenceMaskOp : public framework::OperatorWithKernel { ...@@ -40,8 +40,9 @@ class SequenceMaskOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name, const Tensor& tensor, const std::string& var_name, const Tensor& tensor,
......
...@@ -93,7 +93,7 @@ class SequencePadOp : public framework::OperatorWithKernel { ...@@ -93,7 +93,7 @@ class SequencePadOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
...@@ -199,8 +199,8 @@ class SequencePadGradOp : public framework::OperatorWithKernel { ...@@ -199,8 +199,8 @@ class SequencePadGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar( auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx.InputVar(framework::GradVarName("Out"))); ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
......
...@@ -122,9 +122,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { ...@@ -122,9 +122,9 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -113,8 +113,9 @@ class SequenceScatterOp : public framework::OperatorWithKernel { ...@@ -113,8 +113,9 @@ class SequenceScatterOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
} }
}; };
...@@ -132,9 +133,9 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel { ...@@ -132,9 +133,9 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
platform::CPUPlace()); platform::CPUPlace());
} }
}; };
......
...@@ -51,8 +51,9 @@ class SequenceSliceOp : public framework::OperatorWithKernel { ...@@ -51,8 +51,9 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -71,9 +72,9 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel { ...@@ -71,9 +72,9 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -51,7 +51,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { ...@@ -51,7 +51,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
} }
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<Tensor>("X")->type(), ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_); framework::StringToDataLayout(data_format), library_);
} }
}; };
...@@ -146,7 +146,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { ...@@ -146,7 +146,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
} }
std::string data_format = ctx.Attr<std::string>("data_format"); std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType( return framework::OpKernelType(
ctx.Input<Tensor>("Out")->type(), ctx.GetPlace(), OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_); framework::StringToDataLayout(data_format), library_);
} }
}; };
......
...@@ -90,7 +90,7 @@ class SequenceTopkAvgPoolingGradOp : public framework::OperatorWithKernel { ...@@ -90,7 +90,7 @@ class SequenceTopkAvgPoolingGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
......
...@@ -67,7 +67,7 @@ class SequenceUnpadOp : public framework::OperatorWithKernel { ...@@ -67,7 +67,7 @@ class SequenceUnpadOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X")); auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
...@@ -132,8 +132,8 @@ class SequenceUnpadGradOp : public framework::OperatorWithKernel { ...@@ -132,8 +132,8 @@ class SequenceUnpadGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar( auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx.InputVar(framework::GradVarName("Out"))); ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
}; };
......
...@@ -41,8 +41,9 @@ class ShardIndexOp : public framework::OperatorWithKernel { ...@@ -41,8 +41,9 @@ class ShardIndexOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -35,8 +35,9 @@ class ShuffleChannelOp : public framework::OperatorWithKernel { ...@@ -35,8 +35,9 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -83,9 +84,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel { ...@@ -83,9 +84,9 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -70,8 +70,9 @@ class SimilarityFocusOp : public framework::OperatorWithKernel { ...@@ -70,8 +70,9 @@ class SimilarityFocusOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
platform::CPUPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
} }
}; };
......
...@@ -128,8 +128,9 @@ class SliceOp : public framework::OperatorWithKernel { ...@@ -128,8 +128,9 @@ class SliceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor, const std::string &var_name, const Tensor &tensor,
...@@ -243,9 +244,9 @@ class SliceOpGrad : public framework::OperatorWithKernel { ...@@ -243,9 +244,9 @@ class SliceOpGrad : public framework::OperatorWithKernel {
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor, const std::string &var_name, const Tensor &tensor,
......
...@@ -76,7 +76,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { ...@@ -76,7 +76,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
} }
#endif #endif
auto input_data_type = ctx.Input<Tensor>("X")->type(); auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"float16 can only be used on GPU place"); "float16 can only be used on GPU place");
...@@ -187,8 +187,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { ...@@ -187,8 +187,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
auto input_data_type = auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(); ctx, framework::GradVarName("Out"));
if (input_data_type == framework::proto::VarType::FP16) { if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()), PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"float16 can only be used on GPU place"); "float16 can only be used on GPU place");
......
...@@ -171,8 +171,9 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -171,8 +171,9 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
ctx.device_context());
} }
}; };
...@@ -232,9 +233,9 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { ...@@ -232,9 +233,9 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Loss"))->type(), ctx, framework::GradVarName("Loss")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -167,9 +167,9 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel { ...@@ -167,9 +167,9 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -77,8 +77,8 @@ class SpectralNormOp : public framework::OperatorWithKernel { ...@@ -77,8 +77,8 @@ class SpectralNormOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Weight")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace());
} }
}; };
...@@ -209,8 +209,8 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { ...@@ -209,8 +209,8 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Weight")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace());
} }
}; };
......
...@@ -152,8 +152,9 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { ...@@ -152,8 +152,9 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("sub_result")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "sub_result"),
ctx.GetPlace());
} }
}; };
......
...@@ -104,8 +104,9 @@ class SqueezeOp : public framework::OperatorWithKernel { ...@@ -104,8 +104,9 @@ class SqueezeOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -122,8 +123,9 @@ class SqueezeGradOp : public framework::OperatorWithKernel { ...@@ -122,8 +123,9 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -230,9 +232,9 @@ class Squeeze2GradOp : public framework::OperatorWithKernel { ...@@ -230,9 +232,9 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -124,8 +124,9 @@ class StridedSliceOp : public framework::OperatorWithKernel { ...@@ -124,8 +124,9 @@ class StridedSliceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(), return framework::OpKernelType(
ctx.Input<Tensor>("Input")->place()); OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.Input<Tensor>("Input")->place());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor, const std::string &var_name, const Tensor &tensor,
...@@ -230,9 +231,9 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel { ...@@ -230,9 +231,9 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar( framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor, const std::string &var_name, const Tensor &tensor,
......
...@@ -55,8 +55,9 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel { ...@@ -55,8 +55,9 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel {
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -125,8 +126,9 @@ class TeacherStudentSigmoidLossGradientOp ...@@ -125,8 +126,9 @@ class TeacherStudentSigmoidLossGradientOp
// is determined by its input "X". // is determined by its input "X".
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
......
...@@ -56,8 +56,8 @@ class TemporalShiftOp : public framework::OperatorWithKernel { ...@@ -56,8 +56,8 @@ class TemporalShiftOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
}; };
...@@ -139,9 +139,9 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel { ...@@ -139,9 +139,9 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace()); ctx.GetPlace());
} }
}; };
......
...@@ -53,8 +53,9 @@ class TopkOp : public framework::OperatorWithKernel { ...@@ -53,8 +53,9 @@ class TopkOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context(), layout_, library_); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(),
layout_, library_);
} }
}; };
......
...@@ -78,8 +78,9 @@ class TransposeOp : public framework::OperatorWithKernel { ...@@ -78,8 +78,9 @@ class TransposeOp : public framework::OperatorWithKernel {
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace(), layout_, library_); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
} }
}; };
...@@ -164,9 +165,9 @@ class TransposeOpGrad : public framework::OperatorWithKernel { ...@@ -164,9 +165,9 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace(), layout_, library_); ctx.GetPlace(), layout_, library_);
} }
}; };
...@@ -210,8 +211,9 @@ class Transpose2Op : public TransposeOp { ...@@ -210,8 +211,9 @@ class Transpose2Op : public TransposeOp {
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.GetPlace(), layout_, library_); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
} }
}; };
...@@ -268,9 +270,9 @@ class Transpose2OpGrad : public framework::OperatorWithKernel { ...@@ -268,9 +270,9 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
layout_ = framework::DataLayout::kMKLDNN; layout_ = framework::DataLayout::kMKLDNN;
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.GetPlace(), layout_, library_); ctx.GetPlace(), layout_, library_);
} }
}; };
......
...@@ -104,8 +104,9 @@ class TreeConvOp : public framework::OperatorWithKernel { ...@@ -104,8 +104,9 @@ class TreeConvOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("NodesVector")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"),
ctx.device_context());
} }
}; };
...@@ -153,8 +154,9 @@ class TreeConvGradOp : public framework::OperatorWithKernel { ...@@ -153,8 +154,9 @@ class TreeConvGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("NodesVector")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"),
ctx.device_context());
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -120,8 +120,9 @@ class UnfoldOp : public framework::OperatorWithKernel { ...@@ -120,8 +120,9 @@ class UnfoldOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
}; };
...@@ -141,9 +142,9 @@ class UnfoldGradOp : public framework::OperatorWithKernel { ...@@ -141,9 +142,9 @@ class UnfoldGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::Tensor>(framework::GradVarName("Y"))->type(), ctx, framework::GradVarName("Y")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -74,8 +74,9 @@ class UnpoolOp : public framework::OperatorWithKernel { ...@@ -74,8 +74,9 @@ class UnpoolOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
public: public:
...@@ -117,8 +118,9 @@ class UnpoolOpGrad : public framework::OperatorWithKernel { ...@@ -117,8 +118,9 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
} }
public: public:
......
...@@ -215,9 +215,9 @@ class Unsqueeze2GradOp : public framework::OperatorWithKernel { ...@@ -215,9 +215,9 @@ class Unsqueeze2GradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType( return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(), ctx, framework::GradVarName("Out")),
ctx.device_context()); ctx.device_context());
} }
}; };
......
...@@ -60,8 +60,9 @@ class WarpCTCOp : public framework::OperatorWithKernel { ...@@ -60,8 +60,9 @@ class WarpCTCOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain}; framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(), return framework::OpKernelType(
ctx.device_context(), layout_, library_); OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
ctx.device_context(), layout_, library_);
} }
}; };
...@@ -173,8 +174,9 @@ class WarpCTCGradOp : public framework::OperatorWithKernel { ...@@ -173,8 +174,9 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(), return framework::OpKernelType(
ctx.device_context()); OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
ctx.device_context());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册