提交 6939c381 编写于 作者: Z Zhenyu Tan 提交者: TensorFlower Gardener

Internal Cleanup.

PiperOrigin-RevId: 225217785
上级 57eb92b7
......@@ -194,50 +194,6 @@ static void BM_Adam(int iters, int params) {
}
BENCHMARK(BM_Adam)->Arg(128 << 10)->Arg(256 << 10);
static void AdamWithAmsgrad(int32 n, Graph** init_g, Graph** train_g) {
TensorShape shape({n});
{
Graph* g = new Graph(OpRegistry::Global());
auto var = Var(g, n);
auto m = Var(g, n);
auto v = Var(g, n);
auto zero = Zeros(g, n);
test::graph::Assign(g, var, zero);
test::graph::Assign(g, m, zero);
test::graph::Assign(g, v, zero);
*init_g = g;
}
{
Graph* g = new Graph(OpRegistry::Global());
auto var = Var(g, n);
auto m = Var(g, n);
auto v = Var(g, n);
auto vhat = Var(g, n);
auto beta1_power = Scalar(g, 0.9);
auto beta2_power = Scalar(g, 0.99);
auto lr = Scalar(g, 0.01);
auto beta1 = Scalar(g, 0.9);
auto beta2 = Scalar(g, 0.99);
auto epsilon = Scalar(g, 1e-8);
auto grad = Random(g, n);
test::graph::Multi(g, "ApplyAdamWithAmsgrad",
{var, m, v, vhat, beta1_power, beta2_power, lr, beta1,
beta2, epsilon, grad});
*train_g = g;
}
}
static void BM_AdamWithAmsgrad(int iters, int params) {
const int64 tot = static_cast<int64>(iters) * params;
testing::ItemsProcessed(tot);
testing::BytesProcessed(tot * sizeof(float));
Graph* init;
Graph* train;
AdamWithAmsgrad(params, &init, &train);
test::Benchmark("cpu", train, GetOptions(), init).Run(iters);
}
BENCHMARK(BM_AdamWithAmsgrad)->Arg(128 << 10)->Arg(256 << 10);
static void RMSProp(int32 n, Graph** init_g, Graph** train_g) {
TensorShape shape({n});
{
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册