提交 d323831a 编写于 作者: Q qiaolongfei

add GetAttr to InferShapeContext

上级 848c317a
...@@ -233,6 +233,11 @@ class InferShapeContext { ...@@ -233,6 +233,11 @@ class InferShapeContext {
const Scope& scope() const { return scope_; } const Scope& scope() const { return scope_; }
template <typename T>
inline const T& GetAttr(const std::string& name) const {
return op_.GetAttr<T>(name);
}
size_t InputSize(const std::string& name) const { size_t InputSize(const std::string& name) const {
return op_.Inputs(name).size(); return op_.Inputs(name).size();
} }
......
...@@ -19,13 +19,12 @@ template <typename T> ...@@ -19,13 +19,12 @@ template <typename T>
class CPUGaussianRandomKernel : public framework::OpKernel { class CPUGaussianRandomKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
float mean = context.op().GetAttr<float>("mean"); float mean = context.GetAttr<float>("mean");
float std = context.op().GetAttr<float>("std"); float std = context.GetAttr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
static_cast<unsigned int>(context.op().GetAttr<int>("seed"));
std::minstd_rand engine; std::minstd_rand engine;
if (seed == 0) { if (seed == 0) {
seed = std::random_device()(); seed = std::random_device()();
......
...@@ -42,14 +42,13 @@ class GPUGaussianRandomKernel : public framework::OpKernel { ...@@ -42,14 +42,13 @@ class GPUGaussianRandomKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
static_cast<unsigned int>(context.op().GetAttr<int>("seed"));
if (seed == 0) { if (seed == 0) {
std::random_device rd; std::random_device rd;
seed = rd(); seed = rd();
} }
T mean = static_cast<T>(context.op().GetAttr<float>("mean")); T mean = static_cast<T>(context.GetAttr<float>("mean"));
T std = static_cast<T>(context.op().GetAttr<float>("std")); T std = static_cast<T>(context.GetAttr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0); thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims()); ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N, thrust::transform(index_sequence_begin, index_sequence_begin + N,
......
...@@ -27,7 +27,7 @@ class ScaleKernel : public framework::OpKernel { ...@@ -27,7 +27,7 @@ class ScaleKernel : public framework::OpKernel {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
tensor->mutable_data<T>(in->place()); tensor->mutable_data<T>(in->place());
auto scale = static_cast<T>(context.op().GetAttr<AttrType>("scale")); auto scale = static_cast<T>(context.GetAttr<AttrType>("scale"));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor); auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*in); auto eigen_in = framework::EigenVector<T>::Flatten(*in);
......
...@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel { ...@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad"); auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>("param_out"); auto param_out = ctx.Output<Tensor>("param_out");
float lr = ctx.op().GetAttr<float>("learning_rate"); float lr = ctx.GetAttr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
......
...@@ -26,16 +26,15 @@ class CPUUniformRandomKernel : public framework::OpKernel { ...@@ -26,16 +26,15 @@ class CPUUniformRandomKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
static_cast<unsigned int>(context.op().GetAttr<int>("seed"));
std::minstd_rand engine; std::minstd_rand engine;
if (seed == 0) { if (seed == 0) {
seed = std::random_device()(); seed = std::random_device()();
} }
engine.seed(seed); engine.seed(seed);
std::uniform_real_distribution<T> dist( std::uniform_real_distribution<T> dist(
static_cast<T>(context.op().GetAttr<float>("min")), static_cast<T>(context.GetAttr<float>("min")),
static_cast<T>(context.op().GetAttr<float>("max"))); static_cast<T>(context.GetAttr<float>("max")));
ssize_t size = framework::product(tensor->dims()); ssize_t size = framework::product(tensor->dims());
for (ssize_t i = 0; i < size; ++i) { for (ssize_t i = 0; i < size; ++i) {
data[i] = dist(engine); data[i] = dist(engine);
......
...@@ -45,14 +45,13 @@ class GPUUniformRandomKernel : public framework::OpKernel { ...@@ -45,14 +45,13 @@ class GPUUniformRandomKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed"));
static_cast<unsigned int>(context.op().GetAttr<int>("seed"));
if (seed == 0) { if (seed == 0) {
std::random_device rd; std::random_device rd;
seed = rd(); seed = rd();
} }
T min = static_cast<T>(context.op().GetAttr<float>("min")); T min = static_cast<T>(context.GetAttr<float>("min"));
T max = static_cast<T>(context.op().GetAttr<float>("max")); T max = static_cast<T>(context.GetAttr<float>("max"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0); thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims()); ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N, thrust::transform(index_sequence_begin, index_sequence_begin + N,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册