提交 bcd8c2cc 编写于 作者: M minqiyang

Add unit test

上级 f20fc955
...@@ -267,7 +267,7 @@ if (WITH_GPU AND TENSORRT_FOUND) ...@@ -267,7 +267,7 @@ if (WITH_GPU AND TENSORRT_FOUND)
else() else()
set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op) set(DEPS_OPS ${DEPS_OPS} tensorrt_engine_op)
endif() endif()
op_library(clip_by_norm_op DEPS selected_rows_functor) op_library(clip_by_norm_op DEPS selected_rows_functor selected_rows)
op_library(sum_op DEPS selected_rows_functor) op_library(sum_op DEPS selected_rows_functor)
op_library(sgd_op DEPS selected_rows_functor) op_library(sgd_op DEPS selected_rows_functor)
op_library(print_op DEPS lod_tensor) op_library(print_op DEPS lod_tensor)
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h" #include "paddle/fluid/operators/math/selected_rows_functor.h"
#include "paddle/fluid/platform/transform.h" #include "paddle/fluid/platform/transform.h"
...@@ -23,6 +24,7 @@ namespace paddle { ...@@ -23,6 +24,7 @@ namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using SelectedRows = framework::SelectedRows;
template <typename T, int MajorType = Eigen::RowMajor, template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>; using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
...@@ -41,22 +43,24 @@ class ClipByNormKernel : public framework::OpKernel<T> { ...@@ -41,22 +43,24 @@ class ClipByNormKernel : public framework::OpKernel<T> {
output = context.Output<Tensor>("Out"); output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
} else if (in_var->IsType<framework::SelectedRows>()) { } else if (in_var->IsType<SelectedRows>()) {
auto* x = context.Input<framework::SelectedRows>("X"); auto* x = context.Input<SelectedRows>("X");
// merge ids in selected rows first // merge ids in selected rows first
math::scatter::MergeAdd<DeviceContext, T> merge_func; math::scatter::MergeAdd<DeviceContext, T> merge_func;
auto* merged_input = const_cast<framework::Scope&>(context.scope()) SelectedRows* merged_input =
.Var() const_cast<framework::Scope&>(context.scope())
->GetMutable<framework::SelectedRows>(); .Var()
->GetMutable<SelectedRows>();
merge_func(context.template device_context<DeviceContext>(), *x, merge_func(context.template device_context<DeviceContext>(), *x,
merged_input); merged_input);
input = &(merged_input->value()); input = &(merged_input->value());
auto* output_selected_rows = context.Output<SelectedRows>("Out"); SelectedRows* output_selected_rows = context.Output<SelectedRows>("Out");
output_selected_rows->set_rows(merged_input.rows()); output_selected_rows->set_rows(merged_input->rows());
output = output_selected_rows->mutable_data(); output_selected_rows->set_height(merged_input->height());
output->Resize(framework::make_ddim(merged_input.value().dims())); output = output_selected_rows->mutable_value();
output->Resize(merged_input->value().dims());
} else { } else {
PADDLE_THROW("Unexpected branch, input variable type is %s", PADDLE_THROW("Unexpected branch, input variable type is %s",
in_var->Type().name()); in_var->Type().name());
......
...@@ -18,6 +18,8 @@ import unittest ...@@ -18,6 +18,8 @@ import unittest
import numpy as np import numpy as np
from op_test import OpTest from op_test import OpTest
import paddle.fluid.core as core
class TestClipByNormOp(OpTest): class TestClipByNormOp(OpTest):
def setUp(self): def setUp(self):
...@@ -62,5 +64,41 @@ class TestCase3(TestClipByNormOp): ...@@ -62,5 +64,41 @@ class TestCase3(TestClipByNormOp):
self.max_norm = 1.0 self.max_norm = 1.0
class TestClipByNormOpWithSelectedRows(OpTest):
def setUp(self):
self.initTestCase()
self.max_relative_error = 0.006
scope = core.Scope()
x_selected_rows = scope.var('X').get_selected_rows()
x_selected_rows.set_rows([1, 1, 2, 0])
x_tensor = x_selected_rows.get_tensor()
x_tensor = np.random.random((4, 1)).astype("float32")
x_tensor[np.abs(x_tensor) < self.max_relative_error] = 0.5
self.op_type = "clip_by_norm"
self.inputs = {'X': x_selected_rows, }
self.attrs = {}
self.attrs['max_norm'] = self.max_norm
y_tensor = np.zeros((3, 1))
y_tensor[0::1] = np.sum(x_tensor[0::1], x_tensor[1::1])
y_tensor[1::1] = x_tensor[2::1]
y_tensor[2::1] = x_tensor[3::1]
norm = np.sqrt(np.sum(np.square(y_tensor)))
if norm > self.max_norm:
output = self.max_norm * y_tensor / norm
else:
output = y_tensor
self.outputs = {'Out': output}
def test_check_output(self):
self.check_output()
def initTestCase(self):
self.shape = (100, )
self.max_norm = 1.0
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册