diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e10fc422fac5ccccc89806af0e4901866eecc7b5..cafd7b11aedbb20e24cc00b5d9c648a39a408540 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -267,7 +267,7 @@ if (WITH_GPU AND TENSORRT_FOUND) else() set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) endif() -op_library(clip_by_norm_op DEPS selected_rows_functor) +op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows) op_library(sum_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor) op_library(print_op DEPS lod_tensor) diff --git a/paddle/fluid/operators/clip_by_norm_op.h b/paddle/fluid/operators/clip_by_norm_op.h index 7144524a4c62d29df255afb197c6f5e7c39c28a7..9f99c8a3f953c4e5c7231a85e8ed004e8f24ee56 100644 --- a/paddle/fluid/operators/clip_by_norm_op.h +++ b/paddle/fluid/operators/clip_by_norm_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/transform.h" @@ -23,6 +24,7 @@ namespace paddle { namespace operators { using Tensor = framework::Tensor; +using SelectedRows = framework::SelectedRows; template using EigenVector = framework::EigenVector; @@ -41,22 +43,24 @@ class ClipByNormKernel : public framework::OpKernel { output = context.Output("Out"); output->mutable_data(context.GetPlace()); - } else if (in_var->IsType()) { - auto* x = context.Input("X"); + } else if (in_var->IsType()) { + auto* x = context.Input("X"); // merge ids in selected rows first math::scatter::MergeAdd merge_func; - auto* merged_input = const_cast(context.scope()) - .Var() - ->GetMutable(); + SelectedRows* merged_input = + const_cast(context.scope()) + .Var() + ->GetMutable(); merge_func(context.template device_context(), *x, merged_input); input = &(merged_input->value()); - auto* output_selected_rows = context.Output("Out"); - output_selected_rows->set_rows(merged_input.rows()); - output = output_selected_rows->mutable_data(); - output->Resize(framework::make_ddim(merged_input.value().dims())); + SelectedRows* output_selected_rows = context.Output("Out"); + output_selected_rows->set_rows(merged_input->rows()); + output_selected_rows->set_height(merged_input->height()); + output = output_selected_rows->mutable_value(); + output->Resize(merged_input->value().dims()); } else { PADDLE_THROW("Unexpected branch, input variable type is %s", in_var->Type().name()); diff --git a/python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py b/python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py index 6103c3aafc0bb154194314830c5c8c5d89460cfe..6556c0875e95bd2b7994dec7cc023f4cd56869ef 100644 --- a/python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py +++ b/python/paddle/fluid/tests/unittests/test_clip_by_norm_op.py @@ -18,6 +18,8 @@ import unittest import numpy as np from op_test import OpTest +import paddle.fluid.core as core + class TestClipByNormOp(OpTest): def setUp(self): @@ -62,5 +64,41 @@ class TestCase3(TestClipByNormOp): self.max_norm = 1.0 +class TestClipByNormOpWithSelectedRows(OpTest): + def setUp(self): + self.initTestCase() + + self.max_relative_error = 0.006 + + scope = core.Scope() + x_selected_rows = scope.var('X').get_selected_rows() + x_selected_rows.set_rows([1, 1, 2, 0]) + x_tensor = x_selected_rows.get_tensor() + x_tensor = np.random.random((4, 1)).astype("float32") + x_tensor[np.abs(x_tensor) < self.max_relative_error] = 0.5 + + self.op_type = "clip_by_norm" + self.inputs = {'X': x_selected_rows, } + self.attrs = {} + self.attrs['max_norm'] = self.max_norm + y_tensor = np.zeros((3, 1)) + y_tensor[0::1] = np.sum(x_tensor[0::1], x_tensor[1::1]) + y_tensor[1::1] = x_tensor[2::1] + y_tensor[2::1] = x_tensor[3::1] + norm = np.sqrt(np.sum(np.square(y_tensor))) + if norm > self.max_norm: + output = self.max_norm * y_tensor / norm + else: + output = y_tensor + self.outputs = {'Out': output} + + def test_check_output(self): + self.check_output() + + def initTestCase(self): + self.shape = (100, ) + self.max_norm = 1.0 + + if __name__ == '__main__': unittest.main()