未验证 提交 a28dffbb 编写于 作者: D dzhwinter 提交者: GitHub

Fix/adam float64 (#10407)

* "optimizer op support float64"

* "fix ci"

* "fix ftrl op"
上级 6418c421
......@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class AdadeltaOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -55,6 +56,12 @@ class AdadeltaOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class AdadeltaOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -23,6 +23,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class AdagradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -56,6 +57,12 @@ class AdagradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class AdamOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -69,6 +70,12 @@ class AdamOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("Moment1Out", param_dims);
ctx->SetOutputDim("Moment2Out", param_dims);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class AdamOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class AdamaxOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -63,6 +64,12 @@ class AdamaxOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("MomentOut", param_dims);
ctx->SetOutputDim("InfNormOut", param_dims);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class AdamaxOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class DecayedAdagradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -51,6 +52,12 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dims);
ctx->SetOutputDim("MomentOut", param_dims);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class DecayedAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class FTRLOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -53,6 +54,12 @@ class FTRLOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("SquaredAccumOut", param_dim);
ctx->SetOutputDim("LinearAccumOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class FTRLOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class ProximalAdagradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -55,6 +56,12 @@ class ProximalAdagradOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("MomentOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class ProximalAdagradOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -17,6 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class ProximalGDOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -43,6 +44,12 @@ class ProximalGDOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class ProximalGDOpMaker : public framework::OpProtoAndCheckerMaker {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册