// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "paddle/fluid/operators/optimizers/dgc_momentum_op.h" namespace paddle { namespace operators { class DGCMomentumOp : public MomentumOp { public: using MomentumOp::MomentumOp; protected: void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE_EQ(ctx->HasInput("current_step"), true, "current_step should be set."); return MomentumOp::InferShape(ctx); } framework::OpKernelType GetKernelTypeForVar( const std::string& var_name, const framework::Tensor& tensor, const framework::OpKernelType& expected_kernel_type) const override { if (var_name == "current_step") { VLOG(10) << "var_name:" << var_name << " need not to transform"; return expected_kernel_type; } return framework::OperatorWithKernel::GetKernelTypeForVar( var_name, tensor, expected_kernel_type); } }; class DGCMomentumOpMaker : public MomentumOpMaker { public: void Make() override { AddInput("current_step", "(Tensor) Current step."); AddAttr("rampup_begin_step", "(float, -1.0)" "The period when begin DGC.") .SetDefault(-1.0); return MomentumOpMaker::Make(); } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(dgc_momentum, ops::DGCMomentumOp, ops::DGCMomentumOpMaker); REGISTER_OP_CPU_KERNEL( dgc_momentum, ops::DGCMomentumKernel);