From a28dffbb0b1fc19a3260beee72071ae99e35c0a9 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Sun, 6 May 2018 21:32:23 +0800 Subject: [PATCH] Fix/adam float64 (#10407) * "optimizer op support float64" * "fix ci" * "fix ftrl op" --- paddle/fluid/operators/adadelta_op.cc | 7 +++++++ paddle/fluid/operators/adagrad_op.cc | 7 +++++++ paddle/fluid/operators/adam_op.cc | 7 +++++++ paddle/fluid/operators/adamax_op.cc | 7 +++++++ paddle/fluid/operators/decayed_adagrad_op.cc | 7 +++++++ paddle/fluid/operators/ftrl_op.cc | 7 +++++++ paddle/fluid/operators/proximal_adagrad_op.cc | 7 +++++++ paddle/fluid/operators/proximal_gd_op.cc | 7 +++++++ 8 files changed, 56 insertions(+) diff --git a/paddle/fluid/operators/adadelta_op.cc b/paddle/fluid/operators/adadelta_op.cc index c9ed221a6e6..7bdb3f274aa 100644 --- a/paddle/fluid/operators/adadelta_op.cc +++ b/paddle/fluid/operators/adadelta_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class AdadeltaOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -55,6 +56,12 @@ class AdadeltaOp : public framework::OperatorWithKernel { ctx->SetOutputDim("AvgSquaredGradOut", param_dim); ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/adagrad_op.cc b/paddle/fluid/operators/adagrad_op.cc index 0153e1253b0..1227129429a 100644 --- a/paddle/fluid/operators/adagrad_op.cc +++ b/paddle/fluid/operators/adagrad_op.cc @@ -23,6 +23,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class AdagradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -56,6 +57,12 @@ class AdagradOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class AdagradOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/adam_op.cc b/paddle/fluid/operators/adam_op.cc index 267dcab8104..f12f0c6663d 100644 --- a/paddle/fluid/operators/adam_op.cc +++ b/paddle/fluid/operators/adam_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class AdamOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -69,6 +70,12 @@ class AdamOp : public framework::OperatorWithKernel { ctx->SetOutputDim("Moment1Out", param_dims); ctx->SetOutputDim("Moment2Out", param_dims); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class AdamOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/adamax_op.cc b/paddle/fluid/operators/adamax_op.cc index 7e2f1cc66eb..608b855d58a 100644 --- a/paddle/fluid/operators/adamax_op.cc +++ b/paddle/fluid/operators/adamax_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class AdamaxOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -63,6 +64,12 @@ class AdamaxOp : public framework::OperatorWithKernel { ctx->SetOutputDim("MomentOut", param_dims); ctx->SetOutputDim("InfNormOut", param_dims); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/decayed_adagrad_op.cc b/paddle/fluid/operators/decayed_adagrad_op.cc index 5eeb3dee095..5a1315fb2a8 100644 --- a/paddle/fluid/operators/decayed_adagrad_op.cc +++ b/paddle/fluid/operators/decayed_adagrad_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class DecayedAdagradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -51,6 +52,12 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/ftrl_op.cc b/paddle/fluid/operators/ftrl_op.cc index 0a456f0981e..cbdcce9beb3 100644 --- a/paddle/fluid/operators/ftrl_op.cc +++ b/paddle/fluid/operators/ftrl_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class FTRLOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -53,6 +54,12 @@ class FTRLOp : public framework::OperatorWithKernel { ctx->SetOutputDim("SquaredAccumOut", param_dim); ctx->SetOutputDim("LinearAccumOut", param_dim); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class FTRLOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/proximal_adagrad_op.cc b/paddle/fluid/operators/proximal_adagrad_op.cc index 38cd97c17b1..e057244c1e9 100644 --- a/paddle/fluid/operators/proximal_adagrad_op.cc +++ b/paddle/fluid/operators/proximal_adagrad_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class ProximalAdagradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -55,6 +56,12 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/proximal_gd_op.cc b/paddle/fluid/operators/proximal_gd_op.cc index efb4e1ac204..ed147263187 100644 --- a/paddle/fluid/operators/proximal_gd_op.cc +++ b/paddle/fluid/operators/proximal_gd_op.cc @@ -17,6 +17,7 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; class ProximalGDOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -43,6 +44,12 @@ class ProximalGDOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dim); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker { -- GitLab