未验证 提交 832a014c 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add bf16 support for save and load ops (#33173)

* Add bf16 support for save and load ops

* Add bf16 test condition

* Add matmul and chagne fluid.io to paddle.static

* Reduce the test duration
上级 3af16297
......@@ -87,6 +87,8 @@ REGISTER_OP_CPU_KERNEL(
load_combine,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::LoadCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -69,6 +69,8 @@ REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
load, ops::LoadOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
ops::LoadOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -102,5 +102,7 @@ REGISTER_OP_CPU_KERNEL(
save_combine,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveCombineOpKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
USE_CPU_ONLY_OP(save_combine);
......@@ -76,33 +77,34 @@ void CheckValues(T* expect, U* actual, const paddle::framework::LoD& expect_lod,
// Here, we create 4 LoDTensors and use save_combine_op to first save these
// in a single file. Then, we use load_combine_op to load these sequentially
TEST(SaveLoadCombineOp, CPU) {
template <typename T, typename U>
void SaveLoadCombineOp() {
paddle::framework::Scope scope;
paddle::platform::CPUPlace place;
std::vector<int> lod1 = {0, 1, 2, 3, 10};
int numel1 = 100;
paddle::framework::LoD expect_lod1;
int* expect1 = CreateForSaveCombineOp<int, int>(10, 10, lod1, "test_var1",
place, &scope, &expect_lod1);
T* expect1 = CreateForSaveCombineOp<T, U>(10, 10, lod1, "test_var1", place,
&scope, &expect_lod1);
std::vector<int> lod2 = {0, 2, 5, 10};
int numel2 = 200;
paddle::framework::LoD expect_lod2;
int* expect2 = CreateForSaveCombineOp<int, int>(10, 20, lod2, "test_var2",
place, &scope, &expect_lod2);
T* expect2 = CreateForSaveCombineOp<T, U>(10, 20, lod2, "test_var2", place,
&scope, &expect_lod2);
std::vector<int> lod3 = {0, 2, 3, 20};
int numel3 = 4000;
paddle::framework::LoD expect_lod3;
int* expect3 = CreateForSaveCombineOp<int, int>(20, 200, lod3, "test_var3",
place, &scope, &expect_lod3);
T* expect3 = CreateForSaveCombineOp<T, U>(20, 200, lod3, "test_var3", place,
&scope, &expect_lod3);
std::vector<int> lod4 = {0, 1, 20};
int numel4 = 1000;
paddle::framework::LoD expect_lod4;
int* expect4 = CreateForSaveCombineOp<int, int>(20, 50, lod4, "test_var4",
place, &scope, &expect_lod4);
T* expect4 = CreateForSaveCombineOp<T, U>(20, 50, lod4, "test_var4", place,
&scope, &expect_lod4);
// Set attributes
std::string filename = "check_tensor.ls";
......@@ -128,15 +130,21 @@ TEST(SaveLoadCombineOp, CPU) {
load_combine_op->Run(scope, place);
paddle::framework::LoD actual_lod1, actual_lod2, actual_lod3, actual_lod4;
int* actual1 = GetValuesAfterLoadCombineOp<int>(target1, scope, &actual_lod1);
int* actual2 = GetValuesAfterLoadCombineOp<int>(target2, scope, &actual_lod2);
int* actual3 = GetValuesAfterLoadCombineOp<int>(target3, scope, &actual_lod3);
int* actual4 = GetValuesAfterLoadCombineOp<int>(target4, scope, &actual_lod4);
CheckValues<int, int>(expect1, actual1, expect_lod1, actual_lod1, numel1);
CheckValues<int, int>(expect2, actual2, expect_lod2, actual_lod2, numel2);
CheckValues<int, int>(expect3, actual3, expect_lod3, actual_lod3, numel3);
CheckValues<int, int>(expect4, actual4, expect_lod4, actual_lod4, numel4);
U* actual1 = GetValuesAfterLoadCombineOp<U>(target1, scope, &actual_lod1);
U* actual2 = GetValuesAfterLoadCombineOp<U>(target2, scope, &actual_lod2);
U* actual3 = GetValuesAfterLoadCombineOp<U>(target3, scope, &actual_lod3);
U* actual4 = GetValuesAfterLoadCombineOp<U>(target4, scope, &actual_lod4);
CheckValues<T, U>(expect1, actual1, expect_lod1, actual_lod1, numel1);
CheckValues<T, U>(expect2, actual2, expect_lod2, actual_lod2, numel2);
CheckValues<T, U>(expect3, actual3, expect_lod3, actual_lod3, numel3);
CheckValues<T, U>(expect4, actual4, expect_lod4, actual_lod4, numel4);
}
TEST(SaveLoadCombineOp, CPU) { SaveLoadCombineOp<int, int>(); }
TEST(SaveLoadCombineBF16Op, CPU) {
SaveLoadCombineOp<paddle::platform::bfloat16, paddle::platform::bfloat16>();
}
// FP16 version of SaveLoadCombineOp Test, only altering the saving aspect
......
......@@ -90,6 +90,8 @@ REGISTER_OP_CPU_KERNEL(
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, double>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext,
paddle::platform::float16>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext,
paddle::platform::bfloat16>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, uint8_t>,
ops::SaveOpKernel<paddle::platform::CPUDeviceContext, int8_t>,
......
......@@ -84,8 +84,8 @@ def train(use_cuda, save_dirname, is_local, use_bf16, pure_bf16):
feed=feeder.feed(data),
fetch_list=[avg_cost])
if avg_loss_value[0] < 10.0 or pure_bf16:
if save_dirname is not None and not pure_bf16:
fluid.io.save_inference_model(save_dirname, ['x'],
if save_dirname is not None:
paddle.static.save_inference_model(save_dirname, [x],
[y_predict], exe)
return
if math.isnan(float(avg_loss_value)):
......@@ -127,12 +127,12 @@ def infer(use_cuda, save_dirname=None, use_bf16=False):
inference_scope = fluid.core.Scope()
with fluid.scope_guard(inference_scope):
# Use fluid.io.load_inference_model to obtain the inference program desc,
# Use paddle.static.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be fed
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
fetch_targets] = paddle.static.load_inference_model(save_dirname, exe)
# The input's dimension should be 2-D and the second dim is 13
# The input data should be >= 0
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.fluid.core as core
import paddle.fluid as fluid
import paddle.fluid.framework as framework
from paddle.fluid.optimizer import SGDOptimizer
from paddle.fluid.tests.unittests.test_imperative_base import new_program_scope
from paddle.fluid.tests.unittests.test_static_save_load import PtbModel
import numpy as np
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestSaveLoadBF16(unittest.TestCase):
def set_place(self):
return fluid.CPUPlace()
def test_ptb_rnn_cpu_bfloat16(self):
seed = 90
hidden_size = 10
vocab_size = 500
num_layers = 1
num_steps = 3
init_scale = 0.1
batch_size = 4
batch_num = 100
with new_program_scope():
fluid.default_startup_program().random_seed = seed
fluid.default_main_program().random_seed = seed
ptb_model = PtbModel(
"ptb_model",
hidden_size=hidden_size,
vocab_size=vocab_size,
num_layers=num_layers,
num_steps=num_steps,
init_scale=init_scale)
place = self.set_place()
exe = fluid.Executor(place)
sgd = SGDOptimizer(learning_rate=1e-3)
x = fluid.layers.data(
name="x", shape=[-1, num_steps], dtype='int64')
y = fluid.layers.data(name="y", shape=[-1, 1], dtype='float32')
init_hidden = fluid.layers.data(
name="init_hidden", shape=[1], dtype='float32')
init_cell = fluid.layers.data(
name="init_cell", shape=[1], dtype='float32')
static_loss, static_last_hidden, static_last_cell = ptb_model(
x, y, init_hidden, init_cell)
sgd = paddle.static.amp.bf16.decorate_bf16(
sgd,
amp_lists=paddle.static.amp.bf16.AutoMixedPrecisionListsBF16(
custom_fp32_list={'transpose2', 'concat'}),
use_bf16_guard=False,
use_pure_bf16=True)
sgd.minimize(static_loss, framework.default_startup_program())
out = exe.run(framework.default_startup_program())
for i in range(batch_num):
x_data = np.arange(12).reshape(4, 3).astype('int64')
y_data = np.arange(1, 13).reshape(4, 3).astype('int64')
x_data = x_data.reshape((-1, num_steps, 1))
y_data = y_data.reshape((-1, 1))
init_hidden_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
init_cell_data = np.zeros(
(num_layers, batch_size, hidden_size), dtype='float32')
fetch_list = [static_loss, static_last_hidden, static_last_cell]
out = exe.run(fluid.default_main_program(),
feed={
"x": x_data,
"y": y_data,
"init_hidden": init_hidden_data,
"init_cell": init_cell_data
},
fetch_list=fetch_list)
# get value before save
main_program = framework.default_main_program()
base_map = {}
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimizer var have been update
self.assertTrue(np.sum(np.abs(t)) != 0)
base_map[var.name] = t
fluid.save(main_program, "./test_1")
# set var to zero
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
ten = fluid.global_scope().find_var(var.name).get_tensor()
ten.set(np.zeros_like(np.array(ten)), place)
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
# make sure all the paramerter or optimizer var have been set to zero
self.assertTrue(np.sum(np.abs(new_t)) == 0)
fluid.load(main_program, "./test_1.pdparams", exe)
for var in main_program.list_vars():
if isinstance(var, framework.Parameter) or var.persistable:
new_t = np.array(fluid.global_scope().find_var(var.name)
.get_tensor())
base_t = base_map[var.name]
self.assertTrue(np.array_equal(new_t, base_t))
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
......@@ -480,6 +480,7 @@ STATIC_MODE_TESTING_LIST = [
'test_squared_l2_norm_op',
'test_stack_op',
'test_static_save_load',
'test_static_save_load_bf16',
'test_sum_op',
'test_switch',
'test_switch_case',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册