“726337fc2ddf01bea512c58d9d3b0e7f4763891b”上不存在“develop/doc/howto/cluster/multi_cluster/openmpi_en.html”
未验证 提交 d29e9aa4 编写于 作者: C Chen Weihang 提交者: GitHub

[Cherry-pick to 1.6] Block part of "tensor should not be null" error message (#20845)

* Add IndicateVarDataType interface to block tensor is not initialized problem in OP GetExceptedKernelType (#20044)

* add indicate_var_data_type inferface, test=develop

* add unittests & polish error message, test=develop

* remove needless include, test=develop

* extract public function & polish message, test=develop

* delete empty var check, test=develop

* change data_type to pointer parameter, test=develop

* polish details, test=develop

* Replace risky GetInputType method with secure IndicateVarDataType interface (#20668)

* replace part of the old implementation, test=develop

* restore concat op, test=develop

* update all ops implemention & delete GetDataTypeOfVar func, test=develop

test=release/1.6
上级 e1e58450
......@@ -48,16 +48,6 @@ std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
};
proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
if (var->IsType<framework::LoDTensor>()) {
return var->Get<framework::LoDTensor>().type();
} else if (var->IsType<framework::SelectedRows>()) {
return var->Get<framework::SelectedRows>().value().type();
} else {
PADDLE_THROW("Var should be LoDTensor or SelectedRows");
}
}
static DDim GetDimsDebug(const Scope& scope, const std::string& name,
bool get_actual_dim = false) {
Variable* var = scope.FindVar(name);
......@@ -1152,13 +1142,12 @@ Scope* OperatorWithKernel::PrepareData(
return new_scope;
}
proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
void OperatorWithKernel::ParseInputDataType(
const ExecutionContext& ctx, const std::string& name,
proto::VarType::Type* data_type) const {
proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
for (auto& input : this->inputs_) {
const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first);
const std::vector<const Variable*> vars = ctx.MultiInputVar(name);
for (size_t i = 0; i < vars.size(); ++i) {
const Variable* var = vars[i];
if (var != nullptr) {
......@@ -1171,21 +1160,47 @@ proto::VarType::Type OperatorWithKernel::IndicateDataType(
t = &(var->Get<SelectedRows>().value());
}
if (t != nullptr) {
PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu) is not initialized",
input.first, i);
PADDLE_ENFORCE_EQ(t->IsInitialized(), true,
"The Tensor in the %s Op's Input Variable %s(%s) is "
"not initialized.",
Type(), name, ctx.Inputs(name).at(i));
proto::VarType::Type tmp = t->type();
PADDLE_ENFORCE(
tmp == data_type || data_type == dafault_data_type,
"DataType of Paddle Op %s %s must be the same. Get (%s) != (%s)",
Type(), input.first, DataTypeToString(data_type),
DataTypeToString(tmp));
data_type = tmp;
PADDLE_ENFORCE(tmp == *data_type || *data_type == dafault_data_type,
"The DataType of %s Op's duplicable Variable %s must be "
"consistent. The current variable type is (%s), but the "
"previous variable type is (%s).",
Type(), name, DataTypeToString(tmp),
DataTypeToString(*data_type));
*data_type = tmp;
}
}
}
}
proto::VarType::Type OperatorWithKernel::IndicateDataType(
const ExecutionContext& ctx) const {
proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
for (auto& input : this->inputs_) {
ParseInputDataType(ctx, input.first, &data_type);
}
PADDLE_ENFORCE(data_type != dafault_data_type,
"DataType should be indicated by input");
PADDLE_ENFORCE_NE(data_type, dafault_data_type,
"DataType should be indicated by input Variable.");
return data_type;
}
proto::VarType::Type OperatorWithKernel::IndicateVarDataType(
const ExecutionContext& ctx, const std::string& name) const {
proto::VarType::Type dafault_data_type =
static_cast<proto::VarType::Type>(-1);
proto::VarType::Type data_type = dafault_data_type;
ParseInputDataType(ctx, name, &data_type);
PADDLE_ENFORCE_NE(
data_type, dafault_data_type,
"The Input Variable(%s) of %s Op used to determine kernel data type "
"is empty or not LoDTensor or SelectedRows.",
name, Type());
return data_type;
}
......
......@@ -102,7 +102,6 @@ inline std::string GradOriginalVarName(const std::string& grad_var_name) {
}
}
proto::VarType::Type GetDataTypeOfVar(const Variable* var);
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
......@@ -459,6 +458,9 @@ class OperatorWithKernel : public OperatorBase {
void RuntimeInferShape(const Scope& scope, const platform::Place& place,
const RuntimeContext& ctx) const override;
proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx,
const std::string& name) const;
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
std::vector<KernelConfig>* GetKernelConfig(const OpKernelType& key) const;
......@@ -470,6 +472,8 @@ class OperatorWithKernel : public OperatorBase {
const OpKernelType& expected_kernel_type) const;
private:
void ParseInputDataType(const ExecutionContext& ctx, const std::string& name,
proto::VarType::Type* type) const;
// indicate kernel DataType by input data. By default all input data must be
// same.
proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
......
......@@ -315,3 +315,182 @@ TEST(VarNameTest, all) {
original_var_name = paddle::framework::GradOriginalVarName(original_var_name);
ASSERT_EQ(original_var_name, "");
}
namespace paddle {
namespace framework {
class IndicateLoDTensorDataTypeTest : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "LoDTensor");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class IndicateLoDTensorDataTypeTestProtoMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("LoDTensor", "Input of Tensor type Variable.");
AddComment("This Op is only for IndicateVarDataType inferface test.");
}
};
class IndicateSelectedRowsDataTypeTest : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
auto data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "SelectedRows");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class IndicateSelectedRowsDataTypeTestProtoMaker
: public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("SelectedRows", "Input of SelectedRows type Variable.");
AddComment("This Op is only for IndicateVarDataType inferface test.");
}
};
class IndicateOtherDataTypeTest : public OperatorWithKernel {
public:
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {}
OpKernelType GetExpectedKernelType(
const ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Other");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
class IndicateOtherDataTypeTestProtoMaker : public OpProtoAndCheckerMaker {
public:
void Make() {
AddInput("Other", "Input of Other type Variable");
AddComment("This Op is only for IndicateVarDataType inferface test.");
}
};
template <typename DeviceContext, typename T>
class IndicateVarDataTypeKernelTest : public OpKernel<T> {
public:
void Compute(const ExecutionContext& ctx) const {}
};
} // namespace framework
} // namespace paddle
REGISTER_OP_WITHOUT_GRADIENT(
indicate_lod_tensor_data_type_test,
paddle::framework::IndicateLoDTensorDataTypeTest,
paddle::framework::IndicateLoDTensorDataTypeTestProtoMaker);
REGISTER_OP_WITHOUT_GRADIENT(
indicate_selected_rows_data_type_test,
paddle::framework::IndicateSelectedRowsDataTypeTest,
paddle::framework::IndicateSelectedRowsDataTypeTestProtoMaker);
REGISTER_OP_WITHOUT_GRADIENT(
indicate_other_data_type_test, paddle::framework::IndicateOtherDataTypeTest,
paddle::framework::IndicateOtherDataTypeTestProtoMaker);
REGISTER_OP_CPU_KERNEL(indicate_lod_tensor_data_type_test,
paddle::framework::IndicateVarDataTypeKernelTest<
paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(indicate_selected_rows_data_type_test,
paddle::framework::IndicateVarDataTypeKernelTest<
paddle::platform::CPUDeviceContext, int>);
REGISTER_OP_CPU_KERNEL(indicate_other_data_type_test,
paddle::framework::IndicateVarDataTypeKernelTest<
paddle::platform::CPUDeviceContext, int>);
TEST(IndicateVarDataTypeTest, lodtensor) {
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("indicate_lod_tensor_data_type_test");
BuildVar("LoDTensor", {"lodtensor_1"}, op_desc.add_inputs());
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("lodtensor_1");
var->GetMutable<paddle::framework::LoDTensor>();
bool caught = false;
try {
op->Run(scope, cpu_place);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(
ex_msg.find(
"The Tensor in the indicate_lod_tensor_data_type_test Op's "
"Input Variable LoDTensor(lodtensor_1) is not initialized") !=
std::string::npos);
}
ASSERT_TRUE(caught);
}
TEST(IndicateVarDataTypeTest, selectedrows) {
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("indicate_selected_rows_data_type_test");
BuildVar("SelectedRows", {"selected_rows_1"}, op_desc.add_inputs());
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("selected_rows_1");
var->GetMutable<paddle::framework::SelectedRows>();
bool caught = false;
try {
op->Run(scope, cpu_place);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(
ex_msg.find("The Tensor in the indicate_selected_rows_data_type_test "
"Op's Input Variable SelectedRows(selected_rows_1) is not "
"initialized") != std::string::npos);
}
ASSERT_TRUE(caught);
}
TEST(IndicateVarDataTypeTest, other) {
paddle::framework::InitDevices(true);
paddle::framework::proto::OpDesc op_desc;
op_desc.set_type("indicate_other_data_type_test");
BuildVar("Other", {"lod_tensor_array_1"}, op_desc.add_inputs());
paddle::platform::CPUPlace cpu_place;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto* var = scope.Var("lod_tensor_array_1");
var->GetMutable<paddle::framework::LoDTensorArray>();
bool caught = false;
try {
op->Run(scope, cpu_place);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string ex_msg = err.what();
EXPECT_TRUE(ex_msg.find("The Input Variable(Other) of "
"indicate_other_data_type_test Op used to "
"determine kernel data type "
"is empty or not LoDTensor or SelectedRows") !=
std::string::npos);
}
ASSERT_TRUE(caught);
}
......@@ -30,9 +30,9 @@ class Variable {
static_assert(
IsRegisteredVarType<T>(),
"Not registered type. Please register T inside var_type_traits.h");
PADDLE_ENFORCE(holder_ != nullptr, "Variable must hold some thing");
PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized.");
PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
"Variable must be type %s, the holding type is %s",
"The Variable type must be %s, but the type it holds is %s.",
ToTypeName(VarTypeTrait<T>::kId),
ToTypeName(holder_->Type()));
return *static_cast<const T*>(holder_->Ptr());
......@@ -45,10 +45,10 @@ class Variable {
if (!holder_) {
holder_.reset(new PlaceholderImpl<T>());
} else {
PADDLE_ENFORCE(holder_->Type() == VarTypeTrait<T>::kId,
"Variable must be type %s, the holding type is %s",
ToTypeName(VarTypeTrait<T>::kId),
ToTypeName(holder_->Type()));
PADDLE_ENFORCE(
holder_->Type() == VarTypeTrait<T>::kId,
"The Variable type must be %s, but the type it holds is %s.",
ToTypeName(VarTypeTrait<T>::kId), ToTypeName(holder_->Type()));
}
return static_cast<T*>(holder_->Ptr());
}
......@@ -61,7 +61,7 @@ class Variable {
void Clear() { holder_.reset(); }
int Type() const {
PADDLE_ENFORCE(holder_ != nullptr, "Must hold memory");
PADDLE_ENFORCE(holder_ != nullptr, "Variable is not initialized.");
return holder_->Type();
}
......
......@@ -114,9 +114,8 @@ framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
layout = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.InputVar(name)), ctx.GetPlace(), layout,
library);
return framework::OpKernelType(oper.IndicateVarDataType(ctx, name),
ctx.GetPlace(), layout, library);
}
class ActivationOp : public framework::OperatorWithKernel {
......
......@@ -37,7 +37,8 @@ class AddPositionEncodingOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......@@ -56,8 +57,8 @@ class AddPositionEncodingOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
platform::CPUPlace());
}
};
......
......@@ -121,8 +121,8 @@ class AffineChannelOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
......
......@@ -80,7 +80,7 @@ class AffineGridOp : public framework::OperatorWithKernel {
library = framework::LibraryType::kCUDNN;
}
#endif
auto data_type = ctx.Input<Tensor>("Theta")->type();
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Theta");
return framework::OpKernelType(data_type, ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library);
}
......@@ -191,8 +191,8 @@ class AffineGridOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("Theta")->type(),
ctx.GetPlace(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Theta"), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
......
......@@ -89,7 +89,8 @@ class AssignOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -129,8 +129,8 @@ void AttentionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType AttentionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
}
void AttentionLSTMOpMaker::Make() {
......
......@@ -103,8 +103,8 @@ class AverageAccumulatesOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("param")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "param"), ctx.GetPlace());
}
};
......
......@@ -115,7 +115,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType BatchNormOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto input_data_type = ctx.Input<Tensor>("X")->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
......@@ -432,8 +432,9 @@ framework::OpKernelType BatchNormGradOp::GetExpectedKernelType(
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), layout,
library);
}
template <typename T>
......
......@@ -109,10 +109,11 @@ class BeamSearchOp : public framework::OperatorWithKernel {
// Compute on CPU for cases with batch_size > 4.
if (batch_size <= 4) {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("pre_ids")->type(), ctx.GetPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"),
ctx.GetPlace());
} else {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("pre_ids")->type(),
OperatorWithKernel::IndicateVarDataType(ctx, "pre_ids"),
platform::CPUPlace());
}
}
......
......@@ -52,7 +52,8 @@ class BprLossOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......@@ -98,7 +99,8 @@ class BprLossGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......
......@@ -61,7 +61,8 @@ class CenterLossOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -117,7 +118,8 @@ class CenterLossGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>("SampleCenterDiff")->type(), ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "SampleCenterDiff"),
ctx.device_context());
}
};
......
......@@ -41,8 +41,8 @@ class CAllReduceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......
......@@ -28,8 +28,8 @@ class CBroadcastOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......
......@@ -102,7 +102,6 @@ class ConcatOp : public framework::OperatorWithKernel {
if (flag == 0) {
PADDLE_THROW("All Inputs of Concat OP are Empty!");
}
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
......@@ -175,8 +174,8 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
......
......@@ -135,7 +135,7 @@ framework::OpKernelType ConvOp::GetExpectedKernelType(
framework::OpKernelType::kDefaultCustomizedTypeValue;
framework::LibraryType library{framework::LibraryType::kPlain};
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
auto input_data_type = ctx.Input<Tensor>("Input")->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Input");
std::string data_format =
"AnyLayout"; // todo enable data layout when it's ready
framework::DataLayout layout = framework::StringToDataLayout(data_format);
......@@ -527,9 +527,9 @@ framework::OpKernelType ConvOpGrad::GetExpectedKernelType(
}
#endif
auto type = framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_,
customized_type_value);
auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_, customized_type_value);
#ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
......@@ -704,9 +704,9 @@ framework::OpKernelType ConvOpDoubleGrad::GetExpectedKernelType(
customized_type_value = kConvMKLDNNFP32;
}
#endif
auto type = framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_,
customized_type_value);
auto type = framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_, customized_type_value);
#ifdef PADDLE_WITH_CUDA
if (library_ == framework::LibraryType::kCUDNN) {
std::vector<framework::KernelConfig>& configs = kernel_configs_map_[type];
......
......@@ -132,8 +132,9 @@ framework::OpKernelType ConvTransposeOp::GetExpectedKernelType(
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
}
void Conv2DTransposeOpMaker::Make() {
......@@ -384,8 +385,9 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
}
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
}
class ConvTransposeGradOpDescMaker : public framework::SingleGradOpDescMaker {
......
......@@ -160,7 +160,8 @@ class CRFDecodingOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Emission"),
platform::CPUPlace());
}
};
......
......@@ -53,7 +53,8 @@ class CropOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -174,8 +175,8 @@ class CropOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -87,7 +87,8 @@ class CropTensorOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
......@@ -243,8 +244,8 @@ class CropTensorOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
......
......@@ -107,7 +107,8 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
......@@ -157,8 +158,8 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Y"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Y")),
ctx.device_context());
}
......
......@@ -39,7 +39,8 @@ class CTCAlignOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
......
......@@ -52,7 +52,8 @@ class CVMOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......@@ -93,7 +94,8 @@ class CVMGradientOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......
......@@ -81,7 +81,7 @@ class DataNormOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = ctx.Input<Tensor>("X")->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
......@@ -89,12 +89,14 @@ class DataNormOp : public framework::OperatorWithKernel {
if (input_data_type == framework::proto::VarType::FP64) {
dn_param_type = framework::proto::VarType::FP64;
}
PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input<Tensor>("BatchSize")->type(),
PADDLE_ENFORCE_EQ(dn_param_type,
OperatorWithKernel::IndicateVarDataType(ctx, "BatchSize"),
"BatchSize input should be of float type");
PADDLE_ENFORCE_EQ(dn_param_type, ctx.Input<Tensor>("BatchSum")->type(),
"BatchSum input should be of float type");
PADDLE_ENFORCE_EQ(dn_param_type,
ctx.Input<Tensor>("BatchSquareSum")->type(),
OperatorWithKernel::IndicateVarDataType(ctx, "BatchSum"),
"BatchSum input should be of float type");
PADDLE_ENFORCE_EQ(dn_param_type, OperatorWithKernel::IndicateVarDataType(
ctx, "BatchSquareSum"),
"BatchSquareSum input should be of float type");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
......@@ -276,8 +278,9 @@ class DataNormGradOp : public framework::OperatorWithKernel {
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(), layout, library);
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
}
};
......
......@@ -216,7 +216,8 @@ class DeformableConvOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
......@@ -275,7 +276,8 @@ class DeformableConvGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
......
......@@ -199,7 +199,8 @@ class DeformableConvV1Op : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
......@@ -253,7 +254,8 @@ class DeformableConvV1GradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
......
......@@ -199,7 +199,8 @@ class DeformablePSROIPoolOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
......@@ -247,7 +248,8 @@ class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Trans")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Trans"),
ctx.device_context());
}
};
......
......@@ -25,8 +25,9 @@ framework::OpKernelType DeQuantOp::GetExpectedKernelType(
framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
}
void DeQuantOpMaker::Make() {
......
......@@ -53,7 +53,8 @@ class AnchorGeneratorOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>("Input")->type(), ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
......
......@@ -45,7 +45,8 @@ class BipartiteMatchOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<LoDTensor>("DistMat")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "DistMat"),
platform::CPUPlace());
}
};
......
......@@ -68,7 +68,7 @@ class CollectFpnProposalsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type =
framework::GetDataTypeOfVar(ctx.MultiInputVar("MultiLevelRois")[0]);
OperatorWithKernel::IndicateVarDataType(ctx, "MultiLevelRois");
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
......
......@@ -66,7 +66,7 @@ class DensityPriorBoxOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>("Input")->type(), ctx.GetPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace());
}
};
......
......@@ -46,7 +46,7 @@ class DistributeFpnProposalsOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("FpnRois"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "FpnRois");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......
......@@ -80,7 +80,7 @@ class GenerateMaskLabelsOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Rois"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Rois");
return framework::OpKernelType(data_type, platform::CPUPlace());
}
};
......
......@@ -87,7 +87,7 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("RpnRois"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "RpnRois");
return framework::OpKernelType(data_type, platform::CPUPlace());
}
};
......
......@@ -60,7 +60,8 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Anchors")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Anchors"),
ctx.device_context());
}
};
......
......@@ -255,7 +255,8 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>("ClsLoss")->type(), platform::CPUPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "ClsLoss"),
platform::CPUPlace());
}
};
......
......@@ -80,7 +80,7 @@ class MultiClassNMSOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Scores")->type(),
OperatorWithKernel::IndicateVarDataType(ctx, "Scores"),
platform::CPUPlace());
}
};
......
......@@ -69,7 +69,8 @@ class PriorBoxOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_input_type = ctx.Input<framework::Tensor>("Input")->type();
auto input_input_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Input");
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
......
......@@ -94,8 +94,7 @@ class RetinanetDetectionOutputOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::GetDataTypeOfVar(ctx.MultiInputVar("Scores")[0]);
OperatorWithKernel::IndicateVarDataType(ctx, "Scores");
return framework::OpKernelType(input_data_type,
platform::CPUPlace()); // ctx.GetPlace());
}
......
......@@ -525,7 +525,8 @@ class ROIPerspectiveTransformOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -545,7 +546,8 @@ class ROIPerspectiveTransformGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -77,7 +77,7 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Anchor")->type(),
OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"),
platform::CPUPlace());
}
};
......@@ -726,7 +726,7 @@ class RetinanetTargetAssignOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Anchor")->type(),
OperatorWithKernel::IndicateVarDataType(ctx, "Anchor"),
platform::CPUPlace());
}
};
......
......@@ -63,7 +63,8 @@ class SigmoidFocalLossOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -116,7 +117,8 @@ class SigmoidFocalLossGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -57,7 +57,8 @@ class TargetAssignOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -65,8 +65,8 @@ class YoloBoxOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......
......@@ -98,7 +98,8 @@ class Yolov3LossOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......@@ -255,7 +256,8 @@ class Yolov3LossOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......
......@@ -73,7 +73,7 @@ class DetectionMAPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>("DetectRes")->type(),
OperatorWithKernel::IndicateVarDataType(ctx, "DetectRes"),
platform::CPUPlace());
}
};
......
......@@ -29,8 +29,8 @@ class AllReduceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......
......@@ -72,7 +72,7 @@ class DistributedLookupTableOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......
......@@ -108,7 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.MultiInput<framework::Tensor>("X").front()->type(), ctx.GetPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......
......@@ -42,7 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.MultiInput<framework::Tensor>("X")[0]->type(), ctx.GetPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......
......@@ -66,8 +66,7 @@ class SplitIdsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.MultiInputVar("Ids").front()),
ctx.GetPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "Ids"), ctx.GetPlace());
}
};
......
......@@ -121,8 +121,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
......
......@@ -153,7 +153,7 @@ class ElementwiseDivOpDoubleGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = ctx.Input<Tensor>("DDX")->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
......
......@@ -82,7 +82,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
......@@ -236,8 +236,8 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
ctx.Input<Tensor>(framework::GradVarName("Out"))->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
......@@ -274,7 +274,7 @@ class ElementwiseOpDoubleGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = ctx.Input<Tensor>("DOut")->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DOut");
#ifdef PADDLE_WITH_MKLDNN
if (platform::CanMKLDNNBeUsed(ctx)) {
......@@ -306,13 +306,13 @@ class ElementwiseOpDoubleGradWithoutDXDY
if (ctx.HasInput("DDX") == false) {
PADDLE_ENFORCE_EQ(ctx.HasInput("DDY"), true,
"Input(DDY) should not be null");
input_data_type = ctx.Input<Tensor>("DDY")->type();
input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDY");
} else if (ctx.HasInput("DDY") == false) {
PADDLE_ENFORCE_EQ(ctx.HasInput("DDX"), true,
"Input(DDX) should not be null");
input_data_type = ctx.Input<Tensor>("DDX")->type();
input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
} else {
input_data_type = ctx.Input<Tensor>("DDX")->type();
input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "DDX");
}
#ifdef PADDLE_WITH_MKLDNN
......
......@@ -65,7 +65,8 @@ class ExpandOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
......@@ -180,8 +181,8 @@ class ExpandGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
......
......@@ -190,7 +190,8 @@ class FakeQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -241,8 +242,8 @@ class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......@@ -303,7 +304,8 @@ class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -375,7 +377,8 @@ class FakeQuantOrWithDequantMovingAverageAbsMaxOp
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -450,8 +453,8 @@ class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......
......@@ -80,8 +80,9 @@ class FCOp : public framework::OperatorWithKernel {
library = framework::LibraryType::kMKLDNN;
layout = framework::DataLayout::kMKLDNN;
}
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout, library);
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout, library);
}
};
......
......@@ -48,7 +48,7 @@ class FilterByInstagOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Ins"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Ins");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......@@ -101,8 +101,8 @@ class FilterByInstagOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("Out")));
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......
......@@ -69,7 +69,8 @@ class FlattenOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -130,7 +131,8 @@ class FlattenGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -221,8 +223,8 @@ class Flatten2GradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -49,8 +49,9 @@ class FSPOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context(), layout_, library_);
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(),
layout_, library_);
}
};
......@@ -107,8 +108,8 @@ class FSPOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -140,8 +140,8 @@ class FusedElemwiseActivationOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(ctx.Input<framework::Tensor>("X")->type(),
ctx.Input<framework::Tensor>("Y")->type(),
"The element's type of input should be the same.");
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......@@ -328,8 +328,8 @@ class FusedElemwiseActivationOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("Y")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Y"), ctx.GetPlace());
}
};
} // namespace operators
......
......@@ -114,7 +114,7 @@ void FusedEmbeddingFCLSTMOp::InferShape(
framework::OpKernelType FusedEmbeddingFCLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Embeddings")->type(),
OperatorWithKernel::IndicateVarDataType(ctx, "Embeddings"),
ctx.device_context());
}
......
......@@ -56,7 +56,7 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......@@ -125,7 +125,7 @@ class FusedEmbeddingSeqPoolOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......
......@@ -58,7 +58,8 @@ class ConvInceptionFusionOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
......
......@@ -93,8 +93,8 @@ void FusionGRUOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionGRUOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
}
void FusionGRUOpMaker::Make() {
......
......@@ -117,8 +117,8 @@ void FusionLSTMOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType FusionLSTMOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
}
void FusionLSTMOpMaker::Make() {
......
......@@ -60,8 +60,8 @@ void FusionRepeatedFCReluOp::InferShape(
framework::OpKernelType FusionRepeatedFCReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
void FusionRepeatedFCReluOpMaker::Make() {
......
......@@ -61,8 +61,8 @@ void FusionSeqConvEltAddReluOp::InferShape(
framework::OpKernelType FusionSeqConvEltAddReluOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.device_context());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
}
void FusionSeqConvEltAddReluOpMaker::Make() {
......
......@@ -67,8 +67,8 @@ void FusionSeqExpandConcatFCOp::InferShape(
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(ctx.MultiInput<LoDTensor>("X")[0]->type(),
ctx.device_context());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context());
}
void FusionSeqExpandConcatFCOpMaker::Make() {
......
......@@ -47,7 +47,7 @@ void FusionSeqPoolConcatOp::InferShape(
framework::OpKernelType FusionSeqPoolConcatOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
void FusionSeqPoolConcatOpMaker::Make() {
......
......@@ -52,7 +52,7 @@ void FusionSeqPoolCVMConcatOp::InferShape(
framework::OpKernelType FusionSeqPoolCVMConcatOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.MultiInputVar("X")[0]), ctx.GetPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
void FusionSeqPoolCVMConcatOpMaker::Make() {
......
......@@ -53,8 +53,8 @@ void FusionSquaredMatSubOp::InferShape(
framework::OpKernelType FusionSquaredMatSubOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(framework::GetDataTypeOfVar(ctx.InputVar("X")),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
void FusionSquaredMatSubOpMaker::Make() {
......
......@@ -61,7 +61,7 @@ class GatherNdOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
const auto& x_type = x->type();
const auto& x_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(
x_type,
x_type == framework::proto::VarType::BOOL
......@@ -82,8 +82,8 @@ class GatherNdGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -45,7 +45,8 @@ class GatherOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -62,8 +63,8 @@ class GatherGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -40,7 +40,8 @@ class GatherTreeOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Ids")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Ids"),
ctx.device_context());
}
};
......
......@@ -45,7 +45,8 @@ class GetTensorFromSelectedRowsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -68,8 +68,8 @@ class GridSampleOp : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
......@@ -164,8 +164,8 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
library_ = framework::LibraryType::kCUDNN;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
framework::DataLayout::kAnyLayout, library_);
}
};
......
......@@ -81,8 +81,8 @@ class HierarchicalSigmoidOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......@@ -224,8 +224,8 @@ class HierarchicalSigmoidGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......
......@@ -70,7 +70,7 @@ void InstanceNormOp::InferShape(framework::InferShapeContext *ctx) const {
framework::OpKernelType InstanceNormOp::GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
auto input_data_type = ctx.Input<Tensor>("X")->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
......@@ -236,8 +236,8 @@ framework::OpKernelType InstanceNormGradOp::GetExpectedKernelType(
if (t == nullptr) {
PADDLE_THROW("cannot find Y@GRAD");
}
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
template <typename T>
......@@ -396,8 +396,8 @@ framework::OpKernelType InstanceNormDoubleGradOp::GetExpectedKernelType(
if (t == nullptr) {
PADDLE_THROW("cannot find Y@GRAD");
}
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
std::unique_ptr<framework::OpDesc> InstanceNormDoubleGradMaker::Apply() const {
......
......@@ -204,8 +204,8 @@ class InterpolateOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
......@@ -407,8 +407,8 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
......
......@@ -35,7 +35,8 @@ class IsEmptyOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::LoDTensor>("X");
return framework::OpKernelType(x->type(), x->place());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), x->place());
}
};
......
......@@ -58,8 +58,8 @@ class KLDivLossOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......@@ -136,8 +136,8 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......
......@@ -224,7 +224,8 @@ class LinearChainCRFOp : public framework::OperatorWithKernel {
// is determined by its input "Emission".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<LoDTensor>("Emission")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Emission"),
platform::CPUPlace());
}
};
......@@ -263,7 +264,8 @@ class LinearChainCRFGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<LoDTensor>(framework::GradVarName("LogLikelihood"))->type(),
OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("LogLikelihood")),
platform::CPUPlace());
}
};
......
......@@ -52,8 +52,8 @@ class LinspaceOp : public framework::OperatorWithKernel {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
ctx.Input<framework::Tensor>("Start")->type(), ctx.device_context(),
layout_, library_);
OperatorWithKernel::IndicateVarDataType(ctx, "Start"),
ctx.device_context(), layout_, library_);
}
};
......
......@@ -46,7 +46,8 @@ class LoDResetOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -172,8 +173,8 @@ class LoDResetGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -64,7 +64,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......@@ -166,8 +166,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("Out")));
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......
......@@ -58,7 +58,7 @@ class LookupTableV2Op : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("W"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "W");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......@@ -154,8 +154,8 @@ class LookupTableV2OpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("Out")));
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......
......@@ -130,26 +130,6 @@ struct LRNGradFunctor<platform::CPUDeviceContext, T> {
template struct LRNGradFunctor<platform::CPUDeviceContext, float>;
template struct LRNGradFunctor<platform::CPUDeviceContext, double>;
namespace {
framework::OpKernelType GetExpectedLRNKernel(
const framework::ExecutionContext& ctx) {
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), ctx.GetPlace(),
layout_, library_);
}
} // namespace
class LRNOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -175,7 +155,20 @@ class LRNOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetExpectedLRNKernel(ctx);
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
}
};
......@@ -281,7 +274,20 @@ class LRNOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetExpectedLRNKernel(ctx);
framework::LibraryType library_{framework::LibraryType::kPlain};
std::string data_format = ctx.Attr<std::string>("data_format");
// TODO(pzelazko-intel): enable MKLDNN layout when it's ready
framework::DataLayout layout_ = framework::StringToDataLayout(data_format);
#ifdef PADDLE_WITH_MKLDNN
if (library_ == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
library_ = framework::LibraryType::kMKLDNN;
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
}
};
} // namespace operators
......
......@@ -97,7 +97,8 @@ class LSTMOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
......@@ -261,7 +262,8 @@ class LSTMGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
......
......@@ -109,7 +109,8 @@ class LSTMPOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("Input")->type(), ctx.device_context());
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
};
......@@ -347,7 +348,7 @@ class LSTMPGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("BatchGate")->type(),
OperatorWithKernel::IndicateVarDataType(ctx, "BatchGate"),
ctx.device_context());
}
};
......
......@@ -44,7 +44,8 @@ class MeanIoUOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Predictions")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Predictions"),
ctx.GetPlace());
}
};
......
......@@ -64,8 +64,8 @@ class MeanGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
ctx.Input<Tensor>(framework::GradVarName("Out"))->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -68,8 +68,8 @@ class AccuracyOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Out")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace());
}
};
......
......@@ -53,7 +53,8 @@ class AucOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Predict")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Predict"),
platform::CPUPlace());
}
};
......
......@@ -92,7 +92,8 @@ class PrecisionRecallOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("MaxProbs")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "MaxProbs"),
ctx.device_context());
}
};
......
......@@ -90,7 +90,7 @@ class MulOp : public framework::OperatorWithKernel {
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
int customized_type_value =
framework::OpKernelType::kDefaultCustomizedTypeValue;
auto input_data_type = ctx.Input<Tensor>("X")->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
#ifdef PADDLE_WITH_MKLDNN
if (library == framework::LibraryType::kPlain &&
platform::CanMKLDNNBeUsed(ctx)) {
......
......@@ -55,7 +55,8 @@ class MultiplexOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.MultiInput<Tensor>("X")[0]->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -125,8 +126,8 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -92,7 +92,8 @@ class NCEOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
platform::CPUPlace());
}
};
......@@ -246,7 +247,8 @@ class NCEOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
platform::CPUPlace());
}
};
......
......@@ -51,7 +51,8 @@ class OneHotOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
......
......@@ -48,7 +48,8 @@ class OneHotV2Op : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
......
......@@ -75,8 +75,8 @@ class AdadeltaOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
}
};
......
......@@ -64,8 +64,8 @@ class AdagradOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
}
};
......
......@@ -78,7 +78,7 @@ void AdamOp::InferShape(framework::InferShapeContext* ctx) const {
framework::OpKernelType AdamOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
auto input_data_type = ctx.Input<framework::Tensor>("Param")->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
......
......@@ -81,8 +81,8 @@ class AdamaxOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
}
};
......
......@@ -69,8 +69,8 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
}
};
......
......@@ -55,8 +55,8 @@ class DpsgdOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
}
};
......
......@@ -71,7 +71,8 @@ class FTRLOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type = ctx.Input<Tensor>("Param")->type();
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -85,7 +85,8 @@ class MomentumOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
......
......@@ -58,8 +58,8 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
}
};
......
......@@ -46,8 +46,8 @@ class ProximalGDOp : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Param")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Param"), ctx.GetPlace());
}
};
......
......@@ -48,7 +48,7 @@ class SGDOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(data_type, ctx.device_context());
}
......
......@@ -520,8 +520,8 @@ class Pad2dOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......@@ -621,8 +621,8 @@ class Pad2dOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
......
......@@ -56,7 +56,8 @@ class PadConstantLikeOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Y")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Y"),
ctx.device_context());
}
};
......@@ -186,7 +187,8 @@ class PadConstantLikeOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Y")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Y"),
ctx.device_context());
}
};
......
......@@ -134,7 +134,8 @@ framework::OpKernelType PoolOp::GetExpectedKernelType(
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), ctx.GetPlace(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
}
......@@ -164,7 +165,7 @@ framework::OpKernelType PoolOpGrad::GetExpectedKernelType(
}
#endif
auto input_data_type = ctx.Input<Tensor>("X")->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(library_, framework::LibraryType::kCUDNN,
"float16 can only be used when CUDNN is used");
......
......@@ -76,7 +76,8 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -96,7 +97,8 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -95,7 +95,8 @@ class PositiveNegativePairOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Score")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Score"),
ctx.device_context());
}
};
......
......@@ -56,7 +56,8 @@ class PReluOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -112,7 +113,8 @@ class PReluGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -114,7 +114,8 @@ class PRROIPoolOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -135,7 +136,8 @@ class PRROIPoolGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -131,7 +131,8 @@ class PSROIPoolOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -151,7 +152,8 @@ class PSROIPoolGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -104,9 +104,8 @@ class PushBoxSparseOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.MultiInput<framework::Tensor>(framework::GradVarName("Out"))[0]
->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -25,8 +25,9 @@ framework::OpKernelType QuantOp::GetExpectedKernelType(
framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
}
void QuantOpMaker::Make() {
......
......@@ -22,7 +22,8 @@ class RandomCropOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -267,8 +267,8 @@ class ReduceGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
......
......@@ -25,8 +25,9 @@ framework::OpKernelType ReQuantOp::GetExpectedKernelType(
framework::LibraryType library_ = framework::LibraryType::kMKLDNN;
framework::DataLayout layout_ = framework::DataLayout::kMKLDNN;
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), ctx.GetPlace(),
layout_, library_);
}
void ReQuantOpMaker::Make() {
......
......@@ -200,7 +200,8 @@ class ReshapeOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
......@@ -302,7 +303,8 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -472,8 +474,8 @@ class Reshape2GradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
......@@ -508,7 +510,8 @@ class Reshape2DoubleGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("DDX")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "DDX"),
ctx.device_context());
}
......
......@@ -65,7 +65,8 @@ class ROIAlignOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -85,7 +86,8 @@ class ROIAlignGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("ROIs")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "ROIs"),
ctx.device_context());
}
};
......
......@@ -70,7 +70,8 @@ class ROIPoolOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -90,7 +91,8 @@ class ROIPoolGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -162,7 +162,7 @@ class SampleLogitsOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Logits"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "Logits");
framework::OpKernelType kt =
framework::OpKernelType(data_type, ctx.device_context());
return kt;
......@@ -201,8 +201,8 @@ class SampleLogitsOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("SampledLogits")));
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("SampledLogits"));
framework::OpKernelType kt =
framework::OpKernelType(data_type, ctx.device_context());
return kt;
......
......@@ -31,7 +31,7 @@ class SaveOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......
......@@ -69,8 +69,8 @@ class ScatterNdAddOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->type(),
ctx.Input<Tensor>("Updates")->type(),
PADDLE_ENFORCE_EQ(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
OperatorWithKernel::IndicateVarDataType(ctx, "Updates"),
"Ref and Updates must have same type");
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
......@@ -95,8 +95,8 @@ class ScatterNdAddGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -48,7 +48,8 @@ class ScatterOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -71,8 +72,8 @@ class ScatterGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -13,7 +13,10 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/selu_op.h"
#include <memory>
#include <string>
#include <unordered_map>
namespace paddle {
namespace operators {
......@@ -39,7 +42,7 @@ class SeluOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.InputVar("X")), ctx.GetPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......@@ -115,7 +118,7 @@ class SeluGradOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
framework::GetDataTypeOfVar(ctx.InputVar("Out")), ctx.GetPlace());
OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace());
}
};
......
......@@ -102,8 +102,8 @@ class SeqConcatGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
......
......@@ -75,8 +75,8 @@ class SequenceExpandAsOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......@@ -153,8 +153,8 @@ class SequenceExpandAsOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
......
......@@ -100,8 +100,8 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......@@ -208,8 +208,8 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
......
......@@ -40,7 +40,8 @@ class SequenceMaskOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
......
......@@ -93,7 +93,7 @@ class SequencePadOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......@@ -199,8 +199,8 @@ class SequencePadGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("Out")));
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......
......@@ -122,8 +122,8 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -113,7 +113,8 @@ class SequenceScatterOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......@@ -132,8 +133,8 @@ class SequenceScatterGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
platform::CPUPlace());
}
};
......
......@@ -51,7 +51,8 @@ class SequenceSliceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -71,8 +72,8 @@ class SequenceSliceGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -51,7 +51,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
ctx.Input<Tensor>("X")->type(), ctx.GetPlace(),
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
}
};
......@@ -146,7 +146,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
}
std::string data_format = ctx.Attr<std::string>("data_format");
return framework::OpKernelType(
ctx.Input<Tensor>("X")->type(), ctx.GetPlace(),
OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace(),
framework::StringToDataLayout(data_format), library_);
}
};
......
......@@ -90,7 +90,7 @@ class SequenceTopkAvgPoolingGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......
......@@ -67,7 +67,7 @@ class SequenceUnpadOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("X"));
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......@@ -132,8 +132,8 @@ class SequenceUnpadGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = framework::GetDataTypeOfVar(
ctx.InputVar(framework::GradVarName("Out")));
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
......
......@@ -41,7 +41,8 @@ class ShardIndexOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -35,7 +35,8 @@ class ShuffleChannelOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -83,8 +84,8 @@ class ShuffleChannelGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -70,7 +70,8 @@ class SimilarityFocusOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
platform::CPUPlace());
}
};
......
......@@ -128,7 +128,8 @@ class SliceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
......@@ -243,8 +244,8 @@ class SliceOpGrad : public framework::OperatorWithKernel {
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
......
......@@ -76,7 +76,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
}
#endif
auto input_data_type = ctx.Input<Tensor>("X")->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"float16 can only be used on GPU place");
......@@ -187,8 +187,8 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
auto input_data_type =
ctx.Input<Tensor>(framework::GradVarName("Out"))->type();
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"float16 can only be used on GPU place");
......
......@@ -171,7 +171,8 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
ctx.device_context());
}
};
......@@ -232,8 +233,8 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Loss"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Loss")),
ctx.device_context());
}
};
......
......@@ -167,8 +167,8 @@ class SpaceToDepthGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
......
......@@ -77,8 +77,8 @@ class SpectralNormOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Weight")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace());
}
};
......@@ -209,8 +209,8 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Weight")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Weight"), ctx.GetPlace());
}
};
......
......@@ -152,7 +152,8 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("sub_result")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "sub_result"),
ctx.GetPlace());
}
};
......
......@@ -104,7 +104,8 @@ class SqueezeOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -122,7 +123,8 @@ class SqueezeGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -230,8 +232,8 @@ class Squeeze2GradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -124,7 +124,8 @@ class StridedSliceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.Input<Tensor>("Input")->place());
}
framework::OpKernelType GetKernelTypeForVar(
......@@ -230,8 +231,8 @@ class StridedSliceOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
......
......@@ -55,7 +55,8 @@ class TeacherStudentSigmoidLossOp : public framework::OperatorWithKernel {
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -125,7 +126,8 @@ class TeacherStudentSigmoidLossGradientOp
// is determined by its input "X".
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......
......@@ -56,8 +56,8 @@ class TemporalShiftOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
......@@ -139,8 +139,8 @@ class TemporalShiftOpGrad : public framework::OperatorWithKernel {
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
}
};
......
......@@ -53,8 +53,9 @@ class TopkOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context(), layout_, library_);
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.device_context(),
layout_, library_);
}
};
......
......@@ -78,8 +78,9 @@ class TransposeOp : public framework::OperatorWithKernel {
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
}
};
......@@ -164,8 +165,8 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace(), layout_, library_);
}
};
......@@ -210,8 +211,9 @@ class Transpose2Op : public TransposeOp {
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace(), layout_, library_);
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout_, library_);
}
};
......@@ -268,8 +270,8 @@ class Transpose2OpGrad : public framework::OperatorWithKernel {
layout_ = framework::DataLayout::kMKLDNN;
}
#endif
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace(), layout_, library_);
}
};
......
......@@ -104,7 +104,8 @@ class TreeConvOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("NodesVector")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"),
ctx.device_context());
}
};
......@@ -153,7 +154,8 @@ class TreeConvGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("NodesVector")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "NodesVector"),
ctx.device_context());
}
};
......
......@@ -120,7 +120,8 @@ class UnfoldOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
......@@ -141,8 +142,8 @@ class UnfoldGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::Tensor>(framework::GradVarName("Y"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Y")),
ctx.device_context());
}
};
......
......@@ -74,7 +74,8 @@ class UnpoolOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
......@@ -117,7 +118,8 @@ class UnpoolOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<framework::Tensor>("X")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
......
......@@ -215,8 +215,8 @@ class Unsqueeze2GradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"))->type(),
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.device_context());
}
};
......
......@@ -60,7 +60,8 @@ class WarpCTCOp : public framework::OperatorWithKernel {
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
ctx.device_context(), layout_, library_);
}
};
......@@ -173,7 +174,8 @@ class WarpCTCGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Logits")->type(),
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Logits"),
ctx.device_context());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册