From 16be8e814bad5ad552caa3014cff6b80578cc9c4 Mon Sep 17 00:00:00 2001 From: duduscript Date: Sat, 15 Jul 2017 18:16:46 +0800 Subject: [PATCH] Fix a bug in Momentum model update and add test (#205) * basic model update * spell mistake * fix bug * remove a log line * add model_update_kernel_test * fix some problem in model_update * function problem * judge isfinite and change function * fix batch problem * remove model update test * model_update_kernel_test * change a name * add two function in kernel_util * add a function in kernel_util * remove functions in kernel_util * add op in conf * conflict * RMSProp model update kernel * remove conflict * fix comment error * add const * conflict * conflict * add operator * fix some problem in rmsprop model update kernel * format change * momentum model update kernel test and fix a bug --- .../kernel/momentum_model_update_kernel.cpp | 2 +- .../momentum_model_update_kernel_test.cpp | 88 +++++++++++++++++++ 2 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 oneflow/core/kernel/momentum_model_update_kernel_test.cpp diff --git a/oneflow/core/kernel/momentum_model_update_kernel.cpp b/oneflow/core/kernel/momentum_model_update_kernel.cpp index cd03b6a26a..df37230f16 100644 --- a/oneflow/core/kernel/momentum_model_update_kernel.cpp +++ b/oneflow/core/kernel/momentum_model_update_kernel.cpp @@ -20,7 +20,7 @@ void MomentumMdUpdateKernel::Forward( // momentum = beta * momentum KernelUtil::BlasScal( ctx, momentum_blob->shape().elem_cnt(), - static_cast(-beta), + static_cast(beta), static_cast(momentum_blob->mut_dptr()), 1); // momentum = momentum - alpha * model_diff diff --git a/oneflow/core/kernel/momentum_model_update_kernel_test.cpp b/oneflow/core/kernel/momentum_model_update_kernel_test.cpp new file mode 100644 index 0000000000..777150d3a8 --- /dev/null +++ b/oneflow/core/kernel/momentum_model_update_kernel_test.cpp @@ -0,0 +1,88 @@ +#include "oneflow/core/kernel/momentum_model_update_kernel.h" +#include "oneflow/core/kernel/kernel_test_common.h" + +namespace oneflow { + +namespace test { + +namespace { + +template +Kernel* BuildMomentumMdUpdateKernel(float learning_rate, float beta) { + OperatorConf op_conf; + op_conf.set_name("momentum_model_update_test"); + MomentumModelUpdateOpConf* momentum_md_update_conf = + op_conf.mutable_momentum_model_update_conf(); + momentum_md_update_conf->set_learning_rate(learning_rate); + momentum_md_update_conf->set_beta(beta); + auto momentum_md_update_op = OpMgr::Singleton()->ConstructOp(op_conf); + OperatorProto op_proto; + momentum_md_update_op->ToProto(&op_proto); + auto momentum_md_update_kernel = + new MomentumMdUpdateKernel(); + momentum_md_update_kernel->InitFromOpProto(op_proto); + return momentum_md_update_kernel; +} + +void InitJobDesc(int32_t piece_size, int32_t num_of_pieces_in_batch) { + JobConf job_conf; + job_conf.set_piece_size(piece_size); + job_conf.set_num_of_pieces_in_batch(num_of_pieces_in_batch); + JobDesc::Singleton()->InitFromJobConf(job_conf); +} + +template +std::function BuildBnInOp2BlobPtr() { + using KTCommon = KernelTestCommon; + + std::vector dim_vec = {1, 3, 2}; + + auto bn2blob_ptr = new HashMap; + (*bn2blob_ptr)["model"] = KTCommon::CreateBlobWithSameValue(dim_vec, 2); + (*bn2blob_ptr)["momentum"] = KTCommon::CreateBlobWithSameValue(dim_vec, 4); + (*bn2blob_ptr)["model_diffs"] = KTCommon::CreateBlobWithSameValue(dim_vec, 4); + (*bn2blob_ptr)["model_expected"] = + KTCommon::CreateBlobWithSameValue(dim_vec, 1); + (*bn2blob_ptr)["momentum_expected"] = + KTCommon::CreateBlobWithSameValue(dim_vec, 1); + return [bn2blob_ptr](const std::string& bn) { return bn2blob_ptr->at(bn); }; +} + +template +void TestMomentumMdUpdateKernel() { + using KTCommon = KernelTestCommon; + KernelCtx ctx; + KTCommon::BuildKernelCtx(&ctx); + + const float learning_rate = {0.5f}; + const float beta = {0.5f}; + auto BnInOp2BlobPtr = BuildBnInOp2BlobPtr(); + auto momentum_md_update_kernel = + BuildMomentumMdUpdateKernel(learning_rate, + beta); + int32_t piece_size = 1; + int32_t num_of_pieces_in_batch = 2; + InitJobDesc(piece_size, num_of_pieces_in_batch); + + momentum_md_update_kernel->Forward(ctx, BnInOp2BlobPtr); + KTCommon::SyncStream(&ctx); + + KTCommon::CheckResult(BnInOp2BlobPtr, "momentum", "momentum_expected"); + KTCommon::CheckResult(BnInOp2BlobPtr, "model", "model_expected"); +} + +} // namespace + +} // namespace test + +TEST(MomentumMdUpdateKernel, model_update_cpu) { + test::TestMomentumMdUpdateKernel(); + test::TestMomentumMdUpdateKernel(); +} + +TEST(MomentumMdUpdateKernel, model_update_gpu) { + test::TestMomentumMdUpdateKernel(); + test::TestMomentumMdUpdateKernel(); +} + +} // namespace oneflow -- GitLab