未验证 提交 a28dffbb 编写于 作者: D dzhwinter 提交者: GitHub

Fix/adam float64 (#10407)

* "optimizer op support float64"

* "fix ci"

* "fix ftrl op"
上级 6418c421
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class AdadeltaOp : public framework::OperatorWithKernel { class AdadeltaOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -55,6 +56,12 @@ class AdadeltaOp : public framework::OperatorWithKernel { ...@@ -55,6 +56,12 @@ class AdadeltaOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("AvgSquaredGradOut", param_dim); ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim); ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker { class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class AdagradOp : public framework::OperatorWithKernel { class AdagradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -56,6 +57,12 @@ class AdagradOp : public framework::OperatorWithKernel { ...@@ -56,6 +57,12 @@ class AdagradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker { class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class AdamOp : public framework::OperatorWithKernel { class AdamOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -69,6 +70,12 @@ class AdamOp : public framework::OperatorWithKernel { ...@@ -69,6 +70,12 @@ class AdamOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Moment1Out", param_dims); ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims); ctx->SetOutputDim("Moment2Out", param_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AdamOpMaker : public framework::OpProtoAndCheckerMaker { class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class AdamaxOp : public framework::OperatorWithKernel { class AdamaxOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -63,6 +64,12 @@ class AdamaxOp : public framework::OperatorWithKernel { ...@@ -63,6 +64,12 @@ class AdamaxOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("MomentOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims);
ctx->SetOutputDim("InfNormOut", param_dims); ctx->SetOutputDim("InfNormOut", param_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker { class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class DecayedAdagradOp : public framework::OperatorWithKernel { class DecayedAdagradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -51,6 +52,12 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { ...@@ -51,6 +52,12 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dims); ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims); ctx->SetOutputDim("MomentOut", param_dims);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker { class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class FTRLOp : public framework::OperatorWithKernel { class FTRLOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -53,6 +54,12 @@ class FTRLOp : public framework::OperatorWithKernel { ...@@ -53,6 +54,12 @@ class FTRLOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("SquaredAccumOut", param_dim); ctx->SetOutputDim("SquaredAccumOut", param_dim);
ctx->SetOutputDim("LinearAccumOut", param_dim); ctx->SetOutputDim("LinearAccumOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker { class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class ProximalAdagradOp : public framework::OperatorWithKernel { class ProximalAdagradOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -55,6 +56,12 @@ class ProximalAdagradOp : public framework::OperatorWithKernel { ...@@ -55,6 +56,12 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim); ctx->SetOutputDim("MomentOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker { class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class ProximalGDOp : public framework::OperatorWithKernel { class ProximalGDOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -43,6 +44,12 @@ class ProximalGDOp : public framework::OperatorWithKernel { ...@@ -43,6 +44,12 @@ class ProximalGDOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker { class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册