提交 75453ebd 编写于 作者: duduscript's avatar duduscript 提交者: chengtbf

RMSProp model update kernel test and fix a bug (#206)

* basic model update

* spell mistake

* fix bug

* remove a log line

* add model_update_kernel_test

* fix some problem in model_update

* function problem

* judge isfinite and change function

* fix batch problem

* remove model update test

* model_update_kernel_test

* change a name

* add two function in kernel_util

* add a function in kernel_util

* remove functions in kernel_util

* add op in conf

* conflict

* RMSProp model update kernel

* remove conflict

* fix comment error

* add const

* conflict

* conflict

* add operator

* fix some problem in rmsprop model update kernel

* format change

* momentum model update kernel test and fix a bug

* add test for rmsprop model update kernel
上级 e3066df5
......@@ -57,8 +57,7 @@ class RMSPropMdUpdateKernelUtil<DeviceType::kCPU, FloatingPointType> final {
const FloatingPointType alpha) {
ctx.device_ctx->cpu_stream()->SendWork([=]() {
for (int64_t i = 0; i < n; ++i) {
model[i] -=
alpha * model_diff[i] / (std::sqrt(mean_square[i]) + epsilon);
model[i] -= alpha * model_diff[i] / std::sqrt(mean_square[i] + epsilon);
}
});
}
......
......@@ -17,7 +17,7 @@ __global__ void UpdateMeanSquareGpu(const int64_t n,
}
}
// model -= alpha * model_diff / (sqrt(mean_square) + epsilon)
// model -= alpha * model_diff / sqrt(mean_square + epsilon)
template<typename FloatingPointType>
__global__ void UpdateModelGpu(const int64_t n, FloatingPointType* model,
const FloatingPointType* model_diff,
......@@ -25,7 +25,7 @@ __global__ void UpdateModelGpu(const int64_t n, FloatingPointType* model,
const FloatingPointType epsilon,
const FloatingPointType alpha) {
CUDA_1D_KERNEL_LOOP(i, n) {
model[i] -= alpha * model_diff[i] / (std::sqrt(mean_square[i]) + epsilon);
model[i] -= alpha * model_diff[i] / std::sqrt(mean_square[i] + epsilon);
}
}
......
#include "oneflow/core/kernel/rmsprop_model_update_kernel.h"
#include "oneflow/core/kernel/kernel_test_common.h"
namespace oneflow {
namespace test {
namespace {
template<DeviceType device_type, typename FloatingPointType>
Kernel* BuildRMSPropMdUpdateKernel(float learning_rate, float decay_rate,
float epsilon) {
OperatorConf op_conf;
op_conf.set_name("rmsprop_model_update_test");
RMSPropModelUpdateOpConf* rmsprop_md_update_conf =
op_conf.mutable_rmsprop_model_update_conf();
rmsprop_md_update_conf->set_learning_rate(learning_rate);
rmsprop_md_update_conf->set_decay_rate(decay_rate);
rmsprop_md_update_conf->set_epsilon(epsilon);
auto rmsprop_md_update_op = OpMgr::Singleton()->ConstructOp(op_conf);
OperatorProto op_proto;
rmsprop_md_update_op->ToProto(&op_proto);
auto rmsprop_md_update_kernel =
new RMSPropMdUpdateKernel<device_type, FloatingPointType>();
rmsprop_md_update_kernel->InitFromOpProto(op_proto);
return rmsprop_md_update_kernel;
}
void InitJobDesc(int32_t piece_size, int32_t num_of_pieces_in_batch) {
JobConf job_conf;
job_conf.set_piece_size(piece_size);
job_conf.set_num_of_pieces_in_batch(num_of_pieces_in_batch);
JobDesc::Singleton()->InitFromJobConf(job_conf);
}
template<DeviceType device_type, typename FloatingPointType>
std::function<Blob*(const std::string&)> BuildBnInOp2BlobPtr() {
using KTCommon = KernelTestCommon<device_type, FloatingPointType>;
std::vector<int64_t> dim_vec = {1, 3, 2};
auto bn2blob_ptr = new HashMap<std::string, Blob*>;
(*bn2blob_ptr)["model"] = KTCommon::CreateBlobWithSameValue(dim_vec, 3);
(*bn2blob_ptr)["mean_square"] = KTCommon::CreateBlobWithSameValue(dim_vec, 2);
(*bn2blob_ptr)["model_diffs"] = KTCommon::CreateBlobWithSameValue(dim_vec, 2);
(*bn2blob_ptr)["model_expected"] =
KTCommon::CreateBlobWithSameValue(dim_vec, 2);
(*bn2blob_ptr)["mean_square_expected"] =
KTCommon::CreateBlobWithSameValue(dim_vec, 3);
return [bn2blob_ptr](const std::string& bn) { return bn2blob_ptr->at(bn); };
}
template<DeviceType device_type, typename FloatingPointType>
void TestRMSPropMdUpdateKernel() {
using KTCommon = KernelTestCommon<device_type, FloatingPointType>;
KernelCtx ctx;
KTCommon::BuildKernelCtx(&ctx);
const float learning_rate = {2.0f};
const float decay_rate = {0.5f};
const float epsilon = {1.0f};
auto BnInOp2BlobPtr = BuildBnInOp2BlobPtr<device_type, FloatingPointType>();
auto rmsprop_md_update_kernel =
BuildRMSPropMdUpdateKernel<device_type, FloatingPointType>(
learning_rate, decay_rate, epsilon);
int32_t piece_size = 1;
int32_t num_of_pieces_in_batch = 2;
InitJobDesc(piece_size, num_of_pieces_in_batch);
rmsprop_md_update_kernel->Forward(ctx, BnInOp2BlobPtr);
KTCommon::SyncStream(&ctx);
KTCommon::CheckResult(BnInOp2BlobPtr, "mean_square", "mean_square_expected");
KTCommon::CheckResult(BnInOp2BlobPtr, "model", "model_expected");
}
} // namespace
} // namespace test
TEST(RMSPropMdUpdateKernel, model_update_cpu) {
test::TestRMSPropMdUpdateKernel<DeviceType::kCPU, float>();
test::TestRMSPropMdUpdateKernel<DeviceType::kCPU, double>();
}
TEST(RMSPropMdUpdateKernel, model_update_gpu) {
test::TestRMSPropMdUpdateKernel<DeviceType::kGPU, float>();
test::TestRMSPropMdUpdateKernel<DeviceType::kGPU, double>();
}
} // namespace oneflow
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册