未验证 提交 c551e55d 编写于 作者: H HongyuJia 提交者: GitHub

clean repetitious GetKernelTypeForVar (#47763)

上级 788d9328
...@@ -90,14 +90,6 @@ class BilateralSliceOp : public framework::OperatorWithKernel { ...@@ -90,14 +90,6 @@ class BilateralSliceOp : public framework::OperatorWithKernel {
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
}; };
class BilateralSliceOpMaker : public framework::OpProtoAndCheckerMaker { class BilateralSliceOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -122,14 +122,6 @@ class CorrelationOp : public framework::OperatorWithKernel { ...@@ -122,14 +122,6 @@ class CorrelationOp : public framework::OperatorWithKernel {
"X and Y shoule have the same datatype")); "X and Y shoule have the same datatype"));
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
}; };
template <typename T> template <typename T>
......
...@@ -196,14 +196,6 @@ framework::OpKernelType FusedBatchNormActOp::GetExpectedKernelType( ...@@ -196,14 +196,6 @@ framework::OpKernelType FusedBatchNormActOp::GetExpectedKernelType(
input_data_type, ctx.GetPlace(), layout, library); input_data_type, ctx.GetPlace(), layout, library);
} }
framework::OpKernelType FusedBatchNormActOp::GetKernelTypeForVar(
const std::string &var_name,
const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
void FusedBatchNormActOpMaker::Make() { void FusedBatchNormActOpMaker::Make() {
AddAttr<float>("momentum", "").SetDefault(0.9); AddAttr<float>("momentum", "").SetDefault(0.9);
AddAttr<float>("epsilon", "") AddAttr<float>("epsilon", "")
......
...@@ -36,11 +36,6 @@ class FusedBatchNormActOp : public framework::OperatorWithKernel { ...@@ -36,11 +36,6 @@ class FusedBatchNormActOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override; const framework::ExecutionContext& ctx) const override;
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override;
}; };
class FusedBatchNormActGradOp : public framework::OperatorWithKernel { class FusedBatchNormActGradOp : public framework::OperatorWithKernel {
......
...@@ -108,14 +108,6 @@ class PartialConcatOp : public framework::OperatorWithKernel { ...@@ -108,14 +108,6 @@ class PartialConcatOp : public framework::OperatorWithKernel {
"All Inputs of PartialSum OP are Empty!")); "All Inputs of PartialSum OP are Empty!"));
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name,
const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
}; };
class PartialConcatGradOp : public framework::OperatorWithKernel { class PartialConcatGradOp : public framework::OperatorWithKernel {
......
...@@ -37,13 +37,6 @@ class PutAlongAxisOp : public framework::OperatorWithKernel { ...@@ -37,13 +37,6 @@ class PutAlongAxisOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context()); ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
}; };
class PutAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker { class PutAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -78,13 +71,6 @@ class PutAlongAxisGradOp : public framework::OperatorWithKernel { ...@@ -78,13 +71,6 @@ class PutAlongAxisGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Result")), ctx, framework::GradVarName("Result")),
ctx.device_context()); ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
}; };
template <typename T> template <typename T>
......
...@@ -37,13 +37,6 @@ class TakeAlongAxisOp : public framework::OperatorWithKernel { ...@@ -37,13 +37,6 @@ class TakeAlongAxisOp : public framework::OperatorWithKernel {
OperatorWithKernel::IndicateVarDataType(ctx, "Input"), OperatorWithKernel::IndicateVarDataType(ctx, "Input"),
ctx.device_context()); ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
}; };
class TakeAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker { class TakeAlongAxisOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -77,13 +70,6 @@ class TakeAlongAxisGradOp : public framework::OperatorWithKernel { ...@@ -77,13 +70,6 @@ class TakeAlongAxisGradOp : public framework::OperatorWithKernel {
ctx, framework::GradVarName("Result")), ctx, framework::GradVarName("Result")),
ctx.device_context()); ctx.device_context());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string& var_name,
const phi::DenseTensor& tensor,
const framework::OpKernelType& expected_kernel_type) const override {
return framework::OpKernelType(
expected_kernel_type.data_type_, tensor.place(), tensor.layout());
}
}; };
template <typename T> template <typename T>
......
...@@ -66,9 +66,7 @@ class TransferLayoutOp : public framework::OperatorWithKernel { ...@@ -66,9 +66,7 @@ class TransferLayoutOp : public framework::OperatorWithKernel {
const std::string &var_name, const std::string &var_name,
const phi::DenseTensor &tensor, const phi::DenseTensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override { const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_, return expected_kernel_type;
expected_kernel_type.place_,
expected_kernel_type.data_layout_);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册