From be62127c1f15f776d1952bc1c976029601eff870 Mon Sep 17 00:00:00 2001 From: duduscript Date: Sun, 9 Jul 2017 13:42:11 +0800 Subject: [PATCH] Delete ModelUpdateKernelTest (#188) * basic model update * spell mistake * fix bug * remove a log line * add model_update_kernel_test * fix some problem in model_update * function problem * judge isfinite and change function * fix batch problem * remove model update test --- .../core/kernel/model_update_kernel_test.cpp | 70 ------------------- 1 file changed, 70 deletions(-) delete mode 100644 oneflow/core/kernel/model_update_kernel_test.cpp diff --git a/oneflow/core/kernel/model_update_kernel_test.cpp b/oneflow/core/kernel/model_update_kernel_test.cpp deleted file mode 100644 index 15d96b98fa..0000000000 --- a/oneflow/core/kernel/model_update_kernel_test.cpp +++ /dev/null @@ -1,70 +0,0 @@ -#include "oneflow/core/kernel/model_update_kernel.h" -#include "oneflow/core/kernel/kernel_test_common.h" - -namespace oneflow { - -namespace test { - -namespace { - -template -Kernel* BuildMdUpdateKernel(float learn_rate) { - OperatorConf op_conf; - op_conf.set_name("model_update_test"); - ModelUpdateOpConf* model_update_conf = op_conf.mutable_model_update_conf(); - model_update_conf->set_learn_rate(learn_rate); - auto model_update_op = OpMgr::Singleton()->ConstructOp(op_conf); - OperatorProto op_proto; - model_update_op->ToProto(&op_proto); - auto model_update_kernel = - new MdUpdateKernel(); - model_update_kernel->InitFromOpProto(op_proto); - return model_update_kernel; -} - -template -std::function BuildBnInOp2BlobPtr() { - using KTCommon = KernelTestCommon; - - std::vector dim_vec = {1, 3, 2}; - - auto bn2blob_ptr = new HashMap; - (*bn2blob_ptr)["model"] = KTCommon::CreateBlobWithSameValue(dim_vec, 3); - (*bn2blob_ptr)["model_diffs"] = KTCommon::CreateBlobWithSameValue(dim_vec, 2); - (*bn2blob_ptr)["model_expected"] = - KTCommon::CreateBlobWithSameValue(dim_vec, 1); - return [bn2blob_ptr](const std::string& bn) { return bn2blob_ptr->at(bn); }; -} - -template -void TestMdUpdateKernel() { - using KTCommon = KernelTestCommon; - KernelCtx ctx; - KTCommon::BuildKernelCtx(&ctx); - - const float learn_rate = {1.0f}; - auto BnInOp2BlobPtr = BuildBnInOp2BlobPtr(); - auto model_update_kernel = - BuildMdUpdateKernel(learn_rate); - - model_update_kernel->Backward(ctx, BnInOp2BlobPtr); - KTCommon::SyncStream(&ctx); - - KTCommon::CheckResult(BnInOp2BlobPtr, "model", "model_expected"); -} - -} // namespace - -} // namespace test - -TEST(MdUpdateKernel, model_update_cpu) { - test::TestMdUpdateKernel(); - test::TestMdUpdateKernel(); -} - -TEST(MdUpdateKernel, model_update_gpu) { - test::TestMdUpdateKernel(); - test::TestMdUpdateKernel(); -} - -} // namespace oneflow -- GitLab