提交 d323831a 编写于 作者: Q qiaolongfei

add GetAttr to InferShapeContext

上级 848c317a
......@@ -233,6 +233,11 @@ class InferShapeContext {
const Scope& scope() const { return scope_; }
template <typename T>
inline const T& GetAttr(const std::string& name) const {
return op_.GetAttr<T>(name);
}
size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size();
}
......
......@@ -19,13 +19,12 @@ template <typename T>
class CPUGaussianRandomKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op().GetAttr<float>("mean");
float std = context.op().GetAttr<float>("std");
float mean = context.GetAttr<float>("mean");
float std = context.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op().GetAttr<int>("seed"));
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
......
......@@ -42,14 +42,13 @@ class GPUGaussianRandomKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op().GetAttr<int>("seed"));
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
if (seed == 0) {
std::random_device rd;
seed = rd();
}
T mean = static_cast<T>(context.op().GetAttr<float>("mean"));
T std = static_cast<T>(context.op().GetAttr<float>("std"));
T mean = static_cast<T>(context.GetAttr<float>("mean"));
T std = static_cast<T>(context.GetAttr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N,
......
......@@ -27,7 +27,7 @@ class ScaleKernel : public framework::OpKernel {
auto* in = context.Input<framework::Tensor>("X");
tensor->mutable_data<T>(in->place());
auto scale = static_cast<T>(context.op().GetAttr<AttrType>("scale"));
auto scale = static_cast<T>(context.GetAttr<AttrType>("scale"));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*in);
......
......@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel {
auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>("param_out");
float lr = ctx.op().GetAttr<float>("learning_rate");
float lr = ctx.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace());
......
......@@ -26,16 +26,15 @@ class CPUUniformRandomKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op().GetAttr<int>("seed"));
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
std::minstd_rand engine;
if (seed == 0) {
seed = std::random_device()();
}
engine.seed(seed);
std::uniform_real_distribution<T> dist(
static_cast<T>(context.op().GetAttr<float>("min")),
static_cast<T>(context.op().GetAttr<float>("max")));
static_cast<T>(context.GetAttr<float>("min")),
static_cast<T>(context.GetAttr<float>("max")));
ssize_t size = framework::product(tensor->dims());
for (ssize_t i = 0; i < size; ++i) {
data[i] = dist(engine);
......
......@@ -45,14 +45,13 @@ class GPUUniformRandomKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed =
static_cast<unsigned int>(context.op().GetAttr<int>("seed"));
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
if (seed == 0) {
std::random_device rd;
seed = rd();
}
T min = static_cast<T>(context.op().GetAttr<float>("min"));
T max = static_cast<T>(context.op().GetAttr<float>("max"));
T min = static_cast<T>(context.GetAttr<float>("min"));
T max = static_cast<T>(context.GetAttr<float>("max"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册