未验证 提交 e0b5691e 编写于 作者: G gongweibao 提交者: GitHub

Add drop_out_op unit test (#9364)

上级 fe142d92
......@@ -264,3 +264,4 @@ cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor paddle_memor
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(nccl_op_test SRCS nccl_op_test.cu.cc DEPS nccl_op gpu_info device_context)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
......@@ -55,9 +55,6 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
y->mutable_data<T>(context.GetPlace());
float dropout_prob = context.Attr<float>("dropout_prob");
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
auto& place = *context.template device_context<Place>().eigen_device();
if (!context.Attr<bool>("is_test")) {
auto* mask = context.Output<Tensor>("Mask");
......@@ -76,6 +73,8 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
T><<<grid, threads, 0, context.cuda_device_context().stream()>>>(
size, seed, dropout_prob, x_data, mask_data, y_data);
} else {
auto X = EigenMatrix<T>::Reshape(*x, 1);
auto Y = EigenMatrix<T>::Reshape(*y, 1);
Y.device(place) = X * static_cast<T>(1.0f - dropout_prob);
}
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <unistd.h>
#include <string>
#include <thread>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/string/printf.h"
namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(dropout);
void Compare(f::Scope& scope, p::DeviceContext& ctx) {
// init
auto var = scope.Var("X");
auto tensor = var->GetMutable<f::LoDTensor>();
tensor->Resize({10, 10});
std::vector<float> init;
for (int64_t i = 0; i < 10 * 10; ++i) {
init.push_back(1.0);
}
TensorFromVector(init, ctx, tensor);
auto place = ctx.GetPlace();
auto out_var = scope.Var("Out");
auto out_tensor = out_var->GetMutable<f::LoDTensor>();
out_tensor->Resize({10, 10});
out_tensor->mutable_data<float>(place); // allocate
auto mask_var = scope.Var("Mask");
auto mask_tensor = mask_var->GetMutable<f::LoDTensor>();
mask_tensor->Resize({10, 10});
mask_tensor->mutable_data<float>(place); // allocate
// run
f::AttributeMap attrs;
float dropout_prob = 0.5;
attrs.insert({"fix_seed", 1});
attrs.insert({"seed", 3});
attrs.insert({"dropout_prob", dropout_prob});
auto dropout_op = f::OpRegistry::CreateOp(
"dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs);
dropout_op->Run(scope, place);
std::vector<float> out_vec;
TensorToVector(*out_tensor, ctx, &out_vec);
std::vector<float> std_out = {
0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1,
1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0,
1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1,
1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0,
1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1};
EXPECT_EQ(out_vec.size(), std_out.size());
for (uint32_t i = 0; i < out_vec.size(); i++) {
EXPECT_EQ(out_vec[i], std_out[i]);
}
}
TEST(Dropout, CPUDense) {
f::Scope scope;
p::CPUPlace place;
p::CPUDeviceContext ctx(place);
Compare(scope, ctx);
}
TEST(Dropout, GPUDense) {
f::Scope scope;
p::CUDAPlace place;
p::CUDADeviceContext ctx(place);
Compare(scope, ctx);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册