未验证 提交 c551e55d 编写于 作者: H HongyuJia 提交者: GitHub

clean repetitious GetKernelTypeForVar (#47763)

上级 788d9328
......@@ -90,14 +90,6 @@ class BilateralSliceOp : public framework::OperatorWithKernel {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
class BilateralSliceOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -122,14 +122,6 @@ class CorrelationOp : public framework::OperatorWithKernel {
"X and Y shoule have the same datatype"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
template <typename T>
......
......@@ -196,14 +196,6 @@ framework::OpKernelType FusedBatchNormActOp::GetExpectedKernelType(
input_data_type, ctx.GetPlace(), layout, library);
}
framework::OpKernelType FusedBatchNormActOp::GetKernelTypeForVar(
const std::string &var_name,
const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
void FusedBatchNormActOpMaker::Make() {
AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "")
......
......@@ -36,11 +36,6 @@ class FusedBatchNormActOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
};
class FusedBatchNormActGradOp : public framework::OperatorWithKernel {
......
......@@ -108,14 +108,6 @@ class PartialConcatOp : public framework::OperatorWithKernel {
"All Inputs of PartialSum OP are Empty!"));
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name,
const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
class PartialConcatGradOp : public framework::OperatorWithKernel {
......
......@@ -37,13 +37,6 @@ class PutAlongAxisOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
class PutAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -78,13 +71,6 @@ class PutAlongAxisGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Result")),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
template <typename T>
......
......@@ -37,13 +37,6 @@ class TakeAlongAxisOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
class TakeAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker {
......@@ -77,13 +70,6 @@ class TakeAlongAxisGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Result")),
ctx.device_context());
}
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
};
template <typename T>
......
......@@ -66,9 +66,7 @@ class TransferLayoutOp : public framework::OperatorWithKernel {
const std::string &var_name,
const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
expected_kernel_type.data_layout_);
return expected_kernel_type;
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册