提交 bf7a2789 编写于 作者: N nhzlx

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into add_trt_pad_op

test=develop
......@@ -268,6 +268,7 @@ if (WITH_GPU AND TENSORRT_FOUND)
else()
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
endif()
op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows)
op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor)
op_library(print_op DEPS lod_tensor)
......
......@@ -16,12 +16,15 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using SelectedRows = framework::SelectedRows;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
......@@ -31,9 +34,40 @@ class ClipByNormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto max_norm = context.Attr<T>("max_norm");
auto* input = context.Input<Tensor>("X");
auto* output = context.Output<Tensor>("Out");
auto in_var = context.InputVar("X");
Tensor* output = nullptr;
const Tensor* input = nullptr;
if (in_var->IsType<framework::LoDTensor>()) {
input = context.Input<Tensor>("X");
output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace());
} else if (in_var->IsType<SelectedRows>()) {
auto* x = context.Input<SelectedRows>("X");
// merge ids in selected rows first
math::scatter::MergeAdd<DeviceContext, T> merge_func;
SelectedRows* merged_input =
const_cast<framework::Scope&>(context.scope())
.Var()
->GetMutable<SelectedRows>();
merge_func(context.template device_context<DeviceContext>(), *x,
merged_input);
input = &(merged_input->value());
SelectedRows* output_selected_rows = context.Output<SelectedRows>("Out");
output_selected_rows->set_rows(merged_input->rows());
output_selected_rows->set_height(merged_input->height());
output = output_selected_rows->mutable_value();
output->Resize(merged_input->value().dims());
output->mutable_data<T>(context.GetPlace());
} else {
PADDLE_THROW("Unexpected branch, input variable type is %s",
in_var->Type().name());
}
PADDLE_ENFORCE_NOT_NULL(input);
auto x = EigenVector<T>::Flatten(*input);
auto out = EigenVector<T>::Flatten(*output);
......
......@@ -18,6 +18,9 @@ import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid as fluid
import paddle.fluid.core as core
class TestClipByNormOp(OpTest):
def setUp(self):
......@@ -62,5 +65,59 @@ class TestCase3(TestClipByNormOp):
self.max_norm = 1.0
class TestClipByNormOpWithSelectedRows(OpTest):
def check_with_place(self, place):
self.config_test_case()
scope = core.Scope()
# set input
x_selected_rows = scope.var('X').get_selected_rows()
x_selected_rows.set_rows(self.grad_rows)
x_tensor = x_selected_rows.get_tensor()
x_np = np.random.random(self.grad_shape).astype("float32")
x_np[np.abs(x_np) < self.max_relative_error] = 0.5
x_tensor.set(x_np, place)
# set output
out_selected_rows = scope.var('Out').get_selected_rows()
# run clip_by_norm_op
clip_by_norm_op = fluid.op.Operator(
"clip_by_norm", max_norm=self.max_norm, X='X', Out='Out')
clip_by_norm_op.run(scope, place)
# check output
self.assertEqual(out_selected_rows.rows(), self.grad_clipped_rows)
out_tensor = out_selected_rows.get_tensor()
y_np = np.zeros(self.grad_clipped_shape)
y_np[0] = np.sum(x_np[0:2])
y_np[1] = x_np[2]
y_np[2] = x_np[3]
norm = np.sqrt(np.sum(np.square(y_np)))
if norm > self.max_norm:
output = self.max_norm * y_np / norm
else:
output = y_np
self.assertTrue(
np.allclose(
np.array(out_tensor), output, atol=1e-5, equal_nan=False))
def test_clip_by_norm_with_selected_ros(self):
places = [core.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(core.CUDAPlace(0))
for place in places:
self.check_with_place(place)
def config_test_case(self):
self.max_norm = 1.0
self.max_relative_error = 0.006
self.grad_shape = (4, 1)
self.grad_clipped_shape = (3, 1)
self.grad_rows = [0, 0, 1, 2]
self.grad_clipped_rows = [0, 1, 2]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册