未验证 提交 9c77b65c 编写于 作者: C chengduo 提交者: GitHub

Fix layers.uniform_random (#13823)

* fix layers.uniform_random

* fix uniform_random
test=develop

* remove var type set
test=develop

* fix similar error
test=develop
上级 5f2e8378
......@@ -23,14 +23,14 @@ namespace operators {
template <typename T>
class CPUUniformRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
framework::Tensor* tensor = nullptr;
void Compute(const framework::ExecutionContext &ctx) const override {
framework::Tensor *tensor = nullptr;
auto out_var = ctx.OutputVar("Out");
if (out_var->IsType<framework::LoDTensor>()) {
tensor = out_var->GetMutable<framework::LoDTensor>();
} else if (out_var->IsType<framework::SelectedRows>()) {
auto shape = ctx.Attr<std::vector<int>>("shape");
auto* selected_rows = out_var->GetMutable<framework::SelectedRows>();
auto *selected_rows = out_var->GetMutable<framework::SelectedRows>();
tensor = selected_rows->mutable_value();
tensor->Resize(framework::make_ddim(shape));
selected_rows->mutable_rows()->reserve(shape[0]);
......@@ -39,7 +39,7 @@ class CPUUniformRandomKernel : public framework::OpKernel<T> {
"uniform_random_op's output only"
"supports SelectedRows and LoDTensor");
}
T* data = tensor->mutable_data<T>(ctx.GetPlace());
T *data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
......@@ -60,14 +60,14 @@ class UniformRandomOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UniformRandomOp should not be null.");
PADDLE_ENFORCE(
ctx->Attrs().Get<float>("min") < ctx->Attrs().Get<float>("max"),
"uniform_random's min must less then max");
auto& shape = ctx->Attrs().Get<std::vector<int>>("shape");
auto &shape = ctx->Attrs().Get<std::vector<int>>("shape");
std::vector<int64_t> temp;
temp.reserve(shape.size());
for (auto dim : shape) {
......@@ -78,7 +78,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
......@@ -112,17 +112,17 @@ uniform distribution. The random result is in set [min, max].
class UniformRandomOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto out_var_name = op_desc.Output("Out").front();
if (block->FindRecursiveOrCreateVar(out_var_name).GetType() ==
framework::proto::VarType::SELECTED_ROWS) {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::SELECTED_ROWS);
} else {
block->FindRecursiveOrCreateVar(out_var_name)
.SetType(framework::proto::VarType::LOD_TENSOR);
auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(op_desc.GetAttr("dtype")));
auto out_var = block->FindRecursiveOrCreateVar(out_var_name);
if (out_var.GetType() != framework::proto::VarType::SELECTED_ROWS) {
out_var.SetType(framework::proto::VarType::LOD_TENSOR);
}
out_var.SetDataType(var_data_type);
}
};
......
......@@ -14,6 +14,8 @@
from __future__ import print_function
from .layer_function_generator import generate_layer_fn, generate_layer_fn_noattr
from .. import core
from ..framework import convert_np_dtype_to_dtype_
__activations_noattr__ = [
'sigmoid',
......@@ -58,8 +60,11 @@ _uniform_random_ = generate_layer_fn('uniform_random')
def uniform_random(shape, dtype=None, min=None, max=None, seed=None):
locals_var = locals().keys()
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
kwargs = dict()
for name in locals():
for name in locals_var:
val = locals()[name]
if val is not None:
kwargs[name] = val
......@@ -78,8 +83,9 @@ _hard_shrink_ = generate_layer_fn('hard_shrink')
def hard_shrink(x, threshold=None):
locals_var = locals().keys()
kwargs = dict()
for name in locals():
for name in locals_var:
val = locals()[name]
if val is not None:
kwargs[name] = val
......@@ -99,12 +105,12 @@ _cum_sum_ = generate_layer_fn('cumsum')
def cumsum(x, axis=None, exclusive=None, reverse=None):
locals_var = locals().keys()
kwargs = dict()
for name in locals():
for name in locals_var:
val = locals()[name]
if val is not None:
kwargs[name] = val
return _cum_sum_(**kwargs)
......@@ -121,8 +127,9 @@ _thresholded_relu_ = generate_layer_fn('thresholded_relu')
def thresholded_relu(x, threshold=None):
locals_var = locals().keys()
kwargs = dict()
for name in locals():
for name in locals_var:
val = locals()[name]
if val is not None:
kwargs[name] = val
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册