未验证 提交 8e2fdc54 编写于 作者: C chengduo 提交者: GitHub

Add check for opt op (#13840)

* add check for opt op

* fix opt op
test=develop

* fix test fail
test=develop

* fix optimization doc
test=develop

* test=develop
上级 e37c9e67
...@@ -18,6 +18,7 @@ namespace paddle { ...@@ -18,6 +18,7 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
class AdadeltaOp : public framework::OperatorWithKernel { class AdadeltaOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -31,6 +32,16 @@ class AdadeltaOp : public framework::OperatorWithKernel { ...@@ -31,6 +32,16 @@ class AdadeltaOp : public framework::OperatorWithKernel {
"Input(AvgSquaredGrad) of AdadeltaOp should not be null."); "Input(AvgSquaredGrad) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"), PADDLE_ENFORCE(ctx->HasInput("AvgSquaredUpdate"),
"Input(AvgSquaredUpdate) of AdadeltaOp should not be null."); "Input(AvgSquaredUpdate) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdadeltaOp should not be null."); "Output(ParamOut) of AdadeltaOp should not be null.");
...@@ -56,6 +67,7 @@ class AdadeltaOp : public framework::OperatorWithKernel { ...@@ -56,6 +67,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("AvgSquaredGradOut", param_dim); ctx->SetOutputDim("AvgSquaredGradOut", param_dim);
ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim); ctx->SetOutputDim("AvgSquaredUpdateOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
auto input_data_type = auto input_data_type =
......
...@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T> ...@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class AdadeltaOpKernel : public framework::OpKernel<T> { class AdadeltaOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut"); auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto avg_squared_grad_out_tensor = auto avg_squared_grad_out_tensor =
ctx.Output<framework::Tensor>("AvgSquaredGradOut"); ctx.Output<framework::Tensor>("AvgSquaredGradOut");
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
...@@ -21,25 +22,31 @@ namespace operators { ...@@ -21,25 +22,31 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
struct SparseAdagradFunctor { struct SparseAdagradFunctor {
void operator()(const DeviceContext& context, void operator()(const DeviceContext &context,
const framework::SelectedRows& grad, const framework::SelectedRows &grad,
const framework::Tensor& learning_rate, T epsilon, const framework::Tensor &learning_rate, T epsilon,
framework::Tensor* moment, framework::Tensor* param); framework::Tensor *moment, framework::Tensor *param);
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class AdagradOpKernel : public framework::OpKernel<T> { class AdagradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
auto* param_out_tensor = ctx.Output<framework::Tensor>("ParamOut"); const auto *param_var = ctx.InputVar("Param");
auto* moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut"); PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto *param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto *moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
param_out_tensor->mutable_data<T>(ctx.GetPlace()); param_out_tensor->mutable_data<T>(ctx.GetPlace());
moment_out_tensor->mutable_data<T>(ctx.GetPlace()); moment_out_tensor->mutable_data<T>(ctx.GetPlace());
T epsilon = static_cast<T>(ctx.Attr<float>("epsilon")); T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto* grad_var = ctx.InputVar("Grad"); auto *grad_var = ctx.InputVar("Grad");
if (grad_var->IsType<framework::LoDTensor>()) { if (grad_var->IsType<framework::LoDTensor>()) {
auto param = framework::EigenVector<T>::Flatten( auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param")); *ctx.Input<framework::Tensor>("Param"));
...@@ -47,16 +54,16 @@ class AdagradOpKernel : public framework::OpKernel<T> { ...@@ -47,16 +54,16 @@ class AdagradOpKernel : public framework::OpKernel<T> {
*ctx.Input<framework::Tensor>("Grad")); *ctx.Input<framework::Tensor>("Grad"));
auto moment = framework::EigenVector<T>::Flatten( auto moment = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Moment")); *ctx.Input<framework::Tensor>("Moment"));
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate"); auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor); auto param_out = framework::EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor); auto moment_out = framework::EigenVector<T>::Flatten(*moment_out_tensor);
auto* place = ctx.template device_context<DeviceContext>().eigen_device(); auto *place = ctx.template device_context<DeviceContext>().eigen_device();
moment_out.device(*place) = moment + grad * grad; moment_out.device(*place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel()); Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
if (platform::is_cpu_place(ctx.GetPlace())) { if (platform::is_cpu_place(ctx.GetPlace())) {
auto* lr = learning_rate->data<T>(); auto *lr = learning_rate->data<T>();
param_out.device(*place) = param_out.device(*place) =
param - lr[0] * grad / (moment_out.sqrt() + epsilon); param - lr[0] * grad / (moment_out.sqrt() + epsilon);
} else { } else {
...@@ -66,10 +73,10 @@ class AdagradOpKernel : public framework::OpKernel<T> { ...@@ -66,10 +73,10 @@ class AdagradOpKernel : public framework::OpKernel<T> {
lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon); lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
} }
} else if (grad_var->IsType<framework::SelectedRows>()) { } else if (grad_var->IsType<framework::SelectedRows>()) {
auto* param_tensor = ctx.Input<framework::Tensor>("Param"); auto *param_tensor = ctx.Input<framework::Tensor>("Param");
PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor); PADDLE_ENFORCE_EQ(param_tensor, param_out_tensor);
auto* moment_tensor = ctx.Input<framework::Tensor>("Moment"); auto *moment_tensor = ctx.Input<framework::Tensor>("Moment");
PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor); PADDLE_ENFORCE_EQ(moment_tensor, moment_out_tensor);
SparseAdagradFunctor<DeviceContext, T> functor; SparseAdagradFunctor<DeviceContext, T> functor;
......
...@@ -244,6 +244,12 @@ template <typename DeviceContext, typename T> ...@@ -244,6 +244,12 @@ template <typename DeviceContext, typename T>
class AdamOpKernel : public framework::OpKernel<T> { class AdamOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
using paddle::framework::LoDTensor; using paddle::framework::LoDTensor;
using paddle::operators::detail::Ref; using paddle::operators::detail::Ref;
......
...@@ -35,6 +35,16 @@ class AdamaxOp : public framework::OperatorWithKernel { ...@@ -35,6 +35,16 @@ class AdamaxOp : public framework::OperatorWithKernel {
"Input(LearningRate) of AdamaxOp should not be null."); "Input(LearningRate) of AdamaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"), PADDLE_ENFORCE(ctx->HasInput("Beta1Pow"),
"Input(Beta1Pow) of AdamaxOp should not be null."); "Input(Beta1Pow) of AdamaxOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of AdamaxOp should not be null."); "Output(ParamOut) of AdamaxOp should not be null.");
......
...@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T> ...@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class AdamaxOpKernel : public framework::OpKernel<T> { class AdamaxOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut"); auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut"); auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
auto inf_norm_out_tensor = ctx.Output<framework::Tensor>("InfNormOut"); auto inf_norm_out_tensor = ctx.Output<framework::Tensor>("InfNormOut");
......
...@@ -32,6 +32,16 @@ class DecayedAdagradOp : public framework::OperatorWithKernel { ...@@ -32,6 +32,16 @@ class DecayedAdagradOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE( PADDLE_ENFORCE(
ctx->HasInput("LearningRate"), ctx->HasInput("LearningRate"),
"Input(LearningRate) of DecayedAdagradOp should not be null."); "Input(LearningRate) of DecayedAdagradOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of DecayedAdagradOp should not be null."); "Output(ParamOut) of DecayedAdagradOp should not be null.");
......
...@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T> ...@@ -23,6 +23,17 @@ template <typename DeviceContext, typename T>
class DecayedAdagradOpKernel : public framework::OpKernel<T> { class DecayedAdagradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut"); auto param_out_tensor = ctx.Output<framework::Tensor>("ParamOut");
auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut"); auto moment_out_tensor = ctx.Output<framework::Tensor>("MomentOut");
......
...@@ -34,6 +34,16 @@ class FTRLOp : public framework::OperatorWithKernel { ...@@ -34,6 +34,16 @@ class FTRLOp : public framework::OperatorWithKernel {
"Input(Grad) of FTRL should not be null."); "Input(Grad) of FTRL should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"), PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of FTRL should not be null."); "Input(LearningRate) of FTRL should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(
ctx->GetInputsVarType("Grad").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Grad").front(), ctx->GetInputsVarType("Grad").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of FTRL should not be null."); "Output(ParamOut) of FTRL should not be null.");
......
...@@ -28,6 +28,17 @@ template <typename DeviceContext, typename T> ...@@ -28,6 +28,17 @@ template <typename DeviceContext, typename T>
class FTRLOpKernel : public framework::OpKernel<T> { class FTRLOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto* param_out = ctx.Output<Tensor>("ParamOut"); auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* sq_accum_out = ctx.Output<Tensor>("SquaredAccumOut"); auto* sq_accum_out = ctx.Output<Tensor>("SquaredAccumOut");
auto* lin_accum_out = ctx.Output<Tensor>("LinearAccumOut"); auto* lin_accum_out = ctx.Output<Tensor>("LinearAccumOut");
......
...@@ -33,6 +33,11 @@ class MomentumOp : public framework::OperatorWithKernel { ...@@ -33,6 +33,11 @@ class MomentumOp : public framework::OperatorWithKernel {
"Input(velocity) of Momentum should not be null."); "Input(velocity) of Momentum should not be null.");
PADDLE_ENFORCE(ctx->HasInput("LearningRate"), PADDLE_ENFORCE(ctx->HasInput("LearningRate"),
"Input(LearningRate) of Momentum should not be null."); "Input(LearningRate) of Momentum should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(ParamOut) of Momentum should not be null."); "Output(ParamOut) of Momentum should not be null.");
......
...@@ -46,6 +46,17 @@ template <typename T> ...@@ -46,6 +46,17 @@ template <typename T>
class MomentumOpCUDAKernel : public framework::OpKernel<T> { class MomentumOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
const auto* grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(grad_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Grad").front(), grad_var->Type().name());
auto param_out = ctx.Output<framework::Tensor>("ParamOut"); auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut"); auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
auto param = ctx.Input<framework::Tensor>("Param"); auto param = ctx.Input<framework::Tensor>("Param");
......
...@@ -23,6 +23,12 @@ template <typename T> ...@@ -23,6 +23,12 @@ template <typename T>
class MomentumOpKernel : public framework::OpKernel<T> { class MomentumOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto param_out = ctx.Output<framework::Tensor>("ParamOut"); auto param_out = ctx.Output<framework::Tensor>("ParamOut");
auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut"); auto velocity_out = ctx.Output<framework::Tensor>("VelocityOut");
auto param = ctx.Input<framework::Tensor>("Param"); auto param = ctx.Input<framework::Tensor>("Param");
......
...@@ -32,6 +32,11 @@ class RmspropOp : public framework::OperatorWithKernel { ...@@ -32,6 +32,11 @@ class RmspropOp : public framework::OperatorWithKernel {
"Input(Grad) of RmspropOp should not be null."); "Input(Grad) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Moment"), PADDLE_ENFORCE(ctx->HasInput("Moment"),
"Input(Moment) of RmspropOp should not be null."); "Input(Moment) of RmspropOp should not be null.");
PADDLE_ENFORCE(
ctx->GetInputsVarType("Param").front() ==
framework::proto::VarType::LOD_TENSOR,
"The input var's type should be LoDTensor, but the received is %s",
ctx->Inputs("Param").front(), ctx->GetInputsVarType("Param").front());
PADDLE_ENFORCE(ctx->HasOutput("ParamOut"), PADDLE_ENFORCE(ctx->HasOutput("ParamOut"),
"Output(param_out) of RmspropOp should not be null."); "Output(param_out) of RmspropOp should not be null.");
......
...@@ -28,6 +28,12 @@ template <typename DeviceContext, typename T> ...@@ -28,6 +28,12 @@ template <typename DeviceContext, typename T>
class RmspropOpKernel : public framework::OpKernel<T> { class RmspropOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto* param_out = ctx.Output<Tensor>("ParamOut"); auto* param_out = ctx.Output<Tensor>("ParamOut");
auto* moment_out = ctx.Output<Tensor>("MomentOut"); auto* moment_out = ctx.Output<Tensor>("MomentOut");
auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut"); auto* mean_square_out = ctx.Output<Tensor>("MeanSquareOut");
......
...@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -21,7 +21,7 @@ class SGDOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"), PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of SGDOp should not be null."); "Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"), PADDLE_ENFORCE(ctx->HasInput("Grad"),
...@@ -42,7 +42,7 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -42,7 +42,7 @@ class SGDOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext &ctx) const override {
auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param")); auto data_type = framework::GetDataTypeOfVar(ctx.InputVar("Param"));
return framework::OpKernelType(data_type, ctx.device_context()); return framework::OpKernelType(data_type, ctx.device_context());
} }
...@@ -50,17 +50,20 @@ class SGDOp : public framework::OperatorWithKernel { ...@@ -50,17 +50,20 @@ class SGDOp : public framework::OperatorWithKernel {
class SGDOpInferVarType : public framework::VarTypeInference { class SGDOpInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc& op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc* block) const override { framework::BlockDesc *block) const override {
auto input_var = op_desc.Input("Param")[0]; auto input_var_n = op_desc.Input("Param")[0];
for (auto& out_var : op_desc.Output("ParamOut")) { auto in_var_type = block->FindRecursiveOrCreateVar(input_var_n).GetType();
if (block->FindRecursiveOrCreateVar(input_var).GetType() == PADDLE_ENFORCE(in_var_type == framework::proto::VarType::SELECTED_ROWS ||
framework::proto::VarType::SELECTED_ROWS) { in_var_type == framework::proto::VarType::LOD_TENSOR,
block->FindRecursiveOrCreateVar(out_var).SetType( "The input Var's type should be LoDtensor or SelectedRows,"
framework::proto::VarType::SELECTED_ROWS); " but the received var(%s)'s type is %s",
} else { input_var_n, in_var_type);
block->FindRecursiveOrCreateVar(out_var).SetType(
framework::proto::VarType::LOD_TENSOR); for (auto &out_var_n : op_desc.Output("ParamOut")) {
auto &out_var = block->FindRecursiveOrCreateVar(out_var_n);
if (out_var.GetType() != in_var_type) {
out_var.SetType(in_var_type);
} }
} }
} }
......
...@@ -56,6 +56,12 @@ template <typename T> ...@@ -56,6 +56,12 @@ template <typename T>
class SGDOpCUDAKernel : public framework::OpKernel<T> { class SGDOpCUDAKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
const auto* param_var = ctx.InputVar("Param");
PADDLE_ENFORCE(param_var->IsType<framework::LoDTensor>(),
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx.Inputs("Param").front(), param_var->Type().name());
auto* param = ctx.Input<framework::Tensor>("Param"); auto* param = ctx.Input<framework::Tensor>("Param");
auto* param_out = ctx.Output<framework::Tensor>("ParamOut"); auto* param_out = ctx.Output<framework::Tensor>("ParamOut");
auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate"); auto* learning_rate = ctx.Input<framework::Tensor>("LearningRate");
......
...@@ -659,6 +659,9 @@ class AdamaxOptimizer(Optimizer): ...@@ -659,6 +659,9 @@ class AdamaxOptimizer(Optimizer):
optimizer = fluid.optimizer.Adamax(learning_rate=0.2) optimizer = fluid.optimizer.Adamax(learning_rate=0.2)
optimizer.minimize(cost) optimizer.minimize(cost)
Notes:
Currently, AdamaxOptimizer doesn't support sparse parameter optimization.
""" """
_moment_acc_str = "moment" _moment_acc_str = "moment"
_inf_norm_acc_str = "inf_norm" _inf_norm_acc_str = "inf_norm"
...@@ -778,6 +781,9 @@ class DecayedAdagradOptimizer(Optimizer): ...@@ -778,6 +781,9 @@ class DecayedAdagradOptimizer(Optimizer):
optimizer = fluid.optimizer.DecayedAdagrad(learning_rate=0.2) optimizer = fluid.optimizer.DecayedAdagrad(learning_rate=0.2)
optimizer.minimize(cost) optimizer.minimize(cost)
Notes:
Currently, DecayedAdagradOptimizer doesn't support sparse parameter optimization.
""" """
_moment_acc_str = "moment" _moment_acc_str = "moment"
...@@ -858,6 +864,9 @@ class AdadeltaOptimizer(Optimizer): ...@@ -858,6 +864,9 @@ class AdadeltaOptimizer(Optimizer):
optimizer = fluid.optimizer.Adadelta( optimizer = fluid.optimizer.Adadelta(
learning_rate=0.0003, epsilon=1.0e-6, rho=0.95) learning_rate=0.0003, epsilon=1.0e-6, rho=0.95)
_, params_grads = optimizer.minimize(cost) _, params_grads = optimizer.minimize(cost)
Notes:
Currently, AdadeltaOptimizer doesn't support sparse parameter optimization.
""" """
_avg_squared_grad_acc_str = "_avg_squared_grad" _avg_squared_grad_acc_str = "_avg_squared_grad"
...@@ -1126,6 +1135,9 @@ class FtrlOptimizer(Optimizer): ...@@ -1126,6 +1135,9 @@ class FtrlOptimizer(Optimizer):
optimizer = fluid.optimizer.Ftrl(0.0001) optimizer = fluid.optimizer.Ftrl(0.0001)
_, params_grads = optimizer.minimize(cost) _, params_grads = optimizer.minimize(cost)
Notes:
Currently, FtrlOptimizer doesn't support sparse parameter optimization.
""" """
_squared_acc_str = "squared" _squared_acc_str = "squared"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册