提交 972ae6e9 编写于 作者: Y Yancey1989

random selected rows value

上级 31464f34
...@@ -24,7 +24,15 @@ template <typename T> ...@@ -24,7 +24,15 @@ template <typename T>
class CPUUniformRandomKernel : public framework::OpKernel<T> { class CPUUniformRandomKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* tensor = ctx.Output<framework::Tensor>("Out"); framework::Tensor* tensor(nullptr);
auto out_var = ctx.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
tensor = ctx.Output<framework::LoDTensor>("Out");
} else if (out_var->IsType<framework::SelectedRows>()) {
tensor = ctx.Output<framework::SelectedRows>("Out")->mutable_value();
} else {
PADDLE_THROW("Only support LoDTensor and SelectedRows.");
}
T* data = tensor->mutable_data<T>(ctx.GetPlace()); T* data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed")); unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine; std::minstd_rand engine;
...@@ -36,6 +44,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> { ...@@ -36,6 +44,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
static_cast<T>(ctx.Attr<float>("min")), static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max"))); static_cast<T>(ctx.Attr<float>("max")));
int64_t size = tensor->numel(); int64_t size = tensor->numel();
VLOG(3) << "size = " << size;
for (int64_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine); data[i] = dist(engine);
} }
...@@ -55,6 +64,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -55,6 +64,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
"uniform_random's min must less then max"); "uniform_random's min must less then max");
auto& shape = ctx->Attrs().Get<std::vector<int>>("shape"); auto& shape = ctx->Attrs().Get<std::vector<int>>("shape");
std::vector<int64_t> temp; std::vector<int64_t> temp;
VLOG(3) << "shape.size() = " << shape.size();
temp.reserve(shape.size()); temp.reserve(shape.size());
for (auto dim : shape) { for (auto dim : shape) {
temp.push_back(static_cast<int64_t>(dim)); temp.push_back(static_cast<int64_t>(dim));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册