提交 86e2e686 编写于 作者: Q Qiao Longfei

fix bug

上级 333fd152
...@@ -242,7 +242,7 @@ TEST(selected_rows_functor, gpu_add_to) { ...@@ -242,7 +242,7 @@ TEST(selected_rows_functor, gpu_add_to) {
EXPECT_EQ(tensor1_cpu_data[9 * row_numel + 6], 5.0); EXPECT_EQ(tensor1_cpu_data[9 * row_numel + 6], 5.0);
} }
TEST(selected_rows_functor, cpu_merge_add) { TEST(selected_rows_functor, gpu_merge_add) {
paddle::platform::CUDAPlace gpu_place(0); paddle::platform::CUDAPlace gpu_place(0);
paddle::platform::CPUPlace cpu_place; paddle::platform::CPUPlace cpu_place;
paddle::platform::CUDADeviceContext& ctx = paddle::platform::CUDADeviceContext& ctx =
...@@ -250,7 +250,7 @@ TEST(selected_rows_functor, cpu_merge_add) { ...@@ -250,7 +250,7 @@ TEST(selected_rows_functor, cpu_merge_add) {
paddle::platform::DeviceContextPool::Instance().Get(gpu_place)); paddle::platform::DeviceContextPool::Instance().Get(gpu_place));
paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext, paddle::operators::math::SetConstant<paddle::platform::CUDADeviceContext,
float> float>
functor; set_const;
int64_t height = 10; int64_t height = 10;
int64_t row_numel = 8; int64_t row_numel = 8;
...@@ -262,7 +262,7 @@ TEST(selected_rows_functor, cpu_merge_add) { ...@@ -262,7 +262,7 @@ TEST(selected_rows_functor, cpu_merge_add) {
in1_value->mutable_data<float>( in1_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
{static_cast<int64_t>(rows1.size()), row_numel}), {static_cast<int64_t>(rows1.size()), row_numel}),
cpu_place); gpu_place);
set_const(ctx, in1_value, 1.0); set_const(ctx, in1_value, 1.0);
std::vector<int64_t> rows2{2, 5, 3, 5, 3}; std::vector<int64_t> rows2{2, 5, 3, 5, 3};
...@@ -272,7 +272,7 @@ TEST(selected_rows_functor, cpu_merge_add) { ...@@ -272,7 +272,7 @@ TEST(selected_rows_functor, cpu_merge_add) {
in2_value->mutable_data<float>( in2_value->mutable_data<float>(
paddle::framework::make_ddim( paddle::framework::make_ddim(
{static_cast<int64_t>(rows2.size()), row_numel}), {static_cast<int64_t>(rows2.size()), row_numel}),
cpu_place); gpu_place);
set_const(ctx, in2_value, 1.0); set_const(ctx, in2_value, 1.0);
std::unique_ptr<paddle::framework::SelectedRows> output{ std::unique_ptr<paddle::framework::SelectedRows> output{
...@@ -288,7 +288,7 @@ TEST(selected_rows_functor, cpu_merge_add) { ...@@ -288,7 +288,7 @@ TEST(selected_rows_functor, cpu_merge_add) {
merge_add_functor(ctx, inputs, output.get()); merge_add_functor(ctx, inputs, output.get());
paddle::framework::Tensor output_cpu; paddle::framework::Tensor output_cpu;
paddle::framework::TensorCopy(*output, cpu_place, ctx, &output_cpu); paddle::framework::TensorCopy(output.value(), cpu_place, ctx, &output_cpu);
ctx.Wait(); ctx.Wait();
EXPECT_EQ(output->height(), height); EXPECT_EQ(output->height(), height);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册