提交 8acc0106 编写于 作者: D dzhwinter 提交者: GitHub

Merge branch 'develop' into macro

......@@ -206,7 +206,7 @@ MulOp(const std::string &type, const framework::VariableNameMap &inputs,
- `REGISTER_OP` : 注册`ops::MulOp`类,类型名为`mul`,该类的`ProtoMaker`为`ops::MulOpMaker`,注册`ops::MulOpGrad`,类型名为`mul_grad`。
- `REGISTER_OP_WITHOUT_GRADIENT` : 用于注册没有反向的Op。
- `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::MulKernel`类。
- `REGISTER_OP_CPU_KERNEL` :注册`ops::MulKernel`类,并特化模板参数为`paddle::platform::CPUPlace`和`float`类型,同理,注册`ops::MulGradKernel`类。
-`.cu`文件中注册GPU Kernel。
......
......@@ -205,7 +205,7 @@ The definition of its corresponding backward operator, if applicable, is similar
- `REGISTER_OP` registers the `ops::MulOp` class, type named `mul`, its type `ProtoMaker` is `ops::MulOpMaker`, registering `ops::MulOpGrad` as `mul_grad`.
- `REGISTER_OP_WITHOUT_GRADIENT` registers an operator without gradient.
- `REGISTER_OP_CPU_KERNEL` registers `ops::MulKernel` class and specialized template types `paddle::platform::CPUPlace` and `float`, which also registers `ops::MulKernel`.
- `REGISTER_OP_CPU_KERNEL` registers `ops::MulKernel` class and specialized template types `paddle::platform::CPUPlace` and `float`, which also registers `ops::MulGradKernel`.
- Registering GPU Kernel in `.cu` files
......
......@@ -21,20 +21,12 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/type_defs.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/variant.h"
namespace paddle {
namespace framework {
// The order should be as same as framework.proto
typedef boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*>
Attribute;
typedef std::unordered_map<std::string, Attribute> AttributeMap;
ProgramDesc& GetProgramDesc();
template <typename T>
......
......@@ -112,6 +112,30 @@ const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap()
return attrs_;
}
struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string &v) const { attr_->set_s(v); }
void operator()(bool b) const { attr_->set_b(b); }
void operator()(const std::vector<int> &v) const {
VectorToRepeated(v, attr_->mutable_ints());
}
void operator()(const std::vector<float> &v) const {
VectorToRepeated(v, attr_->mutable_floats());
}
void operator()(const std::vector<std::string> &v) const {
VectorToRepeated(v, attr_->mutable_strings());
}
void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools());
}
void operator()(BlockDesc *desc) const { attr_->set_block_idx(desc->idx()); }
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
void OpDescBind::Sync() {
if (need_update_) {
this->op_desc_.mutable_inputs()->Clear();
......@@ -134,7 +158,8 @@ void OpDescBind::Sync() {
attr_desc->set_name(attr.first);
attr_desc->set_type(
static_cast<framework::AttrType>(attr.second.which() - 1));
boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second);
SetAttrDescVisitor visitor(attr_desc);
boost::apply_visitor(visitor, attr.second);
}
need_update_ = false;
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/framework/attribute.h"
#include "paddle/framework/type_defs.h"
#include "paddle/framework/var_desc.h"
namespace paddle {
......@@ -61,48 +62,22 @@ class OpDescBind {
void SetBlockAttr(const std::string &name, BlockDescBind &block);
// Only be used in C++
void SetAttrMap(const std::unordered_map<std::string, Attribute> &attr_map);
void SetAttrMap(const AttributeMap &attr_map);
Attribute GetAttr(const std::string &name) const;
int GetBlockAttr(const std::string &name) const;
// Only be used in C++
const std::unordered_map<std::string, Attribute> &GetAttrMap() const;
const AttributeMap &GetAttrMap() const;
private:
struct SetAttrDescVisitor : public boost::static_visitor<void> {
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
mutable OpDesc::Attr *attr_;
void operator()(int v) const { attr_->set_i(v); }
void operator()(float v) const { attr_->set_f(v); }
void operator()(const std::string &v) const { attr_->set_s(v); }
void operator()(bool b) const { attr_->set_b(b); }
void operator()(const std::vector<int> &v) const {
VectorToRepeated(v, attr_->mutable_ints());
}
void operator()(const std::vector<float> &v) const {
VectorToRepeated(v, attr_->mutable_floats());
}
void operator()(const std::vector<std::string> &v) const {
VectorToRepeated(v, attr_->mutable_strings());
}
void operator()(const std::vector<bool> &v) const {
VectorToRepeated(v, attr_->mutable_bools());
}
void operator()(BlockDesc *desc) const {
attr_->set_block_idx(desc->idx());
}
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
};
void Sync();
OpDesc op_desc_;
std::unordered_map<std::string, std::vector<std::string>> inputs_;
std::unordered_map<std::string, std::vector<std::string>> outputs_;
std::unordered_map<std::string, Attribute> attrs_;
VariableNameMap inputs_;
VariableNameMap outputs_;
AttributeMap attrs_;
// need_update_ indicate there some local changes not be synchronized. If
// local changes should be synchronized, need_update_ should be set to true.
......
......@@ -19,16 +19,11 @@
#include <unordered_map>
#include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.h"
#include "paddle/platform/macros.h"
#include "paddle/framework/type_defs.h"
namespace paddle {
namespace framework {
class OperatorBase;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
class GradOpDescMakerBase {
public:
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <functional>
#include <map>
#include "paddle/platform/variant.h"
namespace paddle {
namespace framework {
class OperatorBase;
using VariableNameMap = std::map<std::string, std::vector<std::string>>;
// The order should be as same as framework.proto
using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*>;
using AttributeMap = std::unordered_map<std::string, Attribute>;
using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
} // namespace framework
} // namespace paddle
......@@ -185,7 +185,7 @@ inline void throw_on_error(T e) {
std::make_exception_ptr( \
std::runtime_error(paddle::string::Sprintf(__VA_ARGS__))), \
__FILE__, __LINE__); \
} while (0)
} while (false)
#define PADDLE_ENFORCE(...) \
do { \
......@@ -195,7 +195,7 @@ inline void throw_on_error(T e) {
throw ::paddle::platform::EnforceNotMet(std::current_exception(), \
__FILE__, __LINE__); \
} \
} while (0)
} while (false)
/*
* Some enforce helpers here, usage:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册