From 832a014c17d6e3b893b697cc944522fda74b0dac Mon Sep 17 00:00:00 2001 From: "joanna.wozna.intel" Date: Thu, 17 Jun 2021 05:02:50 +0200 Subject: [PATCH] Add bf16 support for save and load ops (#33173) * Add bf16 support for save and load ops * Add bf16 test condition * Add matmul and chagne fluid.io to paddle.static * Reduce the test duration --- paddle/fluid/operators/load_combine_op.cc | 2 + paddle/fluid/operators/load_op.cc | 2 + paddle/fluid/operators/save_combine_op.cc | 2 + .../operators/save_load_combine_op_test.cc | 44 +++--- paddle/fluid/operators/save_op.cc | 2 + .../fluid/tests/book/test_fit_a_line.py | 10 +- .../unittests/test_static_save_load_bf16.py | 134 ++++++++++++++++++ tools/static_mode_white_list.py | 1 + 8 files changed, 174 insertions(+), 23 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_static_save_load_bf16.py diff --git a/paddle/fluid/operators/load_combine_op.cc b/paddle/fluid/operators/load_combine_op.cc index 63d3f809f26..374bfa73f21 100644 --- a/paddle/fluid/operators/load_combine_op.cc +++ b/paddle/fluid/operators/load_combine_op.cc @@ -87,6 +87,8 @@ REGISTER_OP_CPU_KERNEL( load_combine, ops::LoadCombineOpKernel, ops::LoadCombineOpKernel, + ops::LoadCombineOpKernel, ops::LoadCombineOpKernel, ops::LoadCombineOpKernel, ops::LoadCombineOpKernel); diff --git a/paddle/fluid/operators/load_op.cc b/paddle/fluid/operators/load_op.cc index 4f2c9a6ca03..ba19aee9b8d 100644 --- a/paddle/fluid/operators/load_op.cc +++ b/paddle/fluid/operators/load_op.cc @@ -69,6 +69,8 @@ REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker); REGISTER_OP_CPU_KERNEL( load, ops::LoadOpKernel, ops::LoadOpKernel, + ops::LoadOpKernel, ops::LoadOpKernel, ops::LoadOpKernel, ops::LoadOpKernel); diff --git a/paddle/fluid/operators/save_combine_op.cc b/paddle/fluid/operators/save_combine_op.cc index ec038f16113..6da73c99068 100644 --- a/paddle/fluid/operators/save_combine_op.cc +++ b/paddle/fluid/operators/save_combine_op.cc @@ -102,5 +102,7 @@ REGISTER_OP_CPU_KERNEL( save_combine, ops::SaveCombineOpKernel, ops::SaveCombineOpKernel, + ops::SaveCombineOpKernel, ops::SaveCombineOpKernel, ops::SaveCombineOpKernel); diff --git a/paddle/fluid/operators/save_load_combine_op_test.cc b/paddle/fluid/operators/save_load_combine_op_test.cc index 5594de16b67..493f5081ee4 100644 --- a/paddle/fluid/operators/save_load_combine_op_test.cc +++ b/paddle/fluid/operators/save_load_combine_op_test.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "gtest/gtest.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/float16.h" USE_CPU_ONLY_OP(save_combine); @@ -76,33 +77,34 @@ void CheckValues(T* expect, U* actual, const paddle::framework::LoD& expect_lod, // Here, we create 4 LoDTensors and use save_combine_op to first save these // in a single file. Then, we use load_combine_op to load these sequentially -TEST(SaveLoadCombineOp, CPU) { +template +void SaveLoadCombineOp() { paddle::framework::Scope scope; paddle::platform::CPUPlace place; std::vector lod1 = {0, 1, 2, 3, 10}; int numel1 = 100; paddle::framework::LoD expect_lod1; - int* expect1 = CreateForSaveCombineOp(10, 10, lod1, "test_var1", - place, &scope, &expect_lod1); + T* expect1 = CreateForSaveCombineOp(10, 10, lod1, "test_var1", place, + &scope, &expect_lod1); std::vector lod2 = {0, 2, 5, 10}; int numel2 = 200; paddle::framework::LoD expect_lod2; - int* expect2 = CreateForSaveCombineOp(10, 20, lod2, "test_var2", - place, &scope, &expect_lod2); + T* expect2 = CreateForSaveCombineOp(10, 20, lod2, "test_var2", place, + &scope, &expect_lod2); std::vector lod3 = {0, 2, 3, 20}; int numel3 = 4000; paddle::framework::LoD expect_lod3; - int* expect3 = CreateForSaveCombineOp(20, 200, lod3, "test_var3", - place, &scope, &expect_lod3); + T* expect3 = CreateForSaveCombineOp(20, 200, lod3, "test_var3", place, + &scope, &expect_lod3); std::vector lod4 = {0, 1, 20}; int numel4 = 1000; paddle::framework::LoD expect_lod4; - int* expect4 = CreateForSaveCombineOp(20, 50, lod4, "test_var4", - place, &scope, &expect_lod4); + T* expect4 = CreateForSaveCombineOp(20, 50, lod4, "test_var4", place, + &scope, &expect_lod4); // Set attributes std::string filename = "check_tensor.ls"; @@ -128,15 +130,21 @@ TEST(SaveLoadCombineOp, CPU) { load_combine_op->Run(scope, place); paddle::framework::LoD actual_lod1, actual_lod2, actual_lod3, actual_lod4; - int* actual1 = GetValuesAfterLoadCombineOp(target1, scope, &actual_lod1); - int* actual2 = GetValuesAfterLoadCombineOp(target2, scope, &actual_lod2); - int* actual3 = GetValuesAfterLoadCombineOp(target3, scope, &actual_lod3); - int* actual4 = GetValuesAfterLoadCombineOp(target4, scope, &actual_lod4); - - CheckValues(expect1, actual1, expect_lod1, actual_lod1, numel1); - CheckValues(expect2, actual2, expect_lod2, actual_lod2, numel2); - CheckValues(expect3, actual3, expect_lod3, actual_lod3, numel3); - CheckValues(expect4, actual4, expect_lod4, actual_lod4, numel4); + U* actual1 = GetValuesAfterLoadCombineOp(target1, scope, &actual_lod1); + U* actual2 = GetValuesAfterLoadCombineOp(target2, scope, &actual_lod2); + U* actual3 = GetValuesAfterLoadCombineOp(target3, scope, &actual_lod3); + U* actual4 = GetValuesAfterLoadCombineOp(target4, scope, &actual_lod4); + + CheckValues(expect1, actual1, expect_lod1, actual_lod1, numel1); + CheckValues(expect2, actual2, expect_lod2, actual_lod2, numel2); + CheckValues(expect3, actual3, expect_lod3, actual_lod3, numel3); + CheckValues(expect4, actual4, expect_lod4, actual_lod4, numel4); +} + +TEST(SaveLoadCombineOp, CPU) { SaveLoadCombineOp(); } + +TEST(SaveLoadCombineBF16Op, CPU) { + SaveLoadCombineOp(); } // FP16 version of SaveLoadCombineOp Test, only altering the saving aspect diff --git a/paddle/fluid/operators/save_op.cc b/paddle/fluid/operators/save_op.cc index 194274cdd5b..d819c172e4a 100644 --- a/paddle/fluid/operators/save_op.cc +++ b/paddle/fluid/operators/save_op.cc @@ -90,6 +90,8 @@ REGISTER_OP_CPU_KERNEL( ops::SaveOpKernel, ops::SaveOpKernel, + ops::SaveOpKernel, ops::SaveOpKernel, ops::SaveOpKernel, ops::SaveOpKernel, diff --git a/python/paddle/fluid/tests/book/test_fit_a_line.py b/python/paddle/fluid/tests/book/test_fit_a_line.py index 12952462270..65542e2096c 100644 --- a/python/paddle/fluid/tests/book/test_fit_a_line.py +++ b/python/paddle/fluid/tests/book/test_fit_a_line.py @@ -84,9 +84,9 @@ def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16): feed=feeder.feed(data), fetch_list=[avg_cost]) if avg_loss_value[0] < 10.0 or pure_bf16: - if save_dirname is not None and not pure_bf16: - fluid.io.save_inference_model(save_dirname, ['x'], - [y_predict], exe) + if save_dirname is not None: + paddle.static.save_inference_model(save_dirname, [x], + [y_predict], exe) return if math.isnan(float(avg_loss_value)): sys.exit("got NaN loss, training failed.") @@ -127,12 +127,12 @@ def infer(use_cuda, save_dirname=None, use_bf16=False): inference_scope = fluid.core.Scope() with fluid.scope_guard(inference_scope): - # Use fluid.io.load_inference_model to obtain the inference program desc, + # Use paddle.static.load_inference_model to obtain the inference program desc, # the feed_target_names (the names of variables that will be fed # data using feed operators), and the fetch_targets (variables that # we want to obtain data from using fetch operators). [inference_program, feed_target_names, - fetch_targets] = fluid.io.load_inference_model(save_dirname, exe) + fetch_targets] = paddle.static.load_inference_model(save_dirname, exe) # The input's dimension should be 2-D and the second dim is 13 # The input data should be >= 0 diff --git a/python/paddle/fluid/tests/unittests/test_static_save_load_bf16.py b/python/paddle/fluid/tests/unittests/test_static_save_load_bf16.py new file mode 100644 index 00000000000..8d665a17468 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_static_save_load_bf16.py @@ -0,0 +1,134 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import paddle.fluid.core as core +import paddle.fluid as fluid +import paddle.fluid.framework as framework +from paddle.fluid.optimizer import SGDOptimizer +from paddle.fluid.tests.unittests.test_imperative_base import new_program_scope +from paddle.fluid.tests.unittests.test_static_save_load import PtbModel +import numpy as np + + +@unittest.skipIf(not core.supports_bfloat16(), + "place does not support BF16 evaluation") +class TestSaveLoadBF16(unittest.TestCase): + def set_place(self): + return fluid.CPUPlace() + + def test_ptb_rnn_cpu_bfloat16(self): + seed = 90 + hidden_size = 10 + vocab_size = 500 + num_layers = 1 + num_steps = 3 + init_scale = 0.1 + batch_size = 4 + batch_num = 100 + + with new_program_scope(): + fluid.default_startup_program().random_seed = seed + fluid.default_main_program().random_seed = seed + ptb_model = PtbModel( + "ptb_model", + hidden_size=hidden_size, + vocab_size=vocab_size, + num_layers=num_layers, + num_steps=num_steps, + init_scale=init_scale) + + place = self.set_place() + exe = fluid.Executor(place) + sgd = SGDOptimizer(learning_rate=1e-3) + x = fluid.layers.data( + name="x", shape=[-1, num_steps], dtype='int64') + y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32') + init_hidden = fluid.layers.data( + name="init_hidden", shape=[1], dtype='float32') + init_cell = fluid.layers.data( + name="init_cell", shape=[1], dtype='float32') + + static_loss, static_last_hidden, static_last_cell = ptb_model( + x, y, init_hidden, init_cell) + + sgd = paddle.static.amp.bf16.decorate_bf16( + sgd, + amp_lists=paddle.static.amp.bf16.AutoMixedPrecisionListsBF16( + custom_fp32_list={'transpose2', 'concat'}), + use_bf16_guard=False, + use_pure_bf16=True) + + sgd.minimize(static_loss, framework.default_startup_program()) + out = exe.run(framework.default_startup_program()) + + for i in range(batch_num): + x_data = np.arange(12).reshape(4, 3).astype('int64') + y_data = np.arange(1, 13).reshape(4, 3).astype('int64') + x_data = x_data.reshape((-1, num_steps, 1)) + y_data = y_data.reshape((-1, 1)) + init_hidden_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32') + init_cell_data = np.zeros( + (num_layers, batch_size, hidden_size), dtype='float32') + fetch_list = [static_loss, static_last_hidden, static_last_cell] + out = exe.run(fluid.default_main_program(), + feed={ + "x": x_data, + "y": y_data, + "init_hidden": init_hidden_data, + "init_cell": init_cell_data + }, + fetch_list=fetch_list) + + # get value before save + main_program = framework.default_main_program() + base_map = {} + for var in main_program.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + # make sure all the paramerter or optimizer var have been update + self.assertTrue(np.sum(np.abs(t)) != 0) + base_map[var.name] = t + + fluid.save(main_program, "./test_1") + + # set var to zero + for var in main_program.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + ten = fluid.global_scope().find_var(var.name).get_tensor() + ten.set(np.zeros_like(np.array(ten)), place) + + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + # make sure all the paramerter or optimizer var have been set to zero + self.assertTrue(np.sum(np.abs(new_t)) == 0) + + fluid.load(main_program, "./test_1.pdparams", exe) + + for var in main_program.list_vars(): + if isinstance(var, framework.Parameter) or var.persistable: + new_t = np.array(fluid.global_scope().find_var(var.name) + .get_tensor()) + base_t = base_map[var.name] + self.assertTrue(np.array_equal(new_t, base_t)) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main() diff --git a/tools/static_mode_white_list.py b/tools/static_mode_white_list.py index bc6c2ce0ea2..075d1a16927 100644 --- a/tools/static_mode_white_list.py +++ b/tools/static_mode_white_list.py @@ -480,6 +480,7 @@ STATIC_MODE_TESTING_LIST = [ 'test_squared_l2_norm_op', 'test_stack_op', 'test_static_save_load', + 'test_static_save_load_bf16', 'test_sum_op', 'test_switch', 'test_switch_case', -- GitLab