提交 9de6a4b3 编写于 作者: Y Yu Yang

Change `Op::GetAttr` to `Op::Attr`

Fix #3902
上级 ba43904a
...@@ -80,7 +80,7 @@ TEST(OpRegistry, CreateOp) { ...@@ -80,7 +80,7 @@ TEST(OpRegistry, CreateOp) {
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
float scale_get = op->GetAttr<float>("scale"); float scale_get = op->Attr<float>("scale");
ASSERT_EQ(scale_get, scale); ASSERT_EQ(scale_get, scale);
} }
...@@ -121,7 +121,7 @@ TEST(OpRegistry, DefaultValue) { ...@@ -121,7 +121,7 @@ TEST(OpRegistry, DefaultValue) {
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
ASSERT_EQ(op->GetAttr<float>("scale"), 1.0); ASSERT_EQ(op->Attr<float>("scale"), 1.0);
} }
TEST(OpRegistry, CustomChecker) { TEST(OpRegistry, CustomChecker) {
...@@ -172,6 +172,6 @@ TEST(OpRegistry, CustomChecker) { ...@@ -172,6 +172,6 @@ TEST(OpRegistry, CustomChecker) {
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
paddle::framework::Scope scope; paddle::framework::Scope scope;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
int test_attr = op->GetAttr<int>("test_attr"); int test_attr = op->Attr<int>("test_attr");
ASSERT_EQ(test_attr, 4); ASSERT_EQ(test_attr, 4);
} }
\ No newline at end of file
...@@ -69,7 +69,7 @@ class OperatorBase { ...@@ -69,7 +69,7 @@ class OperatorBase {
virtual ~OperatorBase() {} virtual ~OperatorBase() {}
template <typename T> template <typename T>
inline const T& GetAttr(const std::string& name) const { inline const T& Attr(const std::string& name) const {
PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap", PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
name); name);
return boost::get<T>(attrs_.at(name)); return boost::get<T>(attrs_.at(name));
...@@ -238,8 +238,8 @@ class InferShapeContext { ...@@ -238,8 +238,8 @@ class InferShapeContext {
const Scope& scope() const { return scope_; } const Scope& scope() const { return scope_; }
template <typename T> template <typename T>
inline const T& GetAttr(const std::string& name) const { inline const T& Attr(const std::string& name) const {
return op_.GetAttr<T>(name); return op_.Attr<T>(name);
} }
size_t InputSize(const std::string& name) const { size_t InputSize(const std::string& name) const {
......
...@@ -19,12 +19,12 @@ template <typename T> ...@@ -19,12 +19,12 @@ template <typename T>
class CPUGaussianRandomKernel : public framework::OpKernel { class CPUGaussianRandomKernel : public framework::OpKernel {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
float mean = context.GetAttr<float>("mean"); float mean = context.Attr<float>("mean");
float std = context.GetAttr<float>("std"); float std = context.Attr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed")); unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine; std::minstd_rand engine;
if (seed == 0) { if (seed == 0) {
seed = std::random_device()(); seed = std::random_device()();
...@@ -45,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -45,7 +45,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext& context) const override { void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
auto dims = GetAttr<std::vector<int>>("dims"); auto dims = Attr<std::vector<int>>("dims");
PADDLE_ENFORCE(dims.size() > 0UL, PADDLE_ENFORCE(dims.size() > 0UL,
"dims can be one int or array. dims must be set."); "dims can be one int or array. dims must be set.");
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(dims));
......
...@@ -47,8 +47,8 @@ class GPUGaussianRandomKernel : public framework::OpKernel { ...@@ -47,8 +47,8 @@ class GPUGaussianRandomKernel : public framework::OpKernel {
std::random_device rd; std::random_device rd;
seed = rd(); seed = rd();
} }
T mean = static_cast<T>(context.GetAttr<float>("mean")); T mean = static_cast<T>(context.Attr<float>("mean"));
T std = static_cast<T>(context.GetAttr<float>("std")); T std = static_cast<T>(context.Attr<float>("std"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0); thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims()); ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N, thrust::transform(index_sequence_begin, index_sequence_begin + N,
......
...@@ -109,7 +109,7 @@ void InitArgument(const ArgumentName& name, Argument* arg, ...@@ -109,7 +109,7 @@ void InitArgument(const ArgumentName& name, Argument* arg,
arg->step_scopes = op.Output(name.step_scopes); arg->step_scopes = op.Output(name.step_scopes);
auto inlinks = op.Inputs(name.inlinks); auto inlinks = op.Inputs(name.inlinks);
auto inlink_alias = op.GetAttr<std::vector<std::string>>(name.inlink_alias); auto inlink_alias = op.Attr<std::vector<std::string>>(name.inlink_alias);
PADDLE_ENFORCE(inlinks.size() == inlink_alias.size(), PADDLE_ENFORCE(inlinks.size() == inlink_alias.size(),
"the size of inlinks and inlink_alias don't match:%d,%d", "the size of inlinks and inlink_alias don't match:%d,%d",
inlinks.size(), inlink_alias.size()); inlinks.size(), inlink_alias.size());
...@@ -121,7 +121,7 @@ void InitArgument(const ArgumentName& name, Argument* arg, ...@@ -121,7 +121,7 @@ void InitArgument(const ArgumentName& name, Argument* arg,
} }
auto outlinks = op.Outputs(name.outlinks); auto outlinks = op.Outputs(name.outlinks);
auto outlink_alias = op.GetAttr<std::vector<std::string>>(name.outlink_alias); auto outlink_alias = op.Attr<std::vector<std::string>>(name.outlink_alias);
PADDLE_ENFORCE(outlinks.size() == outlink_alias.size(), PADDLE_ENFORCE(outlinks.size() == outlink_alias.size(),
"the size of outlinks and outlink_alias don't match:%d,%d", "the size of outlinks and outlink_alias don't match:%d,%d",
outlinks.size(), outlink_alias.size()); outlinks.size(), outlink_alias.size());
...@@ -135,8 +135,8 @@ void InitArgument(const ArgumentName& name, Argument* arg, ...@@ -135,8 +135,8 @@ void InitArgument(const ArgumentName& name, Argument* arg,
auto boot_memories = op.Inputs(name.boot_memories); auto boot_memories = op.Inputs(name.boot_memories);
// attributes // attributes
auto memories = op.GetAttr<std::vector<std::string>>(name.memories); auto memories = op.Attr<std::vector<std::string>>(name.memories);
auto pre_memories = op.GetAttr<std::vector<std::string>>(name.pre_memories); auto pre_memories = op.Attr<std::vector<std::string>>(name.pre_memories);
PADDLE_ENFORCE(memories.size() == boot_memories.size(), PADDLE_ENFORCE(memories.size() == boot_memories.size(),
"the size of memories, boot_memories don't match:%d,%d", "the size of memories, boot_memories don't match:%d,%d",
......
...@@ -60,7 +60,7 @@ class ScaleGradOp : public NetOp { ...@@ -60,7 +60,7 @@ class ScaleGradOp : public NetOp {
AppendOp(framework::OpRegistry::CreateOp( AppendOp(framework::OpRegistry::CreateOp(
"scale", {{"X", {Input(framework::GradVarName("Out"))}}}, "scale", {{"X", {Input(framework::GradVarName("Out"))}}},
{{"Out", {Output(framework::GradVarName("X"))}}}, {{"Out", {Output(framework::GradVarName("X"))}}},
{{"scale", GetAttr<AttrType>("scale")}})); {{"scale", Attr<AttrType>("scale")}}));
CompleteAddOp(false); CompleteAddOp(false);
} }
}; };
......
...@@ -27,7 +27,7 @@ class ScaleKernel : public framework::OpKernel { ...@@ -27,7 +27,7 @@ class ScaleKernel : public framework::OpKernel {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
tensor->mutable_data<T>(in->place()); tensor->mutable_data<T>(in->place());
auto scale = static_cast<T>(context.GetAttr<AttrType>("scale")); auto scale = static_cast<T>(context.Attr<AttrType>("scale"));
auto eigen_out = framework::EigenVector<T>::Flatten(*tensor); auto eigen_out = framework::EigenVector<T>::Flatten(*tensor);
auto eigen_in = framework::EigenVector<T>::Flatten(*in); auto eigen_in = framework::EigenVector<T>::Flatten(*in);
......
...@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel { ...@@ -31,7 +31,7 @@ class SGDOpKernel : public framework::OpKernel {
auto param = ctx.Input<Tensor>("param"); auto param = ctx.Input<Tensor>("param");
auto grad = ctx.Input<Tensor>("grad"); auto grad = ctx.Input<Tensor>("grad");
auto param_out = ctx.Output<Tensor>("param_out"); auto param_out = ctx.Output<Tensor>("param_out");
float lr = ctx.GetAttr<float>("learning_rate"); float lr = ctx.Attr<float>("learning_rate");
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
......
...@@ -26,15 +26,15 @@ class CPUUniformRandomKernel : public framework::OpKernel { ...@@ -26,15 +26,15 @@ class CPUUniformRandomKernel : public framework::OpKernel {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
T* data = tensor->mutable_data<T>(context.GetPlace()); T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.GetAttr<int>("seed")); unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
std::minstd_rand engine; std::minstd_rand engine;
if (seed == 0) { if (seed == 0) {
seed = std::random_device()(); seed = std::random_device()();
} }
engine.seed(seed); engine.seed(seed);
std::uniform_real_distribution<T> dist( std::uniform_real_distribution<T> dist(
static_cast<T>(context.GetAttr<float>("min")), static_cast<T>(context.Attr<float>("min")),
static_cast<T>(context.GetAttr<float>("max"))); static_cast<T>(context.Attr<float>("max")));
ssize_t size = framework::product(tensor->dims()); ssize_t size = framework::product(tensor->dims());
for (ssize_t i = 0; i < size; ++i) { for (ssize_t i = 0; i < size; ++i) {
data[i] = dist(engine); data[i] = dist(engine);
...@@ -48,10 +48,10 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -48,10 +48,10 @@ class UniformRandomOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext& ctx) const override { void InferShape(const framework::InferShapeContext& ctx) const override {
PADDLE_ENFORCE(GetAttr<float>("min") < GetAttr<float>("max"), PADDLE_ENFORCE(Attr<float>("min") < Attr<float>("max"),
"uniform_random's min must less then max"); "uniform_random's min must less then max");
auto* tensor = ctx.Output<framework::Tensor>("Out"); auto* tensor = ctx.Output<framework::Tensor>("Out");
auto dims = GetAttr<std::vector<int>>("dims"); auto dims = Attr<std::vector<int>>("dims");
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(dims));
} }
}; };
......
...@@ -50,8 +50,8 @@ class GPUUniformRandomKernel : public framework::OpKernel { ...@@ -50,8 +50,8 @@ class GPUUniformRandomKernel : public framework::OpKernel {
std::random_device rd; std::random_device rd;
seed = rd(); seed = rd();
} }
T min = static_cast<T>(context.GetAttr<float>("min")); T min = static_cast<T>(context.Attr<float>("min"));
T max = static_cast<T>(context.GetAttr<float>("max")); T max = static_cast<T>(context.Attr<float>("max"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0); thrust::counting_iterator<unsigned int> index_sequence_begin(0);
ssize_t N = framework::product(tensor->dims()); ssize_t N = framework::product(tensor->dims());
thrust::transform(index_sequence_begin, index_sequence_begin + N, thrust::transform(index_sequence_begin, index_sequence_begin + N,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册