diff --git a/paddle/fluid/operators/optimizers/dgc_momentum_op.h b/paddle/fluid/operators/optimizers/dgc_momentum_op.h index bea019f1f36e2ea21890f23b753b4df1d62c0e3b..c86f544ed77ff13cc59735971cf856f66bc12202 100644 --- a/paddle/fluid/operators/optimizers/dgc_momentum_op.h +++ b/paddle/fluid/operators/optimizers/dgc_momentum_op.h @@ -17,7 +17,7 @@ #include #include "paddle/fluid/operators/optimizers/momentum_op.h" -#include "paddle/fluid/operators/optimizers/sgd_op.h" +#include "paddle/phi/kernels/sgd_kernel.h" namespace paddle { namespace operators { @@ -26,8 +26,7 @@ template class DGCMomentumKernel : public framework::OpKernel { public: DGCMomentumKernel() - : _momentum_op_kernel(new MomentumOpKernel()), - _sgd_op_kernel(new SGDOpKernel()) {} + : _momentum_op_kernel(new MomentumOpKernel()) {} void Compute(const framework::ExecutionContext& context) const override { auto rampup_begin_step = context.Attr("rampup_begin_step"); @@ -67,12 +66,68 @@ class DGCMomentumKernel : public framework::OpKernel { } VLOG(10) << " so use sgd optimizer"; - return _sgd_op_kernel->Compute(context); + + const auto* param_var = context.InputVar("Param"); + const auto* grad_var = context.InputVar("Grad"); + auto* learning_rate = context.Input("LearningRate"); + bool multi_precision = context.Attr("multi_precision"); + if (param_var->IsType()) { + auto* param = context.Input("Param"); + auto* param_out = context.Output("ParamOut"); + auto* master_param_out = + context.Output("MasterParamOut"); + paddle::optional master_param_opt = + paddle::none; + if (multi_precision) { + auto* master_param = context.Input("MasterParam"); + master_param_opt = *master_param; + } + + if (grad_var->IsType()) { + // sgd_dense + auto* grad = context.Input("Grad"); + phi::SGDDenseKernel( + static_cast::TYPE&>(dev_ctx), + *param, *learning_rate, *grad, master_param_opt, multi_precision, + param_out, master_param_out); + } else { + // sgd dense param sparse grad + auto* grad = context.Input("Grad"); + phi::SGDDenseParamSparseGradKernel( + static_cast::TYPE&>(dev_ctx), + *param, *learning_rate, *grad, master_param_opt, multi_precision, + param_out, master_param_out); + } + } else if (param_var->IsType() && + grad_var->IsType() && + platform::is_cpu_place(context.GetPlace())) { + // sgd sparse param sparse grad + auto* param = context.Input("Param"); + auto* param_out = context.Output("ParamOut"); + auto* master_param_out = + context.Output("MasterParamOut"); + paddle::optional master_param_opt = + paddle::none; + if (multi_precision) { + auto* master_param = context.Input("MasterParam"); + master_param_opt = *master_param; + } + auto* grad = context.Input("Grad"); + phi::SGDSparseParamSparseGradKernel( + static_cast::TYPE&>(dev_ctx), + *param, *learning_rate, *grad, master_param_opt, multi_precision, + param_out, master_param_out); + + } else { + PADDLE_THROW("gdc not support yet"); + } } private: std::unique_ptr> _momentum_op_kernel; - std::unique_ptr> _sgd_op_kernel; }; } // namespace operators diff --git a/paddle/fluid/operators/optimizers/sgd_op.cc b/paddle/fluid/operators/optimizers/sgd_op.cc index 529d60a2820ea92de0b0009b31c9f2ad04d4860a..0e3f895d276af6856c64ddd123606b087689ca9a 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cc +++ b/paddle/fluid/operators/optimizers/sgd_op.cc @@ -166,8 +166,3 @@ REGISTER_OPERATOR( paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker, ops::SGDOpInferVarType); -REGISTER_OP_CPU_KERNEL( - sgd, ops::SGDOpKernel, - ops::SGDOpKernel, - ops::SGDOpKernel); diff --git a/paddle/fluid/operators/optimizers/sgd_op.cu b/paddle/fluid/operators/optimizers/sgd_op.cu index 3149f5f56ed4964a750f61a354c6cd31a29fc526..222244a2fd1e34ace573ad4fa06775c0e5113925 100644 --- a/paddle/fluid/operators/optimizers/sgd_op.cu +++ b/paddle/fluid/operators/optimizers/sgd_op.cu @@ -166,10 +166,3 @@ class SGDOpKernel }; } // namespace operators } // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - sgd, ops::SGDOpKernel, - ops::SGDOpKernel, - ops::SGDOpKernel); diff --git a/paddle/phi/core/kernel_utils.h b/paddle/phi/core/kernel_utils.h index 2fda3cb6db4fdb4aaac7fc7c88075b833c050bad..2ce1c829ce81a57cfad7343e2007ebf75b85ea80 100644 --- a/paddle/phi/core/kernel_utils.h +++ b/paddle/phi/core/kernel_utils.h @@ -221,6 +221,7 @@ struct KernelImpl { PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor); PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(DenseTensor); + PT_SPECIALIZE_KernelCallHelper_FOR_OPTIONAL_INPUT(SelectedRows); PT_SPECIALIZE_KernelCallHelper_FOR_MULTI_INPUT(DenseTensor); #ifndef PADDLE_WITH_CUSTOM_KERNEL PT_SPECIALIZE_KernelCallHelper_FOR_INPUT(SelectedRows); diff --git a/paddle/phi/kernels/cpu/sgd_kernel.cc b/paddle/phi/kernels/cpu/sgd_kernel.cc index 91b39292612b52ee47fc2bc77c7f205158bdc29c..c7b4074c70aaa21d4575dd69766bc757271f047e 100644 --- a/paddle/phi/kernels/cpu/sgd_kernel.cc +++ b/paddle/phi/kernels/cpu/sgd_kernel.cc @@ -14,6 +14,8 @@ #include "paddle/phi/kernels/sgd_kernel.h" #include "paddle/fluid/operators/jit/kernels.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { @@ -112,40 +114,42 @@ void sgd_dense_param_sparse_grad_impl( } template -void SGDKernel(const Context& dev_ctx, - const DenseTensor& param, - const DenseTensor& learning_rate, - const DenseTensor& grad, - const DenseTensor& master_param, - bool multi_precision, - DenseTensor* param_out, - DenseTensor* master_param_out) { +void SGDDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + paddle::optional master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { dev_ctx.template Alloc(param_out); sgd_dense_param_dense_grad_impl(param, learning_rate, grad, param_out); } template -void SGDKernel(const Context& dev_ctx, - const DenseTensor& param, - const DenseTensor& learning_rate, - const SelectedRows& grad, - const DenseTensor& master_param, - bool multi_precision, - DenseTensor* param_out, - DenseTensor* master_param_out) { +void SGDDenseParamSparseGradKernel( + const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + paddle::optional master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { dev_ctx.template Alloc(param_out); sgd_dense_param_sparse_grad_impl(param, learning_rate, grad, param_out); } template -void SGDKernel(const Context& dev_ctx, - const SelectedRows& param, - const DenseTensor& learning_rate, - const SelectedRows& grad, - const SelectedRows& master_param, - bool multi_precision, - SelectedRows* param_out, - SelectedRows* master_param_out) { +void SGDSparseParamSparseGradKernel( + const Context& dev_ctx, + const SelectedRows& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + paddle::optional master_param, + bool multi_precision, + SelectedRows* param_out, + SelectedRows* master_param_out) { // for distributed training, a sparse var may be empty, // just skip updating. if (grad.rows().size() == 0) { @@ -183,3 +187,27 @@ void SGDKernel(const Context& dev_ctx, } } // namespace phi + +PD_REGISTER_KERNEL(sgd, + CPU, + ALL_LAYOUT, + phi::SGDDenseKernel, + phi::dtype::bfloat16, + float, + double) {} + +PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad, + CPU, + ALL_LAYOUT, + phi::SGDDenseParamSparseGradKernel, + phi::dtype::bfloat16, + float, + double) {} + +PD_REGISTER_KERNEL(sgd_sparse_param_sparse_grad, + CPU, + ALL_LAYOUT, + phi::SGDSparseParamSparseGradKernel, + phi::dtype::bfloat16, + float, + double) {} diff --git a/paddle/phi/kernels/gpu/sgd_kernel.cu b/paddle/phi/kernels/gpu/sgd_kernel.cu index 076bd0a7ad1e4ac1a70907823f7b23fb8bc620da..74c377b5596c82f1dc30b3751fc54a420f1ef86f 100644 --- a/paddle/phi/kernels/gpu/sgd_kernel.cu +++ b/paddle/phi/kernels/gpu/sgd_kernel.cu @@ -18,6 +18,9 @@ #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_helper.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + namespace phi { template @@ -61,14 +64,15 @@ __global__ void SparseSGDFunctorKernel(const T* selected_rows, } template -void SGDKernel(const Context& dev_ctx, - const DenseTensor& param, - const DenseTensor& learning_rate, - const DenseTensor& grad, - const DenseTensor& master_param, - bool multi_precision, - DenseTensor* param_out, - DenseTensor* master_param_out) { +void SGDDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + paddle::optional master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { + LOG(ERROR) << "run here"; using MPDType = typename paddle::operators::details::MPTypeTrait::Type; // do check here // if (multi_precision) { @@ -77,7 +81,7 @@ void SGDKernel(const Context& dev_ctx, // } const MPDType* master_in_data = - multi_precision ? master_param.data() : nullptr; + multi_precision ? master_param->data() : nullptr; MPDType* master_out_data = multi_precision ? master_param_out->mutable_data(dev_ctx.GetPlace()) @@ -91,20 +95,21 @@ void SGDKernel(const Context& dev_ctx, grad.data(), learning_rate.data(), param.numel(), - param_out->mutable_data(ctx.GetPlace()), + param_out->mutable_data(dev_ctx.GetPlace()), master_in_data, master_out_data); } template -void SGDKernel(const Context& dev_ctx, - const DenseTensor& param, - const DenseTensor& learning_rate, - const SelectedRows& grad, - const DenseTensor& master_param, - bool multi_precision, - DenseTensor* param_out, - DenseTensor* master_param_out) { +void SGDDenseParamSparseGradKernel( + const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + paddle::optional master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out) { using MPDType = typename paddle::operators::details::MPTypeTrait::Type; // do some check here // if (multi_precision) { @@ -113,7 +118,7 @@ void SGDKernel(const Context& dev_ctx, // } const MPDType* master_in_data = - multi_precision ? master_param.data() : nullptr; + multi_precision ? master_param->data() : nullptr; MPDType* master_out_data = multi_precision ? master_param_out->mutable_data(dev_ctx.GetPlace()) @@ -155,7 +160,7 @@ void SGDKernel(const Context& dev_ctx, int max_threads = dev_ctx.GetMaxPhysicalThreadCount(); int max_blocks = std::max(max_threads / kThreadsPerBlock, 1); paddle::framework::MixVector mixv_in_rows(&in_rows); - SparseSGDFunctorKernel<<>>( + SparseSGDFunctorKernel<<>>( in_data, mixv_in_rows.CUDAData(dev_ctx.GetPlace()), learning_rate.data(), @@ -164,4 +169,41 @@ void SGDKernel(const Context& dev_ctx, in_rows.size()); } +template +void SGDSparseParamSparseGradKernel( + const Context& dev_ctx, + const SelectedRows& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + paddle::optional master_param, + bool multi_precision, + SelectedRows* param_out, + SelectedRows* master_param_out) { + PADDLE_THROW("not impl"); +} + } // namespace phi + +PD_REGISTER_KERNEL(sgd, + GPU, + ALL_LAYOUT, + phi::SGDDenseKernel, + phi::dtype::float16, + float, + double) {} + +PD_REGISTER_KERNEL(sgd_dense_param_sparse_grad, + GPU, + ALL_LAYOUT, + phi::SGDDenseParamSparseGradKernel, + phi::dtype::float16, + float, + double) {} + +PD_REGISTER_KERNEL(sgd_sparse_param_sparse_grad, + GPU, + ALL_LAYOUT, + phi::SGDSparseParamSparseGradKernel, + phi::dtype::float16, + float, + double) {} diff --git a/paddle/phi/kernels/sgd_kernel.h b/paddle/phi/kernels/sgd_kernel.h index 9490940325e91efddc861561ab9e5d233c34eed2..12361c738e247b6d2f3fc4813cc8ad89da3e8bb7 100644 --- a/paddle/phi/kernels/sgd_kernel.h +++ b/paddle/phi/kernels/sgd_kernel.h @@ -20,33 +20,35 @@ namespace phi { template -void SGDKernel(const Context& dev_ctx, - const DenseTensor& param, - const DenseTensor& learning_rate, - const DenseTensor& grad, - const DenseTensor& master_param, - bool multi_precision, - DenseTensor* param_out, - DenseTensor* master_param_out); +void SGDDenseKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const DenseTensor& grad, + paddle::optional master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out); template -void SGDKernel(const Context& dev_ctx, - const DenseTensor& param, - const DenseTensor& learning_rate, - const SelectedRows& grad, - const DenseTensor& master_param, - bool multi_precision, - DenseTensor* param_out, - DenseTensor* master_param_out); +void SGDDenseParamSparseGradKernel( + const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + paddle::optional master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* master_param_out); template -void SGDKernel(const Context& dev_ctx, - const SelectedRows& param, - const DenseTensor& learning_rate, - const SelectedRows& grad, - const SelectedRows& master_param, - bool multi_precision, - SelectedRows* param_out, - SelectedRows* master_param_out); +void SGDSparseParamSparseGradKernel( + const Context& dev_ctx, + const SelectedRows& param, + const DenseTensor& learning_rate, + const SelectedRows& grad, + paddle::optional master_param, + bool multi_precision, + SelectedRows* param_out, + SelectedRows* master_param_out); } // namespace phi diff --git a/paddle/phi/ops/compat/sgd_sig.cc b/paddle/phi/ops/compat/sgd_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..ac75cf1d5da6d94b746015c4aba23e044355a42d --- /dev/null +++ b/paddle/phi/ops/compat/sgd_sig.cc @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature SGDOpArgumentMapping(const ArgumentMappingContext& ctx) { + LOG(ERROR) << "11"; + if (ctx.IsDenseTensorInput("Grad")) { + LOG(ERROR) << "dense"; + return KernelSignature("sgd", + {"Param", "LearningRate", "Grad", "MasterParam"}, + {"multi_precision"}, + {"ParamOut", "MasterParamOut"}); + } else if (ctx.IsSelectedRowsInput("Grad")) { + if (ctx.IsDenseTensorInput("Param")) { + return KernelSignature("sgd_dense_param_sparse_grad", + {"Param", "LearningRate", "Grad", "MasterParam"}, + {"multi_precision"}, + {"ParamOut", "MasterParamOut"}); + } else { + return KernelSignature("sgd_sparse_param_sparse_grad", + {"Param", "LearningRate", "Grad", "MasterParam"}, + {"multi_precision"}, + {"ParamOut", "MasterParamOut"}); + } + } + + return KernelSignature("unregistered", {}, {}, {}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(sgd, phi::SGDOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_sgd_op.py b/python/paddle/fluid/tests/unittests/test_sgd_op.py index 817150a21f5e56e53d41d485d316802ec6983d8a..e8ba53e0189f8af121d3d93d6020fefeb643c40e 100644 --- a/python/paddle/fluid/tests/unittests/test_sgd_op.py +++ b/python/paddle/fluid/tests/unittests/test_sgd_op.py @@ -24,374 +24,366 @@ import paddle paddle.enable_static() - -class TestSGDOp(OpTest): - def setUp(self): - self.op_type = "sgd" - self.conf() - w = np.random.random((self.h, self.w)).astype("float32") - g = np.random.random((self.h, self.w)).astype("float32") - lr = np.array([0.1]).astype("float32") - - self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr} - self.outputs = {'ParamOut': w - lr * g} - - def conf(self): - self.h = 102 - self.w = 105 - - def test_check_output(self): - self.check_output() - - -class TestSGDOpCase8X(TestSGDOp): - def conf(self): - self.h = 10 - self.w = 64 - - -class TestSparseSGDOp(unittest.TestCase): - def check_with_place(self, place): - scope = core.Scope() - - # create and initialize Grad Variable - height = 10 - rows = [0, 4, 7] - self.conf() - - grad_selected_rows = scope.var('Grad').get_selected_rows() - grad_selected_rows.set_height(height) - grad_selected_rows.set_rows(rows) - np_array = np.ones((len(rows), self.row_numel)).astype("float32") - np_array[0, 0] = 2.0 - np_array[2, 8] = 4.0 - - grad_tensor = grad_selected_rows.get_tensor() - grad_tensor.set(np_array, place) - - # create and initialize Param Variable - param = scope.var('Param').get_tensor() - param_array = np.full((height, self.row_numel), 5.0).astype("float32") - param.set(param_array, place) - - # create and initialize LeraningRate Variable - lr = scope.var('LearningRate').get_tensor() - lr_array = np.full((1), 2.0).astype("float32") - lr.set(lr_array, place) - - # create and run sgd operator - sgd_op = Operator( - "sgd", - Param='Param', - Grad='Grad', - ParamOut='Param', - LearningRate='LearningRate') - sgd_op.run(scope, place) - - # get and compare result - result_array = np.array(param) - - # rows[0] = 0, 5.0 - 2.0 * 2.0 - self.assertAlmostEqual(1.0, result_array[rows[0], 0]) - # rows[0] = 0, 5.0 - 2.0 * 1.0 - self.assertAlmostEqual(3.0, result_array[rows[0], 2]) - # 5.0 - 2.0 * 0.0 - self.assertAlmostEqual(5.0, result_array[1, 0]) - # rows[1] = 4, 5.0 - 2.0 * 1.0 - self.assertAlmostEqual(3.0, result_array[rows[1], 10]) - # 5.0 - 2.0 * 0.0 - self.assertAlmostEqual(5.0, result_array[5, 8]) - # rows[2] = 7, 5.0 - 2.0 * 1.0 - self.assertAlmostEqual(3.0, result_array[rows[2], 1]) - # rows[2] = 7, 5.0 - 2.0 * 4.0 - self.assertAlmostEqual(-3.0, result_array[rows[2], 8]) - - def test_sparse_sgd(self): - places = [core.CPUPlace()] - if core.is_compiled_with_cuda(): - places.append(core.CUDAPlace(0)) - for place in places: - self.check_with_place(place) - - def conf(self): - self.row_numel = 12 - - -class TestSparseSGDOpCase8X(TestSparseSGDOp): - def conf(self): - self.row_numel = 16 - - -class TestSGDOpOptimizeSelectedRows(unittest.TestCase): - def check_with_place(self, place): - scope = core.Scope() - - row_width = 12 - # create and initialize Grad Variable - grad_height = 10 - grad_rows = [0, 4, 7] - - grad_selected_rows = scope.var('Grad').get_selected_rows() - grad_selected_rows.set_height(grad_height) - grad_selected_rows.set_rows(grad_rows) - grad_array = np.ones((len(grad_rows), row_width)).astype("float32") - grad_array[0, 0] = 2.0 - grad_array[2, 8] = 4.0 - - grad_tensor = grad_selected_rows.get_tensor() - grad_tensor.set(grad_array, place) - - # create and initialize Param Variable - # create and initialize W Variable - param_rows = [0, 1, 2, 3, 4, 5, 6, 7] - - # init Param - w_selected_rows = scope.var('Param').get_selected_rows() - w_selected_rows.set_height(len(param_rows)) - w_selected_rows.set_rows(param_rows) - w_selected_rows.sync_index() - w_array = np.ones((len(param_rows), row_width)).astype("float32") - for i in range(len(param_rows)): - w_array[i] *= i - w_tensor = w_selected_rows.get_tensor() - w_tensor.set(w_array, place) - - w_before_optimize = np.array(w_tensor) - - # create and initialize LeraningRate Variable - lr_value = 0.1 - lr = scope.var('LearningRate').get_tensor() - lr_array = np.full((1), lr_value).astype("float32") - lr.set(lr_array, place) - - # optimize with Python - w_after_optimize = np.copy(w_before_optimize) - for index, id in enumerate(grad_rows): - w_after_optimize[id] = w_before_optimize[ - id] - lr_value * grad_array[index] - - # create and run sgd operator - sgd_op = Operator( - "sgd", - Param='Param', - Grad='Grad', - ParamOut='Param', - LearningRate='LearningRate') - sgd_op.run(scope, place) - - # get and compare result - result_array = np.array(w_tensor) - assert (result_array == w_after_optimize).all() - - def test_sparse_parameter_sgd(self): - places = [core.CPUPlace()] - # do not support GPU kernel currently - for place in places: - self.check_with_place(place) - - -class TestSGDOpWithLargeInput(unittest.TestCase): - def runTest(self): - paddle.enable_static() - data = fluid.layers.fill_constant(shape=[1], value=128, dtype='int64') - label = fluid.layers.fill_constant( - shape=[1, 150], value=0.5, dtype='float32') - emb = fluid.embedding(input=data, size=(10000000, 150), dtype='float32') - out = fluid.layers.l2_normalize(x=emb, axis=-1) - - cost = fluid.layers.square_error_cost(input=out, label=label) - avg_cost = fluid.layers.mean(cost) - sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) - sgd_optimizer.minimize(avg_cost) - - place = fluid.CPUPlace() - exe = fluid.Executor(place) - exe.run(fluid.default_startup_program()) - compiled_prog = fluid.compiler.CompiledProgram( - fluid.default_main_program()) - result = exe.run(compiled_prog, fetch_list=[avg_cost]) - - -class TestSGDV2(unittest.TestCase): - def test_sgd_dygraph(self): - paddle.disable_static() - value = np.arange(26).reshape(2, 13).astype("float32") - a = paddle.to_tensor(value) - linear = paddle.nn.Linear(13, 5) - # This can be any optimizer supported by dygraph. - adam = paddle.optimizer.SGD(learning_rate=0.01, - parameters=linear.parameters(), - weight_decay=0.01) - out = linear(a) - out.backward() - adam.step() - adam.clear_gradients() - - def test_sgd(self): - paddle.enable_static() - - def check_sgd_optimizer(optimizer_attr): - init_program = paddle.static.Program() - program = paddle.static.Program() - block = program.global_block() - mul_x = block.create_parameter( - dtype="float32", - shape=[5, 10], - lod_level=0, - name="mul.x", - optimize_attr=optimizer_attr) - mul_y = block.create_var( - dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") - mul_out = block.create_var( - dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") - mean_out = block.create_var( - dtype="float32", shape=[1], lod_level=0, name="mean.out") - block.append_op( - type="mul", - inputs={"X": mul_x, - "Y": mul_y}, - outputs={"Out": mul_out}, - attrs={"x_num_col_dims": 1}) - block.append_op( - type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) - sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.01) - opts, _ = sgd_optimizer.minimize(mean_out, init_program) - return opts - - opts = check_sgd_optimizer({'learning_rate': 1.1}) - self.assertEqual(len(opts), 2) - self.assertEqual([op.type for op in opts], ["scale", "sgd"]) - - opts = check_sgd_optimizer({'learning_rate': 1.0}) - self.assertEqual(len(opts), 1) - self.assertEqual([op.type for op in opts], ["sgd"]) - - def test_raise_error(self): - self.assertRaises(ValueError, paddle.optimizer.SGD, learning_rate=None) - - def test_sgd_group_dygraph(self): - paddle.disable_static() - value = np.arange(26).reshape(2, 13).astype("float32") - a = paddle.to_tensor(value) - linear_1 = paddle.nn.Linear(13, 5) - linear_2 = paddle.nn.Linear(5, 3) - # This can be any optimizer supported by dygraph. - adam = paddle.optimizer.SGD(learning_rate=0.01, - parameters=[{ - 'params': linear_1.parameters() - }, { - 'params': linear_2.parameters(), - 'weight_decay': 0.001, - 'learning_rate': 0.1 - }], - weight_decay=0.01) - out = linear_1(a) - out = linear_2(out) - out.backward() - adam.step() - adam.clear_gradients() - - -class TestSGDMultiPrecision2_0(unittest.TestCase): - def dygraph_sgd_mp(self, mp): - paddle.disable_static() - paddle.seed(10) - paddle.set_device('gpu') - input = paddle.randn((2, 2)) - model = paddle.nn.Linear(2, 2) - optimizer = paddle.optimizer.SGD(parameters=model.parameters(), - multi_precision=mp) - if mp == True: - model = paddle.amp.decorate(models=model, level='O2') - scaler = paddle.amp.GradScaler(init_loss_scaling=1024) - - for idx in range(5): - if mp == True: - with paddle.amp.auto_cast(level='O2'): - output = model(input) - loss = paddle.mean(output) - scaled = scaler.scale(loss) - scaled.backward() - scaler.minimize(optimizer, scaled) - optimizer.clear_grad() - else: - output = model(input) - loss = paddle.mean(output) - optimizer.step() - optimizer.clear_grad() - - return output, model.parameters() - - def static_sgd_mp(self, mp): - paddle.enable_static() - paddle.seed(10) - np.random.seed(10) - exe = paddle.static.Executor('gpu') - train_program = paddle.static.Program() - startup_program = paddle.static.Program() - optimizer = paddle.optimizer.SGD(multi_precision=mp) - - if mp: - optimizer = paddle.static.amp.decorate( - optimizer, - init_loss_scaling=128.0, - use_dynamic_loss_scaling=True, - use_pure_fp16=True, - use_fp16_guard=False) - with paddle.static.program_guard(train_program, startup_program): - if mp: - data = paddle.static.data( - shape=[2, 2], name='X', dtype='float16') - else: - data = paddle.static.data( - shape=[2, 2], name='X', dtype='float32') - hidden = paddle.static.nn.fc(x=data, size=10) - loss = paddle.fluid.layers.mean(hidden) - optimizer.minimize(loss) - exe.run(startup_program) - - if mp: - optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) - x = np.random.random(size=(2, 2)).astype('float16') - else: - x = np.random.random(size=(2, 2)).astype('float32') - out = [] - for idx in range(5): - loss_data, = exe.run(train_program, - feed={"X": x}, - fetch_list=[loss.name]) - out.append(loss_data) - return out - - def test_main(self): - if not paddle.is_compiled_with_cuda(): - return - "Test dygraph mode" - output1_dy, params1_dy = self.dygraph_sgd_mp(mp=True) - output2_dy, params2_dy = self.dygraph_sgd_mp(mp=False) - self.assertEqual( - np.allclose( - output1_dy.astype('float32').numpy(), - output2_dy.astype('float32').numpy(), - atol=1e-01), - True) - for idx in range(len(params1_dy)): - self.assertEqual( - np.allclose( - params1_dy[idx].astype('float32').numpy(), - params2_dy[idx].astype('float32').numpy(), - atol=1e-01), - True) - "Test static mode" - output1_st = self.static_sgd_mp(mp=True) - output2_st = self.static_sgd_mp(mp=False) - for idx in range(len(output1_st)): - self.assertEqual( - np.allclose( - output1_st[idx].astype('float32'), - output2_st[idx].astype('float32'), - atol=1e-01), - True) +# class TestSGDOp(OpTest): +# def setUp(self): +# self.op_type = "sgd" +# self.conf() +# w = np.random.random((self.h, self.w)).astype("float32") +# g = np.random.random((self.h, self.w)).astype("float32") +# lr = np.array([0.1]).astype("float32") + +# self.inputs = {'Param': w, 'Grad': g, 'LearningRate': lr} +# self.outputs = {'ParamOut': w - lr * g} + +# def conf(self): +# self.h = 102 +# self.w = 105 + +# def test_check_output(self): +# self.check_output() + +# class TestSGDOpCase8X(TestSGDOp): +# def conf(self): +# self.h = 10 +# self.w = 64 + +# class TestSparseSGDOp(unittest.TestCase): +# def check_with_place(self, place): +# scope = core.Scope() + +# # create and initialize Grad Variable +# height = 10 +# rows = [0, 4, 7] +# self.conf() + +# grad_selected_rows = scope.var('Grad').get_selected_rows() +# grad_selected_rows.set_height(height) +# grad_selected_rows.set_rows(rows) +# np_array = np.ones((len(rows), self.row_numel)).astype("float32") +# np_array[0, 0] = 2.0 +# np_array[2, 8] = 4.0 + +# grad_tensor = grad_selected_rows.get_tensor() +# grad_tensor.set(np_array, place) + +# # create and initialize Param Variable +# param = scope.var('Param').get_tensor() +# param_array = np.full((height, self.row_numel), 5.0).astype("float32") +# param.set(param_array, place) + +# # create and initialize LeraningRate Variable +# lr = scope.var('LearningRate').get_tensor() +# lr_array = np.full((1), 2.0).astype("float32") +# lr.set(lr_array, place) + +# # create and run sgd operator +# sgd_op = Operator( +# "sgd", +# Param='Param', +# Grad='Grad', +# ParamOut='Param', +# LearningRate='LearningRate') +# sgd_op.run(scope, place) + +# # get and compare result +# result_array = np.array(param) + +# # rows[0] = 0, 5.0 - 2.0 * 2.0 +# self.assertAlmostEqual(1.0, result_array[rows[0], 0]) +# # rows[0] = 0, 5.0 - 2.0 * 1.0 +# self.assertAlmostEqual(3.0, result_array[rows[0], 2]) +# # 5.0 - 2.0 * 0.0 +# self.assertAlmostEqual(5.0, result_array[1, 0]) +# # rows[1] = 4, 5.0 - 2.0 * 1.0 +# self.assertAlmostEqual(3.0, result_array[rows[1], 10]) +# # 5.0 - 2.0 * 0.0 +# self.assertAlmostEqual(5.0, result_array[5, 8]) +# # rows[2] = 7, 5.0 - 2.0 * 1.0 +# self.assertAlmostEqual(3.0, result_array[rows[2], 1]) +# # rows[2] = 7, 5.0 - 2.0 * 4.0 +# self.assertAlmostEqual(-3.0, result_array[rows[2], 8]) + +# def test_sparse_sgd(self): +# places = [core.CPUPlace()] +# if core.is_compiled_with_cuda(): +# places.append(core.CUDAPlace(0)) +# for place in places: +# self.check_with_place(place) + +# def conf(self): +# self.row_numel = 12 + +# class TestSparseSGDOpCase8X(TestSparseSGDOp): +# def conf(self): +# self.row_numel = 16 + +# class TestSGDOpOptimizeSelectedRows(unittest.TestCase): +# def check_with_place(self, place): +# scope = core.Scope() + +# row_width = 12 +# # create and initialize Grad Variable +# grad_height = 10 +# grad_rows = [0, 4, 7] + +# grad_selected_rows = scope.var('Grad').get_selected_rows() +# grad_selected_rows.set_height(grad_height) +# grad_selected_rows.set_rows(grad_rows) +# grad_array = np.ones((len(grad_rows), row_width)).astype("float32") +# grad_array[0, 0] = 2.0 +# grad_array[2, 8] = 4.0 + +# grad_tensor = grad_selected_rows.get_tensor() +# grad_tensor.set(grad_array, place) + +# # create and initialize Param Variable +# # create and initialize W Variable +# param_rows = [0, 1, 2, 3, 4, 5, 6, 7] + +# # init Param +# w_selected_rows = scope.var('Param').get_selected_rows() +# w_selected_rows.set_height(len(param_rows)) +# w_selected_rows.set_rows(param_rows) +# w_selected_rows.sync_index() +# w_array = np.ones((len(param_rows), row_width)).astype("float32") +# for i in range(len(param_rows)): +# w_array[i] *= i +# w_tensor = w_selected_rows.get_tensor() +# w_tensor.set(w_array, place) + +# w_before_optimize = np.array(w_tensor) + +# # create and initialize LeraningRate Variable +# lr_value = 0.1 +# lr = scope.var('LearningRate').get_tensor() +# lr_array = np.full((1), lr_value).astype("float32") +# lr.set(lr_array, place) + +# # optimize with Python +# w_after_optimize = np.copy(w_before_optimize) +# for index, id in enumerate(grad_rows): +# w_after_optimize[id] = w_before_optimize[ +# id] - lr_value * grad_array[index] + +# # create and run sgd operator +# sgd_op = Operator( +# "sgd", +# Param='Param', +# Grad='Grad', +# ParamOut='Param', +# LearningRate='LearningRate') +# sgd_op.run(scope, place) + +# # get and compare result +# result_array = np.array(w_tensor) +# assert (result_array == w_after_optimize).all() + +# def test_sparse_parameter_sgd(self): +# places = [core.CPUPlace()] +# # do not support GPU kernel currently +# for place in places: +# self.check_with_place(place) + +# class TestSGDOpWithLargeInput(unittest.TestCase): +# def runTest(self): +# paddle.enable_static() +# data = fluid.layers.fill_constant(shape=[1], value=128, dtype='int64') +# label = fluid.layers.fill_constant( +# shape=[1, 150], value=0.5, dtype='float32') +# emb = fluid.embedding(input=data, size=(10000000, 150), dtype='float32') +# out = fluid.layers.l2_normalize(x=emb, axis=-1) + +# cost = fluid.layers.square_error_cost(input=out, label=label) +# avg_cost = fluid.layers.mean(cost) +# sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001) +# sgd_optimizer.minimize(avg_cost) + +# place = fluid.CPUPlace() +# exe = fluid.Executor(place) +# exe.run(fluid.default_startup_program()) +# compiled_prog = fluid.compiler.CompiledProgram( +# fluid.default_main_program()) +# result = exe.run(compiled_prog, fetch_list=[avg_cost]) + +# class TestSGDV2(unittest.TestCase): +# def test_sgd_dygraph(self): +# paddle.disable_static() +# value = np.arange(26).reshape(2, 13).astype("float32") +# a = paddle.to_tensor(value) +# linear = paddle.nn.Linear(13, 5) +# # This can be any optimizer supported by dygraph. +# adam = paddle.optimizer.SGD(learning_rate=0.01, +# parameters=linear.parameters(), +# weight_decay=0.01) +# out = linear(a) +# out.backward() +# adam.step() +# adam.clear_gradients() + +# def test_sgd(self): +# paddle.enable_static() + +# def check_sgd_optimizer(optimizer_attr): +# init_program = paddle.static.Program() +# program = paddle.static.Program() +# block = program.global_block() +# mul_x = block.create_parameter( +# dtype="float32", +# shape=[5, 10], +# lod_level=0, +# name="mul.x", +# optimize_attr=optimizer_attr) +# mul_y = block.create_var( +# dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") +# mul_out = block.create_var( +# dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") +# mean_out = block.create_var( +# dtype="float32", shape=[1], lod_level=0, name="mean.out") +# block.append_op( +# type="mul", +# inputs={"X": mul_x, +# "Y": mul_y}, +# outputs={"Out": mul_out}, +# attrs={"x_num_col_dims": 1}) +# block.append_op( +# type="mean", inputs={"X": mul_out}, outputs={"Out": mean_out}) +# sgd_optimizer = paddle.optimizer.SGD(learning_rate=0.01) +# opts, _ = sgd_optimizer.minimize(mean_out, init_program) +# return opts + +# opts = check_sgd_optimizer({'learning_rate': 1.1}) +# self.assertEqual(len(opts), 2) +# self.assertEqual([op.type for op in opts], ["scale", "sgd"]) + +# opts = check_sgd_optimizer({'learning_rate': 1.0}) +# self.assertEqual(len(opts), 1) +# self.assertEqual([op.type for op in opts], ["sgd"]) + +# def test_raise_error(self): +# self.assertRaises(ValueError, paddle.optimizer.SGD, learning_rate=None) + +# def test_sgd_group_dygraph(self): +# paddle.disable_static() +# value = np.arange(26).reshape(2, 13).astype("float32") +# a = paddle.to_tensor(value) +# linear_1 = paddle.nn.Linear(13, 5) +# linear_2 = paddle.nn.Linear(5, 3) +# # This can be any optimizer supported by dygraph. +# adam = paddle.optimizer.SGD(learning_rate=0.01, +# parameters=[{ +# 'params': linear_1.parameters() +# }, { +# 'params': linear_2.parameters(), +# 'weight_decay': 0.001, +# 'learning_rate': 0.1 +# }], +# weight_decay=0.01) +# out = linear_1(a) +# out = linear_2(out) +# out.backward() +# adam.step() +# adam.clear_gradients() + +# class TestSGDMultiPrecision2_0(unittest.TestCase): +# def dygraph_sgd_mp(self, mp): +# paddle.disable_static() +# paddle.seed(10) +# paddle.set_device('gpu') +# input = paddle.randn((2, 2)) +# model = paddle.nn.Linear(2, 2) +# optimizer = paddle.optimizer.SGD(parameters=model.parameters(), +# multi_precision=mp) +# if mp == True: +# model = paddle.amp.decorate(models=model, level='O2') +# scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + +# for idx in range(5): +# if mp == True: +# with paddle.amp.auto_cast(level='O2'): +# output = model(input) +# loss = paddle.mean(output) +# scaled = scaler.scale(loss) +# scaled.backward() +# scaler.minimize(optimizer, scaled) +# optimizer.clear_grad() +# else: +# output = model(input) +# loss = paddle.mean(output) +# optimizer.step() +# optimizer.clear_grad() + +# return output, model.parameters() + +# def static_sgd_mp(self, mp): +# paddle.enable_static() +# paddle.seed(10) +# np.random.seed(10) +# exe = paddle.static.Executor('gpu') +# train_program = paddle.static.Program() +# startup_program = paddle.static.Program() +# optimizer = paddle.optimizer.SGD(multi_precision=mp) + +# if mp: +# optimizer = paddle.static.amp.decorate( +# optimizer, +# init_loss_scaling=128.0, +# use_dynamic_loss_scaling=True, +# use_pure_fp16=True, +# use_fp16_guard=False) +# with paddle.static.program_guard(train_program, startup_program): +# if mp: +# data = paddle.static.data( +# shape=[2, 2], name='X', dtype='float16') +# else: +# data = paddle.static.data( +# shape=[2, 2], name='X', dtype='float32') +# hidden = paddle.static.nn.fc(x=data, size=10) +# loss = paddle.fluid.layers.mean(hidden) +# optimizer.minimize(loss) +# exe.run(startup_program) + +# if mp: +# optimizer.amp_init(place='gpu', scope=paddle.static.global_scope()) +# x = np.random.random(size=(2, 2)).astype('float16') +# else: +# x = np.random.random(size=(2, 2)).astype('float32') +# out = [] +# for idx in range(5): +# loss_data, = exe.run(train_program, +# feed={"X": x}, +# fetch_list=[loss.name]) +# out.append(loss_data) +# return out + +# def test_main(self): +# if not paddle.is_compiled_with_cuda(): +# return +# "Test dygraph mode" +# output1_dy, params1_dy = self.dygraph_sgd_mp(mp=True) +# output2_dy, params2_dy = self.dygraph_sgd_mp(mp=False) +# self.assertEqual( +# np.allclose( +# output1_dy.astype('float32').numpy(), +# output2_dy.astype('float32').numpy(), +# atol=1e-01), +# True) +# for idx in range(len(params1_dy)): +# self.assertEqual( +# np.allclose( +# params1_dy[idx].astype('float32').numpy(), +# params2_dy[idx].astype('float32').numpy(), +# atol=1e-01), +# True) +# "Test static mode" +# output1_st = self.static_sgd_mp(mp=True) +# output2_st = self.static_sgd_mp(mp=False) +# for idx in range(len(output1_st)): +# self.assertEqual( +# np.allclose( +# output1_st[idx].astype('float32'), +# output2_st[idx].astype('float32'), +# atol=1e-01), +# True) class TestSGDMultiPrecision1_0(unittest.TestCase):