diff --git a/paddle/fluid/operators/adadelta_op.cc b/paddle/fluid/operators/adadelta_op.cc index c9ed221a6e662e8c213fe1d34ff85a3f77483a3c..7bdb3f274aa9bacb6b261e0d0cd00b72f1d409ae 100644 --- a/paddle/fluid/operators/adadelta_op.cc +++ b/paddle/fluid/operators/adadelta_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class AdadeltaOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -55,6 +56,12 @@ class AdadeltaOp : public framework::OperatorWithKernel { ctx->SetOutputDim("AvgSquaredGradOut", param_dim); ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/adagrad_op.cc b/paddle/fluid/operators/adagrad_op.cc index 0153e1253b00ded21a7a14e37faf5a76d904d8d1..1227129429addb0ed412c7f1755fd39c9ca77157 100644 --- a/paddle/fluid/operators/adagrad_op.cc +++ b/paddle/fluid/operators/adagrad_op.cc @@ -23,6 +23,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class AdagradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -56,6 +57,12 @@ class AdagradOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class AdagradOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/adam_op.cc b/paddle/fluid/operators/adam_op.cc index 267dcab8104c337c8590180c8093098c756ab27d..f12f0c6663d1785b8af852244ffe32358fb1b693 100644 --- a/paddle/fluid/operators/adam_op.cc +++ b/paddle/fluid/operators/adam_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class AdamOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -69,6 +70,12 @@ class AdamOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Moment1Out", param_dims); ctx->SetOutputDim("Moment2Out", param_dims); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class AdamOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/adamax_op.cc b/paddle/fluid/operators/adamax_op.cc index 7e2f1cc66ebf8b7deebf55057a27129844129d5d..608b855d58a2f701fbb8631cb5f24768a61f3deb 100644 --- a/paddle/fluid/operators/adamax_op.cc +++ b/paddle/fluid/operators/adamax_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class AdamaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -63,6 +64,12 @@ class AdamaxOp : public framework::OperatorWithKernel { ctx->SetOutputDim("MomentOut", param_dims); ctx->SetOutputDim("InfNormOut", param_dims); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/decayed_adagrad_op.cc b/paddle/fluid/operators/decayed_adagrad_op.cc index 5eeb3dee095e330174f35fa8eebd418bf764b132..5a1315fb2a80bf7f7f57388d0d6832686442c4ff 100644 --- a/paddle/fluid/operators/decayed_adagrad_op.cc +++ b/paddle/fluid/operators/decayed_adagrad_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class DecayedAdagradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -51,6 +52,12 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/ftrl_op.cc b/paddle/fluid/operators/ftrl_op.cc index 0a456f0981e5753a7de5a6f2ba029681beb347a5..cbdcce9beb3fafb0775d0b5fc39cb381ad128d0c 100644 --- a/paddle/fluid/operators/ftrl_op.cc +++ b/paddle/fluid/operators/ftrl_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class FTRLOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -53,6 +54,12 @@ class FTRLOp : public framework::OperatorWithKernel { ctx->SetOutputDim("SquaredAccumOut", param_dim); ctx->SetOutputDim("LinearAccumOut", param_dim); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class FTRLOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/proximal_adagrad_op.cc b/paddle/fluid/operators/proximal_adagrad_op.cc index 38cd97c17b16a4cc64f7e6d52150fae392df6036..e057244c1e974edea1b9bbc76c0585c295495299 100644 --- a/paddle/fluid/operators/proximal_adagrad_op.cc +++ b/paddle/fluid/operators/proximal_adagrad_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class ProximalAdagradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -55,6 +56,12 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/proximal_gd_op.cc b/paddle/fluid/operators/proximal_gd_op.cc index efb4e1ac204ce79bfad7d77038f342be09e8f0e8..ed1472631870e5aee6b0e8b8f80bb5e6c84a3851 100644 --- a/paddle/fluid/operators/proximal_gd_op.cc +++ b/paddle/fluid/operators/proximal_gd_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class ProximalGDOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -43,6 +44,12 @@ class ProximalGDOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dim); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker {