diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index b61bca8c3d1e2a63faed43b437459f4e9318e46e..e10fc422fac5ccccc89806af0e4901866eecc7b5 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -229,7 +229,7 @@ if(WITH_DISTRIBUTE) op_library(${dist_op} DEPS ${DISTRIBUTE_DEPS}) set_source_files_properties(${dist_op}.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) endforeach() - + #set_source_files_properties(send_recv_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) #cc_test(test_send_recv SRCS send_recv_op_test.cc DEPS prefetch_op send_op # listen_and_serv_op sum_op executor SERIAL) @@ -267,6 +267,7 @@ if (WITH_GPU AND TENSORRT_FOUND) else() set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) endif() +op_library(clip_by_norm_op DEPS selected_rows_functor) op_library(sum_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor) op_library(print_op DEPS lod_tensor) diff --git a/paddle/fluid/operators/clip_by_norm_op.h b/paddle/fluid/operators/clip_by_norm_op.h index 5af0eb0b2ada66d5ae7d521d80e213f9e61f826f..83461159134f3a77613eec70f735a2e1402131ce 100644 --- a/paddle/fluid/operators/clip_by_norm_op.h +++ b/paddle/fluid/operators/clip_by_norm_op.h @@ -16,6 +16,7 @@ limitations under the License. */ #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/platform/transform.h" namespace paddle { @@ -31,10 +32,31 @@ class ClipByNormKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto max_norm = context.Attr("max_norm"); - auto* input = context.Input("X"); + auto in_var = context.InputVar("X"); auto* output = context.Output("Out"); output->mutable_data(context.GetPlace()); + const Tensor* input = nullptr; + if (in_var->IsType()) { + input = context.Input("X"); + } else if (in_var->IsType()) { + auto* x = context.Input("X"); + + // merge ids in selected rows first + math::scatter::MergeAdd merge_func; + auto* merged_input = const_cast(context.scope()) + .Var() + ->GetMutable(); + merge_func(context.template device_context(), *x, + merged_input); + input = &(merged_input->value()); + } else { + PADDLE_THROW("Unexpected branch, input variable type is %s", + in_var->Type().name()); + } + + PADDLE_ENFORCE_NOT_NULL(input); + auto x = EigenVector::Flatten(*input); auto out = EigenVector::Flatten(*output); auto x_norm = x.square().sum().sqrt();