From dc534fc19525b2671a9620863daa7ace47a37c00 Mon Sep 17 00:00:00 2001 From: tangwei12 Date: Fri, 11 May 2018 16:44:10 +0800 Subject: [PATCH] add checkpoint save op test --- paddle/fluid/operators/cpkt_save_op_test.cc | 44 +++++++++++++++++++++ 1 file changed, 44 insertions(+) create mode 100644 paddle/fluid/operators/cpkt_save_op_test.cc diff --git a/paddle/fluid/operators/cpkt_save_op_test.cc b/paddle/fluid/operators/cpkt_save_op_test.cc new file mode 100644 index 00000000000..3e620a0e9cb --- /dev/null +++ b/paddle/fluid/operators/cpkt_save_op_test.cc @@ -0,0 +1,44 @@ +/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" + +TEST(CkptSaveOp, CPU) { + paddle::framework::Scope scope; + paddle::platform::CPUPlace place; + + auto var = scope.Var("test_var"); + auto tensor = var->GetMutable(); + tensor->Resize({3, 10}); + paddle::framework::LoD expect_lod; + expect_lod.resize(1); + expect_lod[0].push_back(0); + expect_lod[0].push_back(1); + expect_lod[0].push_back(2); + expect_lod[0].push_back(3); + + tensor->set_lod(expect_lod); + float* expect = tensor->mutable_data(place); + for (int64_t i = 0; i < tensor->numel(); ++i) { + expect[i] = static_cast(paddle::platform::float16(i)); + } + + paddle::framework::AttributeMap attrs; + attrs.insert({"file_path", std::string("tensor.save")}); + + auto save_op = paddle::framework::OpRegistry::CreateOp( + "ckpt_save", {{"X", {"test_var"}}}, {}, attrs); + save_op->Run(scope, place); +} -- GitLab