提交 943dedec 编写于 作者: P phlrain

add sgd kernel; test=develop

上级 a4bccde0
...@@ -2048,7 +2048,11 @@ void OperatorWithKernel::BuildPhiKernelContext( ...@@ -2048,7 +2048,11 @@ void OperatorWithKernel::BuildPhiKernelContext(
// deal with optional here // deal with optional here
if ((it == ctx.inputs.end() || it->second.size() == 0) && if ((it == ctx.inputs.end() || it->second.size() == 0) &&
(input_defs[i].type_index == (input_defs[i].type_index ==
std::type_index(typeid(paddle::optional<const phi::DenseTensor&>)))) { std::type_index(
typeid(paddle::optional<const phi::DenseTensor&>)) ||
input_defs[i].type_index ==
std::type_index(
typeid(paddle::optional<const phi::SelectedRows&>)))) {
pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr); pt_kernel_context->EmplaceBackInputWithoutSetRange(nullptr);
auto end_idx = start_idx + 1; auto end_idx = start_idx + 1;
pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx), pt_kernel_context->AssignInputRange(std::make_pair(start_idx, end_idx),
......
...@@ -81,6 +81,12 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> { ...@@ -81,6 +81,12 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout, default_tensor_layout,
default_key.dtype(), default_key.dtype(),
arg_type); arg_type);
} else if (arg_type == std::type_index(typeid(
paddle::optional<const SelectedRows&>))) {
args_def->AppendInput(default_key.backend(),
default_tensor_layout,
default_key.dtype(),
arg_type);
} else if (arg_type == } else if (arg_type ==
std::type_index(typeid(const std::vector<DenseTensor>&))) { std::type_index(typeid(const std::vector<DenseTensor>&))) {
args_def->AppendInput(default_key.backend(), args_def->AppendInput(default_key.backend(),
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/phi/kernels/sgd_kernel.h" #include "paddle/phi/kernels/sgd_kernel.h"
#include "paddle/fluid/framework/mixed_vector.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_helper.h" #include "paddle/phi/backends/gpu/gpu_helper.h"
...@@ -72,7 +73,6 @@ void SGDDenseKernel(const Context& dev_ctx, ...@@ -72,7 +73,6 @@ void SGDDenseKernel(const Context& dev_ctx,
bool multi_precision, bool multi_precision,
DenseTensor* param_out, DenseTensor* param_out,
DenseTensor* master_param_out) { DenseTensor* master_param_out) {
LOG(ERROR) << "run here";
using MPDType = typename paddle::operators::details::MPTypeTrait<T>::Type; using MPDType = typename paddle::operators::details::MPTypeTrait<T>::Type;
// do check here // do check here
// if (multi_precision) { // if (multi_precision) {
......
...@@ -17,9 +17,7 @@ ...@@ -17,9 +17,7 @@
namespace phi { namespace phi {
KernelSignature SGDOpArgumentMapping(const ArgumentMappingContext& ctx) { KernelSignature SGDOpArgumentMapping(const ArgumentMappingContext& ctx) {
LOG(ERROR) << "11";
if (ctx.IsDenseTensorInput("Grad")) { if (ctx.IsDenseTensorInput("Grad")) {
LOG(ERROR) << "dense";
return KernelSignature("sgd", return KernelSignature("sgd",
{"Param", "LearningRate", "Grad", "MasterParam"}, {"Param", "LearningRate", "Grad", "MasterParam"},
{"multi_precision"}, {"multi_precision"},
......
...@@ -24,366 +24,374 @@ import paddle ...@@ -24,366 +24,374 @@ import paddle
paddle.enable_static() paddle.enable_static()
# class TestSGDOp(OpTest):
# def setUp(self): class TestSGDOp(OpTest):
# self.op_type = "sgd" def setUp(self):
# self.conf() self.op_type = "sgd"
# w = np.random.random((self.h, self.w)).astype("float32") self.conf()
# g = np.random.random((self.h, self.w)).astype("float32") w = np.random.random((self.h, self.w)).astype("float32")
# lr = np.array([0.1]).astype("float32") g = np.random.random((self.h, self.w)).astype("float32")
lr = np.array([0.1]).astype("float32")
# self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}
# self.outputs = {'ParamOut': w - lr * g} self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr}
self.outputs = {'ParamOut': w - lr * g}
# def conf(self):
# self.h = 102 def conf(self):
# self.w = 105 self.h = 102
self.w = 105
# def test_check_output(self):
# self.check_output() def test_check_output(self):
self.check_output()
# class TestSGDOpCase8X(TestSGDOp):
# def conf(self):
# self.h = 10 class TestSGDOpCase8X(TestSGDOp):
# self.w = 64 def conf(self):
self.h = 10
# class TestSparseSGDOp(unittest.TestCase): self.w = 64
# def check_with_place(self, place):
# scope = core.Scope()
class TestSparseSGDOp(unittest.TestCase):
# # create and initialize Grad Variable def check_with_place(self, place):
# height = 10 scope = core.Scope()
# rows = [0, 4, 7]
# self.conf() # create and initialize Grad Variable
height = 10
# grad_selected_rows = scope.var('Grad').get_selected_rows() rows = [0, 4, 7]
# grad_selected_rows.set_height(height) self.conf()
# grad_selected_rows.set_rows(rows)
# np_array = np.ones((len(rows), self.row_numel)).astype("float32") grad_selected_rows = scope.var('Grad').get_selected_rows()
# np_array[0, 0] = 2.0 grad_selected_rows.set_height(height)
# np_array[2, 8] = 4.0 grad_selected_rows.set_rows(rows)
np_array = np.ones((len(rows), self.row_numel)).astype("float32")
# grad_tensor = grad_selected_rows.get_tensor() np_array[0, 0] = 2.0
# grad_tensor.set(np_array, place) np_array[2, 8] = 4.0
# # create and initialize Param Variable grad_tensor = grad_selected_rows.get_tensor()
# param = scope.var('Param').get_tensor() grad_tensor.set(np_array, place)
# param_array = np.full((height, self.row_numel), 5.0).astype("float32")
# param.set(param_array, place) # create and initialize Param Variable
param = scope.var('Param').get_tensor()
# # create and initialize LeraningRate Variable param_array = np.full((height, self.row_numel), 5.0).astype("float32")
# lr = scope.var('LearningRate').get_tensor() param.set(param_array, place)
# lr_array = np.full((1), 2.0).astype("float32")
# lr.set(lr_array, place) # create and initialize LeraningRate Variable
lr = scope.var('LearningRate').get_tensor()
# # create and run sgd operator lr_array = np.full((1), 2.0).astype("float32")
# sgd_op = Operator( lr.set(lr_array, place)
# "sgd",
# Param='Param', # create and run sgd operator
# Grad='Grad', sgd_op = Operator(
# ParamOut='Param', "sgd",
# LearningRate='LearningRate') Param='Param',
# sgd_op.run(scope, place) Grad='Grad',
ParamOut='Param',
# # get and compare result LearningRate='LearningRate')
# result_array = np.array(param) sgd_op.run(scope, place)
# # rows[0] = 0, 5.0 - 2.0 * 2.0 # get and compare result
# self.assertAlmostEqual(1.0, result_array[rows[0], 0]) result_array = np.array(param)
# # rows[0] = 0, 5.0 - 2.0 * 1.0
# self.assertAlmostEqual(3.0, result_array[rows[0], 2]) # rows[0] = 0, 5.0 - 2.0 * 2.0
# # 5.0 - 2.0 * 0.0 self.assertAlmostEqual(1.0, result_array[rows[0], 0])
# self.assertAlmostEqual(5.0, result_array[1, 0]) # rows[0] = 0, 5.0 - 2.0 * 1.0
# # rows[1] = 4, 5.0 - 2.0 * 1.0 self.assertAlmostEqual(3.0, result_array[rows[0], 2])
# self.assertAlmostEqual(3.0, result_array[rows[1], 10]) # 5.0 - 2.0 * 0.0
# # 5.0 - 2.0 * 0.0 self.assertAlmostEqual(5.0, result_array[1, 0])
# self.assertAlmostEqual(5.0, result_array[5, 8]) # rows[1] = 4, 5.0 - 2.0 * 1.0
# # rows[2] = 7, 5.0 - 2.0 * 1.0 self.assertAlmostEqual(3.0, result_array[rows[1], 10])
# self.assertAlmostEqual(3.0, result_array[rows[2], 1]) # 5.0 - 2.0 * 0.0
# # rows[2] = 7, 5.0 - 2.0 * 4.0 self.assertAlmostEqual(5.0, result_array[5, 8])
# self.assertAlmostEqual(-3.0, result_array[rows[2], 8]) # rows[2] = 7, 5.0 - 2.0 * 1.0
self.assertAlmostEqual(3.0, result_array[rows[2], 1])
# def test_sparse_sgd(self): # rows[2] = 7, 5.0 - 2.0 * 4.0
# places = [core.CPUPlace()] self.assertAlmostEqual(-3.0, result_array[rows[2], 8])
# if core.is_compiled_with_cuda():
# places.append(core.CUDAPlace(0)) def test_sparse_sgd(self):
# for place in places: places = [core.CPUPlace()]
# self.check_with_place(place) if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
# def conf(self): for place in places:
# self.row_numel = 12 self.check_with_place(place)
# class TestSparseSGDOpCase8X(TestSparseSGDOp): def conf(self):
# def conf(self): self.row_numel = 12
# self.row_numel = 16
# class TestSGDOpOptimizeSelectedRows(unittest.TestCase): class TestSparseSGDOpCase8X(TestSparseSGDOp):
# def check_with_place(self, place): def conf(self):
# scope = core.Scope() self.row_numel = 16
# row_width = 12
# # create and initialize Grad Variable class TestSGDOpOptimizeSelectedRows(unittest.TestCase):
# grad_height = 10 def check_with_place(self, place):
# grad_rows = [0, 4, 7] scope = core.Scope()
# grad_selected_rows = scope.var('Grad').get_selected_rows() row_width = 12
# grad_selected_rows.set_height(grad_height) # create and initialize Grad Variable
# grad_selected_rows.set_rows(grad_rows) grad_height = 10
# grad_array = np.ones((len(grad_rows), row_width)).astype("float32") grad_rows = [0, 4, 7]
# grad_array[0, 0] = 2.0
# grad_array[2, 8] = 4.0 grad_selected_rows = scope.var('Grad').get_selected_rows()
grad_selected_rows.set_height(grad_height)
# grad_tensor = grad_selected_rows.get_tensor() grad_selected_rows.set_rows(grad_rows)
# grad_tensor.set(grad_array, place) grad_array = np.ones((len(grad_rows), row_width)).astype("float32")
grad_array[0, 0] = 2.0
# # create and initialize Param Variable grad_array[2, 8] = 4.0
# # create and initialize W Variable
# param_rows = [0, 1, 2, 3, 4, 5, 6, 7] grad_tensor = grad_selected_rows.get_tensor()
grad_tensor.set(grad_array, place)
# # init Param
# w_selected_rows = scope.var('Param').get_selected_rows() # create and initialize Param Variable
# w_selected_rows.set_height(len(param_rows)) # create and initialize W Variable
# w_selected_rows.set_rows(param_rows) param_rows = [0, 1, 2, 3, 4, 5, 6, 7]
# w_selected_rows.sync_index()
# w_array = np.ones((len(param_rows), row_width)).astype("float32") # init Param
# for i in range(len(param_rows)): w_selected_rows = scope.var('Param').get_selected_rows()
# w_array[i] *= i w_selected_rows.set_height(len(param_rows))
# w_tensor = w_selected_rows.get_tensor() w_selected_rows.set_rows(param_rows)
# w_tensor.set(w_array, place) w_selected_rows.sync_index()
w_array = np.ones((len(param_rows), row_width)).astype("float32")
# w_before_optimize = np.array(w_tensor) for i in range(len(param_rows)):
w_array[i] *= i
# # create and initialize LeraningRate Variable w_tensor = w_selected_rows.get_tensor()
# lr_value = 0.1 w_tensor.set(w_array, place)
# lr = scope.var('LearningRate').get_tensor()
# lr_array = np.full((1), lr_value).astype("float32") w_before_optimize = np.array(w_tensor)
# lr.set(lr_array, place)
# create and initialize LeraningRate Variable
# # optimize with Python lr_value = 0.1
# w_after_optimize = np.copy(w_before_optimize) lr = scope.var('LearningRate').get_tensor()
# for index, id in enumerate(grad_rows): lr_array = np.full((1), lr_value).astype("float32")
# w_after_optimize[id] = w_before_optimize[ lr.set(lr_array, place)
# id] - lr_value * grad_array[index]
# optimize with Python
# # create and run sgd operator w_after_optimize = np.copy(w_before_optimize)
# sgd_op = Operator( for index, id in enumerate(grad_rows):
# "sgd", w_after_optimize[id] = w_before_optimize[
# Param='Param', id] - lr_value * grad_array[index]
# Grad='Grad',
# ParamOut='Param', # create and run sgd operator
# LearningRate='LearningRate') sgd_op = Operator(
# sgd_op.run(scope, place) "sgd",
Param='Param',
# # get and compare result Grad='Grad',
# result_array = np.array(w_tensor) ParamOut='Param',
# assert (result_array == w_after_optimize).all() LearningRate='LearningRate')
sgd_op.run(scope, place)
# def test_sparse_parameter_sgd(self):
# places = [core.CPUPlace()] # get and compare result
# # do not support GPU kernel currently result_array = np.array(w_tensor)
# for place in places: assert (result_array == w_after_optimize).all()
# self.check_with_place(place)
def test_sparse_parameter_sgd(self):
# class TestSGDOpWithLargeInput(unittest.TestCase): places = [core.CPUPlace()]
# def runTest(self): # do not support GPU kernel currently
# paddle.enable_static() for place in places:
# data = fluid.layers.fill_constant(shape=[1], value=128, dtype='int64') self.check_with_place(place)
# label = fluid.layers.fill_constant(
# shape=[1, 150], value=0.5, dtype='float32')
# emb = fluid.embedding(input=data, size=(10000000, 150), dtype='float32') class TestSGDOpWithLargeInput(unittest.TestCase):
# out = fluid.layers.l2_normalize(x=emb, axis=-1) def runTest(self):
paddle.enable_static()
# cost = fluid.layers.square_error_cost(input=out, label=label) data = fluid.layers.fill_constant(shape=[1], value=128, dtype='int64')
# avg_cost = fluid.layers.mean(cost) label = fluid.layers.fill_constant(
# sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) shape=[1, 150], value=0.5, dtype='float32')
# sgd_optimizer.minimize(avg_cost) emb = fluid.embedding(input=data, size=(10000000, 150), dtype='float32')
out = fluid.layers.l2_normalize(x=emb, axis=-1)
# place = fluid.CPUPlace()
# exe = fluid.Executor(place) cost = fluid.layers.square_error_cost(input=out, label=label)
# exe.run(fluid.default_startup_program()) avg_cost = fluid.layers.mean(cost)
# compiled_prog = fluid.compiler.CompiledProgram( sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
# fluid.default_main_program()) sgd_optimizer.minimize(avg_cost)
# result = exe.run(compiled_prog, fetch_list=[avg_cost])
place = fluid.CPUPlace()
# class TestSGDV2(unittest.TestCase): exe = fluid.Executor(place)
# def test_sgd_dygraph(self): exe.run(fluid.default_startup_program())
# paddle.disable_static() compiled_prog = fluid.compiler.CompiledProgram(
# value = np.arange(26).reshape(2, 13).astype("float32") fluid.default_main_program())
# a = paddle.to_tensor(value) result = exe.run(compiled_prog, fetch_list=[avg_cost])
# linear = paddle.nn.Linear(13, 5)
# # This can be any optimizer supported by dygraph.
# adam = paddle.optimizer.SGD(learning_rate=0.01, class TestSGDV2(unittest.TestCase):
# parameters=linear.parameters(), def test_sgd_dygraph(self):
# weight_decay=0.01) paddle.disable_static()
# out = linear(a) value = np.arange(26).reshape(2, 13).astype("float32")
# out.backward() a = paddle.to_tensor(value)
# adam.step() linear = paddle.nn.Linear(13, 5)
# adam.clear_gradients() # This can be any optimizer supported by dygraph.
adam = paddle.optimizer.SGD(learning_rate=0.01,
# def test_sgd(self): parameters=linear.parameters(),
# paddle.enable_static() weight_decay=0.01)
out = linear(a)
# def check_sgd_optimizer(optimizer_attr): out.backward()
# init_program = paddle.static.Program() adam.step()
# program = paddle.static.Program() adam.clear_gradients()
# block = program.global_block()
# mul_x = block.create_parameter( def test_sgd(self):
# dtype="float32", paddle.enable_static()
# shape=[5, 10],
# lod_level=0, def check_sgd_optimizer(optimizer_attr):
# name="mul.x", init_program = paddle.static.Program()
# optimize_attr=optimizer_attr) program = paddle.static.Program()
# mul_y = block.create_var( block = program.global_block()
# dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") mul_x = block.create_parameter(
# mul_out = block.create_var( dtype="float32",
# dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") shape=[5, 10],
# mean_out = block.create_var( lod_level=0,
# dtype="float32", shape=[1], lod_level=0, name="mean.out") name="mul.x",
# block.append_op( optimize_attr=optimizer_attr)
# type="mul", mul_y = block.create_var(
# inputs={"X": mul_x, dtype="float32", shape=[10, 8], lod_level=0, name="mul.y")
# "Y": mul_y}, mul_out = block.create_var(
# outputs={"Out": mul_out}, dtype="float32", shape=[5, 8], lod_level=0, name="mul.out")
# attrs={"x_num_col_dims": 1}) mean_out = block.create_var(
# block.append_op( dtype="float32", shape=[1], lod_level=0, name="mean.out")
# type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) block.append_op(
# sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.01) type="mul",
# opts, _ = sgd_optimizer.minimize(mean_out, init_program) inputs={"X": mul_x,
# return opts "Y": mul_y},
outputs={"Out": mul_out},
# opts = check_sgd_optimizer({'learning_rate': 1.1}) attrs={"x_num_col_dims": 1})
# self.assertEqual(len(opts), 2) block.append_op(
# self.assertEqual([op.type for op in opts], ["scale", "sgd"]) type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out})
sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.01)
# opts = check_sgd_optimizer({'learning_rate': 1.0}) opts, _ = sgd_optimizer.minimize(mean_out, init_program)
# self.assertEqual(len(opts), 1) return opts
# self.assertEqual([op.type for op in opts], ["sgd"])
opts = check_sgd_optimizer({'learning_rate': 1.1})
# def test_raise_error(self): self.assertEqual(len(opts), 2)
# self.assertRaises(ValueError, paddle.optimizer.SGD, learning_rate=None) self.assertEqual([op.type for op in opts], ["scale", "sgd"])
# def test_sgd_group_dygraph(self): opts = check_sgd_optimizer({'learning_rate': 1.0})
# paddle.disable_static() self.assertEqual(len(opts), 1)
# value = np.arange(26).reshape(2, 13).astype("float32") self.assertEqual([op.type for op in opts], ["sgd"])
# a = paddle.to_tensor(value)
# linear_1 = paddle.nn.Linear(13, 5) def test_raise_error(self):
# linear_2 = paddle.nn.Linear(5, 3) self.assertRaises(ValueError, paddle.optimizer.SGD, learning_rate=None)
# # This can be any optimizer supported by dygraph.
# adam = paddle.optimizer.SGD(learning_rate=0.01, def test_sgd_group_dygraph(self):
# parameters=[{ paddle.disable_static()
# 'params': linear_1.parameters() value = np.arange(26).reshape(2, 13).astype("float32")
# }, { a = paddle.to_tensor(value)
# 'params': linear_2.parameters(), linear_1 = paddle.nn.Linear(13, 5)
# 'weight_decay': 0.001, linear_2 = paddle.nn.Linear(5, 3)
# 'learning_rate': 0.1 # This can be any optimizer supported by dygraph.
# }], adam = paddle.optimizer.SGD(learning_rate=0.01,
# weight_decay=0.01) parameters=[{
# out = linear_1(a) 'params': linear_1.parameters()
# out = linear_2(out) }, {
# out.backward() 'params': linear_2.parameters(),
# adam.step() 'weight_decay': 0.001,
# adam.clear_gradients() 'learning_rate': 0.1
}],
# class TestSGDMultiPrecision2_0(unittest.TestCase): weight_decay=0.01)
# def dygraph_sgd_mp(self, mp): out = linear_1(a)
# paddle.disable_static() out = linear_2(out)
# paddle.seed(10) out.backward()
# paddle.set_device('gpu') adam.step()
# input = paddle.randn((2, 2)) adam.clear_gradients()
# model = paddle.nn.Linear(2, 2)
# optimizer = paddle.optimizer.SGD(parameters=model.parameters(),
# multi_precision=mp) class TestSGDMultiPrecision2_0(unittest.TestCase):
# if mp == True: def dygraph_sgd_mp(self, mp):
# model = paddle.amp.decorate(models=model, level='O2') paddle.disable_static()
# scaler = paddle.amp.GradScaler(init_loss_scaling=1024) paddle.seed(10)
paddle.set_device('gpu')
# for idx in range(5): input = paddle.randn((2, 2))
# if mp == True: model = paddle.nn.Linear(2, 2)
# with paddle.amp.auto_cast(level='O2'): optimizer = paddle.optimizer.SGD(parameters=model.parameters(),
# output = model(input) multi_precision=mp)
# loss = paddle.mean(output) if mp == True:
# scaled = scaler.scale(loss) model = paddle.amp.decorate(models=model, level='O2')
# scaled.backward() scaler = paddle.amp.GradScaler(init_loss_scaling=1024)
# scaler.minimize(optimizer, scaled)
# optimizer.clear_grad() for idx in range(5):
# else: if mp == True:
# output = model(input) with paddle.amp.auto_cast(level='O2'):
# loss = paddle.mean(output) output = model(input)
# optimizer.step() loss = paddle.mean(output)
# optimizer.clear_grad() scaled = scaler.scale(loss)
scaled.backward()
# return output, model.parameters() scaler.minimize(optimizer, scaled)
optimizer.clear_grad()
# def static_sgd_mp(self, mp): else:
# paddle.enable_static() output = model(input)
# paddle.seed(10) loss = paddle.mean(output)
# np.random.seed(10) optimizer.step()
# exe = paddle.static.Executor('gpu') optimizer.clear_grad()
# train_program = paddle.static.Program()
# startup_program = paddle.static.Program() return output, model.parameters()
# optimizer = paddle.optimizer.SGD(multi_precision=mp)
def static_sgd_mp(self, mp):
# if mp: paddle.enable_static()
# optimizer = paddle.static.amp.decorate( paddle.seed(10)
# optimizer, np.random.seed(10)
# init_loss_scaling=128.0, exe = paddle.static.Executor('gpu')
# use_dynamic_loss_scaling=True, train_program = paddle.static.Program()
# use_pure_fp16=True, startup_program = paddle.static.Program()
# use_fp16_guard=False) optimizer = paddle.optimizer.SGD(multi_precision=mp)
# with paddle.static.program_guard(train_program, startup_program):
# if mp: if mp:
# data = paddle.static.data( optimizer = paddle.static.amp.decorate(
# shape=[2, 2], name='X', dtype='float16') optimizer,
# else: init_loss_scaling=128.0,
# data = paddle.static.data( use_dynamic_loss_scaling=True,
# shape=[2, 2], name='X', dtype='float32') use_pure_fp16=True,
# hidden = paddle.static.nn.fc(x=data, size=10) use_fp16_guard=False)
# loss = paddle.fluid.layers.mean(hidden) with paddle.static.program_guard(train_program, startup_program):
# optimizer.minimize(loss) if mp:
# exe.run(startup_program) data = paddle.static.data(
shape=[2, 2], name='X', dtype='float16')
# if mp: else:
# optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) data = paddle.static.data(
# x = np.random.random(size=(2, 2)).astype('float16') shape=[2, 2], name='X', dtype='float32')
# else: hidden = paddle.static.nn.fc(x=data, size=10)
# x = np.random.random(size=(2, 2)).astype('float32') loss = paddle.fluid.layers.mean(hidden)
# out = [] optimizer.minimize(loss)
# for idx in range(5): exe.run(startup_program)
# loss_data, = exe.run(train_program,
# feed={"X": x}, if mp:
# fetch_list=[loss.name]) optimizer.amp_init(place='gpu', scope=paddle.static.global_scope())
# out.append(loss_data) x = np.random.random(size=(2, 2)).astype('float16')
# return out else:
x = np.random.random(size=(2, 2)).astype('float32')
# def test_main(self): out = []
# if not paddle.is_compiled_with_cuda(): for idx in range(5):
# return loss_data, = exe.run(train_program,
# "Test dygraph mode" feed={"X": x},
# output1_dy, params1_dy = self.dygraph_sgd_mp(mp=True) fetch_list=[loss.name])
# output2_dy, params2_dy = self.dygraph_sgd_mp(mp=False) out.append(loss_data)
# self.assertEqual( return out
# np.allclose(
# output1_dy.astype('float32').numpy(), def test_main(self):
# output2_dy.astype('float32').numpy(), if not paddle.is_compiled_with_cuda():
# atol=1e-01), return
# True) "Test dygraph mode"
# for idx in range(len(params1_dy)): output1_dy, params1_dy = self.dygraph_sgd_mp(mp=True)
# self.assertEqual( output2_dy, params2_dy = self.dygraph_sgd_mp(mp=False)
# np.allclose( self.assertEqual(
# params1_dy[idx].astype('float32').numpy(), np.allclose(
# params2_dy[idx].astype('float32').numpy(), output1_dy.astype('float32').numpy(),
# atol=1e-01), output2_dy.astype('float32').numpy(),
# True) atol=1e-01),
# "Test static mode" True)
# output1_st = self.static_sgd_mp(mp=True) for idx in range(len(params1_dy)):
# output2_st = self.static_sgd_mp(mp=False) self.assertEqual(
# for idx in range(len(output1_st)): np.allclose(
# self.assertEqual( params1_dy[idx].astype('float32').numpy(),
# np.allclose( params2_dy[idx].astype('float32').numpy(),
# output1_st[idx].astype('float32'), atol=1e-01),
# output2_st[idx].astype('float32'), True)
# atol=1e-01), "Test static mode"
# True) output1_st = self.static_sgd_mp(mp=True)
output2_st = self.static_sgd_mp(mp=False)
for idx in range(len(output1_st)):
self.assertEqual(
np.allclose(
output1_st[idx].astype('float32'),
output2_st[idx].astype('float32'),
atol=1e-01),
True)
class TestSGDMultiPrecision1_0(unittest.TestCase): class TestSGDMultiPrecision1_0(unittest.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册