#include "oneflow/core/kernel/momentum_model_update_kernel.h" namespace oneflow { template void MomentumMdUpdateKernel::Forward( const KernelCtx& ctx, std::function BnInOp2Blob) const { Blob* model_blob = BnInOp2Blob("model"); Blob* momentum_blob = BnInOp2Blob("momentum"); const Blob* model_diffs_blob = BnInOp2Blob("model_diffs"); float learning_rate = op()->op_conf().momentum_mdupdt_conf().learning_rate(); float beta = op()->op_conf().momentum_mdupdt_conf().beta(); float alpha = learning_rate / JobDesc::Singleton()->BatchSize(); CHECK(std::isfinite(alpha)); // momentum = beta * momentum KernelUtil::BlasScal( ctx.device_ctx, momentum_blob->shape().elem_cnt(), static_cast(beta), momentum_blob->mut_dptr(), 1); // momentum = momentum - alpha * model_diff KernelUtil::BlasAxpy( ctx.device_ctx, momentum_blob->shape().elem_cnt(), static_cast(-alpha), model_diffs_blob->dptr(), 1, momentum_blob->mut_dptr(), 1); // model = model + momentum KernelUtil::BlasAxpy( ctx.device_ctx, model_blob->shape().elem_cnt(), static_cast(1), momentum_blob->dptr(), 1, model_blob->mut_dptr(), 1); } template void MomentumMdUpdateKernel::InitDataTmpBlobs( const KernelCtx& ctx, std::function BnInOp2Blob) const { FillConf momentum_fill_conf; momentum_fill_conf.mutable_constant_conf()->set_value(0.0f); KernelUtil::Fill(ctx.device_ctx, momentum_fill_conf, 0, BnInOp2Blob("momentum")); } namespace { Kernel* CreateMomentumMdUpdateKernel(const OpContext& op_ctx) { static const HashMap> creators = { #define MODEL_UPDATE_KERNEL_ENTRY(device_type, data_type_pair) \ {GetHashKey(device_type, OF_PP_PAIR_SECOND(data_type_pair)), []() { \ return new MomentumMdUpdateKernel(); \ }}, OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE( MODEL_UPDATE_KERNEL_ENTRY, DEVICE_TYPE_SEQ, FLOATING_DATA_TYPE_SEQ)}; return creators.at(GetHashKey( op_ctx.device_type(), op_ctx.bn_in_op2data_type().at("model_diffs")))(); } } // namespace COMMAND(AddKernelCreator(OperatorConf::kMomentumMdupdtConf, CreateMomentumMdUpdateKernel)) } // namespace oneflow