提交 67308822 编写于 作者: M minqiyang

Add selected_rows merge for clip_by_norm op

test=develop
上级 2f5a7cc4
...@@ -267,6 +267,7 @@ if (WITH_GPU AND TENSORRT_FOUND) ...@@ -267,6 +267,7 @@ if (WITH_GPU AND TENSORRT_FOUND)
else() else()
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
endif() endif()
op_library(clip_by_norm_op DEPS selected_rows_functor)
op_library(sum_op DEPS selected_rows_functor) op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor)
op_library(print_op DEPS lod_tensor) op_library(print_op DEPS lod_tensor)
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
namespace paddle { namespace paddle {
...@@ -31,10 +32,31 @@ class ClipByNormKernel : public framework::OpKernel<T> { ...@@ -31,10 +32,31 @@ class ClipByNormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto max_norm = context.Attr<T>("max_norm"); auto max_norm = context.Attr<T>("max_norm");
auto* input = context.Input<Tensor>("X"); auto in_var = context.InputVar("X");
auto* output = context.Output<Tensor>("Out"); auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
const Tensor* input = nullptr;
if (in_var->IsType<framework::LoDTensor>()) {
input = context.Input<Tensor>("X");
} else if (in_var->IsType<framework::SelectedRows>()) {
auto* x = context.Input<framework::SelectedRows>("X");
// merge ids in selected rows first
math::scatter::MergeAdd<DeviceContext, T> merge_func;
auto* merged_input = const_cast<framework::Scope&>(context.scope())
.Var()
->GetMutable<framework::SelectedRows>();
merge_func(context.template device_context<DeviceContext>(), *x,
merged_input);
input = &(merged_input->value());
} else {
PADDLE_THROW("Unexpected branch, input variable type is %s",
in_var->Type().name());
}
PADDLE_ENFORCE_NOT_NULL(input);
auto x = EigenVector<T>::Flatten(*input); auto x = EigenVector<T>::Flatten(*input);
auto out = EigenVector<T>::Flatten(*output); auto out = EigenVector<T>::Flatten(*output);
auto x_norm = x.square().sum().sqrt(); auto x_norm = x.square().sum().sqrt();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册