You need to sign in or sign up before continuing.
提交 7e830116 编写于 作者: Y Yu Yang

Try make pass

上级 72e3ba50
...@@ -44,7 +44,7 @@ AttrType AttrTypeID<std::vector<std::string>>() { ...@@ -44,7 +44,7 @@ AttrType AttrTypeID<std::vector<std::string>>() {
return STRINGS; return STRINGS;
} }
Attribute GetAttrValue(const AttrDesc& attr_desc) { Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
case paddle::framework::AttrType::INT: { case paddle::framework::AttrType::INT: {
return attr_desc.i(); return attr_desc.i();
......
...@@ -21,8 +21,7 @@ limitations under the License. */ ...@@ -21,8 +21,7 @@ limitations under the License. */
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/framework/attribute.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
namespace paddle { namespace paddle {
...@@ -37,7 +36,7 @@ typedef std::unordered_map<std::string, Attribute> AttributeMap; ...@@ -37,7 +36,7 @@ typedef std::unordered_map<std::string, Attribute> AttributeMap;
template <typename T> template <typename T>
AttrType AttrTypeID(); AttrType AttrTypeID();
Attribute GetAttrValue(const AttrDesc& attr_desc); Attribute GetAttrValue(const OpDesc::Attr& attr_desc);
// check whether a value(attribute) fit a certain limit // check whether a value(attribute) fit a certain limit
template <typename T> template <typename T>
......
...@@ -284,5 +284,11 @@ DDim::DDim(std::initializer_list<int> init_list) { ...@@ -284,5 +284,11 @@ DDim::DDim(std::initializer_list<int> init_list) {
*this = make_ddim(init_list); *this = make_ddim(init_list);
} }
std::string DDim::DebugString() const {
std::ostringstream ss;
ss << *this;
return ss.str();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -73,6 +73,8 @@ struct DDim { ...@@ -73,6 +73,8 @@ struct DDim {
DDim operator*(DDim d) const; DDim operator*(DDim d) const;
ssize_t size() const; ssize_t size() const;
std::string DebugString() const;
}; };
/** /**
......
...@@ -40,8 +40,8 @@ message OpDesc { ...@@ -40,8 +40,8 @@ message OpDesc {
}; };
message Var { message Var {
required string name; // e.g. "X" required string op_proto_name = 1;
optional int dup = 2 [ default = 0 ]; // e.g., "1" repeated string var_names = 2;
}; };
required string type = 3; required string type = 3;
...@@ -57,7 +57,7 @@ message OpProto { ...@@ -57,7 +57,7 @@ message OpProto {
message Var { message Var {
required string name = 1; required string name = 1;
required string comment = 2; required string comment = 2;
// OpDesc::Var::dup indices the duplica.
optional bool duplicable = 3 [ default = false ]; optional bool duplicable = 3 [ default = false ];
optional bool intermediate = 4 [ default = false ]; optional bool intermediate = 4 [ default = false ];
optional bool no_gradient = 5 [ default = false ]; optional bool no_gradient = 5 [ default = false ];
......
...@@ -13,12 +13,12 @@ express or implied. See the License for the specific language governing ...@@ -13,12 +13,12 @@ express or implied. See the License for the specific language governing
permissions and limitations under the License. */ permissions and limitations under the License. */
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_proto.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/**
class OpRegistry; class OpRegistry;
using VarIndexMap = std::unordered_map<std::string, int>; using VarIndexMap = std::unordered_map<std::string, int>;
...@@ -98,6 +98,7 @@ OperatorBase* BuildGradOp(const OperatorBase* op) { ...@@ -98,6 +98,7 @@ OperatorBase* BuildGradOp(const OperatorBase* op) {
TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG TransOpArg(op, grad_op, OpArgType::IN, OpArgType::OUT, out_idx, true); // IG
return grad_op; return grad_op;
} }
**/
OperatorBase* BuildGradOp(const OperatorBase* op) { return nullptr; }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -20,8 +20,8 @@ limitations under the License. */ ...@@ -20,8 +20,8 @@ limitations under the License. */
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/grad_op_builder.h" #include "paddle/framework/grad_op_builder.h"
#include "paddle/framework/op_desc.pb.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
namespace paddle { namespace paddle {
...@@ -44,25 +44,20 @@ class OpProtoAndCheckerMaker { ...@@ -44,25 +44,20 @@ class OpProtoAndCheckerMaker {
protected: protected:
struct VariableBuilder { struct VariableBuilder {
VarProto* var_; OpProto::Var* var_;
std::function<void()> on_multiple_;
std::function<void()> on_temporary_;
VariableBuilder& SetMultiple() { VariableBuilder& SetMultiple() {
var_->set_multiple(true); var_->set_duplicable(true);
on_multiple_();
return *this; return *this;
} }
VariableBuilder& SetTemporary() { VariableBuilder& SetTemporary() {
PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary"); var_->set_intermediate(true);
var_->set_temporary(true);
on_temporary_();
return *this; return *this;
} }
VariableBuilder& IgnoreGradient() { VariableBuilder& IgnoreGradient() {
var_->set_ignore_gradient(true); var_->set_no_gradient(true);
return *this; return *this;
} }
}; };
...@@ -72,8 +67,7 @@ class OpProtoAndCheckerMaker { ...@@ -72,8 +67,7 @@ class OpProtoAndCheckerMaker {
auto input = proto_->mutable_inputs()->Add(); auto input = proto_->mutable_inputs()->Add();
*input->mutable_name() = name; *input->mutable_name() = name;
*input->mutable_comment() = comment; *input->mutable_comment() = comment;
return VariableBuilder{input, [=] { this->SetHasMultipleInput(); }, return VariableBuilder{input};
nullptr};
} }
VariableBuilder AddOutput(const std::string& name, VariableBuilder AddOutput(const std::string& name,
...@@ -81,8 +75,7 @@ class OpProtoAndCheckerMaker { ...@@ -81,8 +75,7 @@ class OpProtoAndCheckerMaker {
auto output = proto_->mutable_outputs()->Add(); auto output = proto_->mutable_outputs()->Add();
*output->mutable_name() = name; *output->mutable_name() = name;
*output->mutable_comment() = comment; *output->mutable_comment() = comment;
return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); }, return VariableBuilder{output};
[=] { this->SetHasTemporaryOutput(); }};
} }
template <typename T> template <typename T>
...@@ -102,53 +95,6 @@ class OpProtoAndCheckerMaker { ...@@ -102,53 +95,6 @@ class OpProtoAndCheckerMaker {
} }
private: private:
void SetHasMultiple(const std::string& in_out, bool* flag) {
if (!*flag) {
AddAttr<std::vector<int>>(in_out + "_format",
"The multiple index of " + in_out +
"\n"
R"DOC(
This attribute is used by Paddle core framework. Paddle's Op support each input
or output could be a list of variable. This attribute is used to show how that
list organized.
e.g.
input = ["a", "b", "c", "d", "e", "f"]
input_format = [0, 4, 5, 6]
means
The number of all input variables this op is six, and they are segmented into
three inputs.
The first input is input[0:4], second is input[4:5], third is input[5:6].
)DOC",
/*generated*/ true);
*flag = true;
}
}
void SetHasMultipleInput() { SetHasMultiple("input", &has_multiple_input_); }
void SetHasMultipleOutput() {
SetHasMultiple("output", &has_multiple_output_);
}
void SetHasTemporaryOutput() {
if (!has_temporary_output_) {
AddAttr<std::vector<int>>("temporary_index",
R"DOC(The temporary index of output.
Not all output of Paddle Op is used by user. For faster computation, each op
could output some its internal state to other op, other op could take that
output to make compute faster.
Add a mark to which output is temporary is helpful for future optimization.
)DOC",
/*generated*/ true)
.SetDefault(std::vector<int>());
has_temporary_output_ = true;
}
}
void CheckNoDuplicatedInOutAttrs() { void CheckNoDuplicatedInOutAttrs() {
std::unordered_set<std::string> names; std::unordered_set<std::string> names;
auto checker = [&](const std::string& name) { auto checker = [&](const std::string& name) {
...@@ -169,15 +115,12 @@ Add a mark to which output is temporary is helpful for future optimization. ...@@ -169,15 +115,12 @@ Add a mark to which output is temporary is helpful for future optimization.
OpProto* proto_; OpProto* proto_;
OpAttrChecker* op_checker_; OpAttrChecker* op_checker_;
bool validated_{false}; bool validated_{false};
bool has_multiple_input_{false};
bool has_multiple_output_{false};
bool has_temporary_output_{false};
}; };
class OpRegistry { class OpRegistry {
using OpCreator = std::function<OperatorBase*()>; using OpCreator = std::function<OperatorBase*()>;
using VarIndexMap = std::unordered_map<std::string, int>; using VarIndexMap = std::unordered_map<std::string, int>;
using VarNameList = std::vector<std::string>; using VarNameMap = std::unordered_map<std::string, std::vector<std::string>>;
public: public:
template <typename OpType, typename ProtoMakerType> template <typename OpType, typename ProtoMakerType>
...@@ -213,8 +156,8 @@ class OpRegistry { ...@@ -213,8 +156,8 @@ class OpRegistry {
} }
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type, static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
const VarNameList& inputs, const VarNameMap& inputs,
const VarNameList& outputs, const VarNameMap& outputs,
const AttributeMap& attrs) { const AttributeMap& attrs) {
auto op_create_it = op_creators().find(type); auto op_create_it = op_creators().find(type);
PADDLE_ENFORCE(op_create_it != op_creators().end(), PADDLE_ENFORCE(op_create_it != op_creators().end(),
...@@ -230,27 +173,28 @@ class OpRegistry { ...@@ -230,27 +173,28 @@ class OpRegistry {
GenerateTempVariableName(op); GenerateTempVariableName(op);
{
auto var_index_it = VarIndexMaps().find(type);
if (var_index_it != VarIndexMaps().end()) {
op->in_out_idxs_ = var_index_it->second;
}
}
op->Init(); op->Init();
return std::shared_ptr<OperatorBase>(op); return std::shared_ptr<OperatorBase>(op);
} }
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) { static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
std::vector<std::string> inputs; VarNameMap inputs;
inputs.reserve((size_t)op_desc.inputs_size()); for (auto& input : op_desc.inputs()) {
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(), auto& var_names = inputs[input.op_proto_name()];
std::back_inserter(inputs)); auto& var_names_in_proto = input.var_names();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
}
std::vector<std::string> outputs; VarNameMap outputs;
outputs.reserve((size_t)op_desc.outputs_size()); for (auto& output : op_desc.outputs()) {
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(), auto& var_names = outputs[output.op_proto_name()];
std::back_inserter(outputs)); auto& var_names_in_proto = output.var_names();
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
std::back_inserter(var_names));
}
AttributeMap attrs; AttributeMap attrs;
for (auto& attr : op_desc.attrs()) { for (auto& attr : op_desc.attrs()) {
...@@ -303,11 +247,13 @@ class OpRegistry { ...@@ -303,11 +247,13 @@ class OpRegistry {
static void GenerateTempVariableName(OperatorBase* op) { static void GenerateTempVariableName(OperatorBase* op) {
static std::atomic<size_t> gUniqId(0UL); static std::atomic<size_t> gUniqId(0UL);
for (auto& outname : op->outputs_) { for (auto& output : op->outputs_) {
if (outname == kTempVarName) { for (auto& output_name : output.second) {
outname += op->type_; if (output_name == kTempVarName) {
outname += "@"; output_name += op->type_;
outname += std::to_string(gUniqId.fetch_add(1)); output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
} }
} }
} }
......
...@@ -34,83 +34,72 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const { ...@@ -34,83 +34,72 @@ ExecutionContext::GetEigenDevice<platform::GPUPlace, Eigen::GpuDevice>() const {
#endif #endif
const std::string& OperatorBase::Input(const std::string& name) const { const std::string& OperatorBase::Input(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, auto it = inputs_.find(name);
"Input Output Indices could not be nullptr"); PADDLE_ENFORCE(it != inputs_.end(), "Op %s does not have output %s", type_,
auto it = in_out_idxs_->find(name);
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
name); name);
if (attrs_.count("input_format") == 0) { PADDLE_ENFORCE_EQ(it->second.size(), 1UL,
return inputs_.at((size_t)it->second); "Op %s input %s should contain only one variable", type_,
} else { name);
const auto& input_format = GetAttr<std::vector<int>>("input_format"); return it->second[0];
int idx = input_format[it->second];
return inputs_.at((size_t)idx);
}
} }
std::vector<std::string> OperatorBase::Inputs(const std::string& name) const { const std::vector<std::string>& OperatorBase::Inputs(
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "IO Idx could not be nullptr"); const std::string& name) const {
auto input_format = GetAttr<std::vector<int>>("input_format"); return inputs_.at(name);
auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(input_format.at(static_cast<size_t>(offset) + 1) <=
static_cast<int>(inputs_.size()),
"Input Out Of Range");
return std::vector<std::string>{
inputs_.begin() + input_format.at(offset),
inputs_.begin() + input_format.at(offset + 1)};
} }
const std::string& OperatorBase::Output(const std::string& name) const { const std::string& OperatorBase::Output(const std::string& name) const {
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); auto it = outputs_.find(name);
auto it = in_out_idxs_->find(name); PADDLE_ENFORCE(it != outputs_.end(), "Op %s does not have output %s", type_,
PADDLE_ENFORCE(it != in_out_idxs_->end(), "no key [%s] in in_out_idxs_",
name); name);
if (attrs_.count("output_format") == 0) { PADDLE_ENFORCE_EQ(it->second.size(), 1UL,
return outputs_.at((size_t)it->second); "Op %s input %s should contain only one variable", type_,
} else { name);
const auto& output_format = GetAttr<std::vector<int>>("output_format"); return it->second[0];
int idx = output_format[it->second];
return outputs_.at((size_t)idx);
}
} }
std::vector<std::string> OperatorBase::Outputs(const std::string& name) const { const std::vector<std::string>& OperatorBase::Outputs(
PADDLE_ENFORCE(in_out_idxs_ != nullptr, "InOut Indice could not be nullptr"); const std::string& name) const {
auto output_format = GetAttr<std::vector<int>>("output_format"); return outputs_.at(name);
auto offset = in_out_idxs_->at(name);
PADDLE_ENFORCE(output_format.at(static_cast<size_t>(offset) + 1) <=
static_cast<int>(outputs_.size()),
"Output Out of Range");
return std::vector<std::string>{
outputs_.begin() + output_format.at(offset),
outputs_.begin() + output_format.at(offset + 1)};
} }
std::string OperatorBase::DebugString() const { std::string OperatorBase::DebugString() const {
std::stringstream ss; std::stringstream ss;
ss << "Op(" << type_ << "), inputs:("; ss << "Op(" << type_ << "), inputs:{";
for (size_t i = 0; i < inputs_.size(); ++i) { for (auto& input : inputs_) {
ss << inputs_[i]; ss << input.first << "[";
if (i != inputs_.size() - 1) { for (size_t i = 0; i < input.second.size(); ++i) {
ss << input.second[i];
if (i != input.second.size() - 1) {
ss << ", "; ss << ", ";
} }
} }
ss << "), outputs:("; ss << "]";
for (size_t i = 0; i < outputs_.size(); ++i) { }
ss << outputs_[i]; ss << "}, outputs:{";
if (i != outputs_.size() - 1) { for (auto& output : outputs_) {
ss << output.first << "[";
for (size_t i = 0; i < output.second.size(); ++i) {
ss << output.second[i];
if (i != output.second.size() - 1) {
ss << ", "; ss << ", ";
} }
} }
ss << ")."; ss << "]";
}
ss << "}.";
return ss.str(); return ss.str();
} }
void OperatorBase::Rename(const std::string& old_name, void OperatorBase::Rename(const std::string& old_name,
const std::string& new_name) { const std::string& new_name) {
std::replace(inputs_.begin(), inputs_.end(), old_name, new_name); for (auto& input : inputs_) {
std::replace(outputs_.begin(), outputs_.end(), old_name, new_name); std::replace(input.second.begin(), input.second.end(), old_name, new_name);
}
for (auto& output : outputs_) {
std::replace(output.second.begin(), output.second.end(), old_name,
new_name);
}
} }
} // namespace framework } // namespace framework
......
...@@ -21,8 +21,7 @@ limitations under the License. */ ...@@ -21,8 +21,7 @@ limitations under the License. */
#include <vector> #include <vector>
#include "paddle/framework/attribute.h" #include "paddle/framework/attribute.h"
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
...@@ -95,13 +94,12 @@ class OperatorBase { ...@@ -95,13 +94,12 @@ class OperatorBase {
const std::string& Input(const std::string& name) const; const std::string& Input(const std::string& name) const;
//! Get a input which has multiple variables. //! Get a input which has multiple variables.
//! TODO add a vector_view to prevent memory copy. const std::vector<std::string>& Inputs(const std::string& name) const;
std::vector<std::string> Inputs(const std::string& name) const;
//! Get a output with argument's name described in `op_proto` //! Get a output with argument's name described in `op_proto`
const std::string& Output(const std::string& name) const; const std::string& Output(const std::string& name) const;
//! Get an output which has multiple variables. //! Get an output which has multiple variables.
//! TODO add a vector_view to prevent memory copy. //! TODO add a vector_view to prevent memory copy.
std::vector<std::string> Outputs(const std::string& name) const; const std::vector<std::string>& Outputs(const std::string& name) const;
public: public:
std::string type_; std::string type_;
...@@ -109,13 +107,12 @@ class OperatorBase { ...@@ -109,13 +107,12 @@ class OperatorBase {
// I (Inputs) // I (Inputs)
// O (Outputs) // O (Outputs)
// OG (Output Gradients) // OG (Output Gradients)
std::vector<std::string> inputs_; std::unordered_map<std::string, std::vector<std::string>> inputs_;
// NOTE: in case of OpGrad, outputs_ contains // NOTE: in case of OpGrad, outputs_ contains
// IG (Inputs Gradients) // IG (Inputs Gradients)
std::vector<std::string> outputs_; std::unordered_map<std::string, std::vector<std::string>> outputs_;
AttributeMap attrs_; AttributeMap attrs_;
// store the arguments' offset described in op_desc.
std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
}; };
class OperatorContext { class OperatorContext {
...@@ -123,16 +120,12 @@ class OperatorContext { ...@@ -123,16 +120,12 @@ class OperatorContext {
OperatorContext(const OperatorBase* op, const Scope& scope) OperatorContext(const OperatorBase* op, const Scope& scope)
: op_(*op), scope_(scope) {} : op_(*op), scope_(scope) {}
size_t InputSize() const { return op_.inputs_.size(); } size_t InputSize(const std::string& name) const {
return op_.inputs_.at(name).size();
size_t OutputSize() const { return op_.outputs_.size(); }
const Variable* InputVar(const size_t index) const {
return scope_.FindVar(op_.inputs_.at(index));
} }
Variable* OutputVar(const size_t index) const { size_t OutputSize(const std::string& name) const {
return scope_.FindVar(op_.outputs_.at(index)); return op_.outputs_.at(name).size();
} }
const Variable* InputVar(const std::string& name) const { const Variable* InputVar(const std::string& name) const {
...@@ -164,24 +157,6 @@ class OperatorContext { ...@@ -164,24 +157,6 @@ class OperatorContext {
return res; return res;
} }
template <typename T>
const T* Input(const size_t index) const {
auto var = InputVar(index);
PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index);
return &var->Get<T>();
}
template <typename T>
T* Output(const size_t index) const {
auto var = OutputVar(index);
PADDLE_ENFORCE(
var != nullptr,
"Output(%d) not be nullptr, which means variable [%s] does not "
"exist in scope",
index, op_.outputs_[index]);
return var->GetMutable<T>();
}
template <typename T> template <typename T>
const T* Input(const std::string& name) const { const T* Input(const std::string& name) const {
auto var = InputVar(name); auto var = InputVar(name);
......
...@@ -20,15 +20,10 @@ namespace operators { ...@@ -20,15 +20,10 @@ namespace operators {
class AddOp : public OperatorWithKernel { class AddOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE_EQ(ctx.InputSize(), 2); PADDLE_ENFORCE_EQ(ctx.Input<Tensor>("X")->dims(),
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1); ctx.Input<Tensor>("Y")->dims(),
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr,
"Inputs of AddOp must all be set");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Outputs of AddOp must all be set");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
"Two input of Add Op's dimension must be same."); "Two input of Add Op's dimension must be same.");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims()); ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("X")->dims());
} }
}; };
......
...@@ -22,9 +22,9 @@ template <typename Place, typename T> ...@@ -22,9 +22,9 @@ template <typename Place, typename T>
class AddKernel : public OpKernel { class AddKernel : public OpKernel {
public: public:
void Compute(const ExecutionContext& context) const override { void Compute(const ExecutionContext& context) const override {
auto input0 = context.Input<Tensor>(0); auto* input0 = context.Input<Tensor>("X");
auto input1 = context.Input<Tensor>(1); auto* input1 = context.Input<Tensor>("Y");
auto output = context.Output<Tensor>(0); auto* output = context.Output<Tensor>("Out");
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
......
...@@ -20,19 +20,13 @@ namespace operators { ...@@ -20,19 +20,13 @@ namespace operators {
class OnehotCrossEntropyOp : public OperatorWithKernel { class OnehotCrossEntropyOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, auto *X = ctx.Input<Tensor>("X");
"Input size of OnehotCrossEntropyOp must be two"); auto *label = ctx.Input<Tensor>("label");
PADDLE_ENFORCE(ctx.OutputSize() == 1,
"Output size of OnehotCrossEntropyOp must be one"); PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2.");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.InputVar(1) != nullptr, PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1.");
"Inputs of OnehotCrossEntropyOp must all be set"); PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]);
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, ctx.Output<Tensor>("Y")->Resize({X->dims()[0]});
"Outputs of OnehotCrossEntropyOp must all be set");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims().size() == 2,
"X's dimension must be 2.");
PADDLE_ENFORCE(ctx.Output<Tensor>(0)->dims().size() == 1,
"label's dimension must be 1.");
ctx.Output<Tensor>(0)->Resize({ctx.Input<Tensor>(0)->dims()[0]});
} }
}; };
......
...@@ -43,7 +43,7 @@ class OnehotCrossEntropyOpKernel : public OpKernel { ...@@ -43,7 +43,7 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
void Compute(const ExecutionContext& ctx) const override { void Compute(const ExecutionContext& ctx) const override {
auto X = ctx.Input<Tensor>("X"); auto X = ctx.Input<Tensor>("X");
const T* Xdata = X->data<T>(); const T* Xdata = X->data<T>();
const int* label_data = ctx.Input<Tensor>(1)->data<int>(); const int* label_data = ctx.Input<Tensor>("label")->data<int>();
auto Y = ctx.Output<Tensor>("Y"); auto Y = ctx.Output<Tensor>("Y");
Y->mutable_data<T>(ctx.GetPlace()); Y->mutable_data<T>(ctx.GetPlace());
......
...@@ -20,16 +20,8 @@ namespace operators { ...@@ -20,16 +20,8 @@ namespace operators {
class FillZerosLikeOp : public framework::OperatorWithKernel { class FillZerosLikeOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1UL, ctx.Output<framework::Tensor>("Dst")->Resize(
"Input size of FillZerosLikeOp must be one."); ctx.Input<framework::Tensor>("Src")->dims());
PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
"Output size of AddOp must be one.");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr,
"Input of FillZerosLikeOp must be set.");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr,
"Output of FillZerosLikeOp must be set.");
ctx.Output<framework::Tensor>(0)->Resize(
ctx.Input<framework::Tensor>(0)->dims());
} }
}; };
......
...@@ -20,11 +20,9 @@ namespace operators { ...@@ -20,11 +20,9 @@ namespace operators {
class MeanOp : public OperatorWithKernel { class MeanOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Input size of AddOp must be one"); PADDLE_ENFORCE(ctx.InputVar("X") != nullptr,
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of AddOp must be one"); "Input of MeanOp must be initialized.");
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr && ctx.OutputVar(0) != nullptr, ctx.Output<Tensor>("Out")->Resize({1});
"Input/Output of MeanOp must be initialized.");
ctx.Output<Tensor>(0)->Resize(framework::make_ddim({1}));
} }
}; };
......
...@@ -20,9 +20,8 @@ namespace operators { ...@@ -20,9 +20,8 @@ namespace operators {
class MulOp : public OperatorWithKernel { class MulOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "The mul op must take two inputs"); auto dim0 = ctx.Input<Tensor>("X")->dims();
auto dim0 = ctx.Input<Tensor>(0)->dims(); auto dim1 = ctx.Input<Tensor>("Y")->dims();
auto dim1 = ctx.Input<Tensor>(1)->dims();
PADDLE_ENFORCE_EQ(dim0.size(), 2, PADDLE_ENFORCE_EQ(dim0.size(), 2,
"input X(%s) should be a tensor with 2 dims, a matrix", "input X(%s) should be a tensor with 2 dims, a matrix",
ctx.op_.Input("X")); ctx.op_.Input("X"));
...@@ -32,8 +31,7 @@ class MulOp : public OperatorWithKernel { ...@@ -32,8 +31,7 @@ class MulOp : public OperatorWithKernel {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
dim0[1], dim1[0], dim0[1], dim1[0],
"First matrix's width must be equal with second matrix's height."); "First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE_EQ(ctx.OutputSize(), 1, "The mul op takes only one output"); ctx.Output<Tensor>("Out")->Resize({dim0[0], dim1[1]});
ctx.Output<Tensor>(0)->Resize({dim0[0], dim1[1]});
} }
}; };
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
*/ */
#include "paddle/operators/net_op.h" #include "paddle/operators/net_op.h"
#include <set>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
namespace paddle { namespace paddle {
...@@ -23,36 +24,39 @@ namespace operators { ...@@ -23,36 +24,39 @@ namespace operators {
void NetOp::CompleteAddOp(bool calc) { void NetOp::CompleteAddOp(bool calc) {
add_op_done_ = true; add_op_done_ = true;
if (!calc) return; if (!calc) return;
std::unordered_set<std::string> input_set; std::set<std::string> input_set;
std::unordered_set<std::string> output_set; std::set<std::string> output_set;
std::unordered_set<std::string> temp_output; std::set<std::string> temp_output;
for (auto& op : ops_) { for (auto& op : ops_) {
for (auto& ipt : op->inputs_) { for (auto& ipt : op->inputs_) {
if (!Contains(output_set, ipt)) { // Not other op's output for (auto& var_name : ipt.second) {
input_set.insert(ipt); if (!Contains(output_set, var_name)) { // Not other op's output
input_set.insert(var_name);
} else { } else {
temp_output.insert(ipt); temp_output.insert(var_name);
}
} }
} }
for (auto& opt : op->outputs_) { for (auto& opt : op->outputs_) {
output_set.insert(opt); for (auto& var_name : opt.second) {
output_set.insert(var_name);
} }
} }
}
auto& inputs = inputs_["all"];
inputs.reserve(input_set.size());
std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs));
auto& outputs = outputs_["all"];
outputs.reserve(output_set.size());
std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs));
inputs_.reserve(input_set.size()); //! TODO figure out how to generate temporary_index in Network.
std::copy(input_set.begin(), input_set.end(), std::back_inserter(inputs_));
std::sort(inputs_.begin(), inputs_.end());
outputs_.reserve(output_set.size());
std::copy(output_set.begin(), output_set.end(), std::back_inserter(outputs_));
std::sort(outputs_.begin(), outputs_.end());
std::vector<int> tmp_index; std::vector<int> tmp_index;
tmp_index.reserve(temp_output.size()); tmp_index.reserve(temp_output.size());
int output_len = static_cast<int>(outputs_.size()); int output_len = static_cast<int>(outputs.size());
for (int i = 0; i < output_len; ++i) { for (int i = 0; i < output_len; ++i) {
if (Contains(temp_output, outputs_[i])) { if (Contains(temp_output, outputs[i])) {
tmp_index.push_back(i); tmp_index.push_back(i);
} }
} }
......
...@@ -14,8 +14,7 @@ limitations under the License. */ ...@@ -14,8 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/op_desc.pb.h" #include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_proto.pb.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h" #include "paddle/framework/operator.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
......
...@@ -89,12 +89,17 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const { ...@@ -89,12 +89,17 @@ void RecurrentAlgorithm::CreateScopes(const Scope& scope) const {
// create step net's temp inputs // create step net's temp inputs
for (auto& input : net_op->inputs_) { for (auto& input : net_op->inputs_) {
// the weight are located in parent scope // the weight are located in parent scope
if (!step_scope.FindVar(input)) for (auto& var_name : input.second) {
step_scope.NewVar(input)->GetMutable<Tensor>(); if (!step_scope.FindVar(var_name)) {
step_scope.NewVar(var_name)->GetMutable<Tensor>();
}
}
} }
// create stepnet's outputs // create stepnet's outputs
for (const auto& output : net_op->outputs_) { for (const auto& output : net_op->outputs_) {
step_scope.NewVar(output); for (auto& var_name : output.second) {
step_scope.NewVar(var_name);
}
} }
step_scopes->emplace_back(&step_scope); step_scopes->emplace_back(&step_scope);
} }
......
...@@ -19,16 +19,14 @@ namespace operators { ...@@ -19,16 +19,14 @@ namespace operators {
class RowWiseAddOp : public OperatorWithKernel { class RowWiseAddOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2UL, auto dim0 = ctx.Input<Tensor>("X")->dims();
"Two inputs is needed by rowwise add"); auto dim1 = ctx.Input<Tensor>("b")->dims();
auto dim0 = ctx.Input<Tensor>(0)->dims();
auto dim1 = ctx.Input<Tensor>(1)->dims();
PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix"); PADDLE_ENFORCE(dim0.size() == 2, "Input 0 must be matrix");
PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector"); PADDLE_ENFORCE(dim1.size() == 1, "The second input must be vector");
PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same"); PADDLE_ENFORCE(dim0[1] == dim1[0], "The width of two input must be same");
PADDLE_ENFORCE(ctx.OutputSize() == 1, "The output size must be 1"); PADDLE_ENFORCE(ctx.OutputSize("Out") == 1, "The output size must be 1");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims()); ctx.Output<Tensor>("Out")->Resize(ctx.Input<Tensor>("X")->dims());
} }
}; };
......
...@@ -25,8 +25,8 @@ class RowWiseAddKernel : public OpKernel { ...@@ -25,8 +25,8 @@ class RowWiseAddKernel : public OpKernel {
auto out = context.Output<Tensor>(0); auto out = context.Output<Tensor>(0);
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto input = EigenMatrix<T>::From(*context.Input<Tensor>(0)); auto input = EigenMatrix<T>::From(*context.Input<Tensor>("X"));
auto bias = EigenVector<T>::From(*context.Input<Tensor>(1)); auto bias = EigenVector<T>::From(*context.Input<Tensor>("b"));
auto output = EigenMatrix<T>::From(*out); auto output = EigenMatrix<T>::From(*out);
const int bias_size = bias.dimension(0); const int bias_size = bias.dimension(0);
......
...@@ -20,14 +20,10 @@ namespace operators { ...@@ -20,14 +20,10 @@ namespace operators {
class SGDOp : public OperatorWithKernel { class SGDOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 2, "Input size of SGDOp must be two"); PADDLE_ENFORCE(
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Output size of SGDOp must be one"); ctx.Input<Tensor>("param")->dims() == ctx.Input<Tensor>("grad")->dims(),
PADDLE_ENFORCE(ctx.InputVar(0) != nullptr, "inputs[0] mast be set");
PADDLE_ENFORCE(ctx.InputVar(1) != nullptr, "inputs[1] mast be set");
PADDLE_ENFORCE(ctx.OutputVar(0) != nullptr, "outputs[0] mast be set");
PADDLE_ENFORCE(ctx.Input<Tensor>(0)->dims() == ctx.Input<Tensor>(1)->dims(),
"Two input of SGD Op's dimension must be same."); "Two input of SGD Op's dimension must be same.");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims()); ctx.Output<Tensor>("param_out")->Resize(ctx.Input<Tensor>("param")->dims());
} }
}; };
......
...@@ -19,9 +19,7 @@ namespace operators { ...@@ -19,9 +19,7 @@ namespace operators {
class SigmoidOp : public OperatorWithKernel { class SigmoidOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1, "Sigmoid Op only have one input"); ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
PADDLE_ENFORCE(ctx.OutputSize() == 1, "Sigmoid Op only have one output");
ctx.Output<Tensor>(0)->Resize(ctx.Input<Tensor>(0)->dims());
} }
}; };
......
...@@ -20,12 +20,8 @@ namespace operators { ...@@ -20,12 +20,8 @@ namespace operators {
class SoftmaxOp : public OperatorWithKernel { class SoftmaxOp : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 1UL,
"Only one input is need for softmax");
PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL, PADDLE_ENFORCE(ctx.Input<Tensor>("X")->dims().size() == 2UL,
"The input of softmax op must be matrix"); "The input of softmax op must be matrix");
PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
"Only one output is need for softmax");
ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims()); ctx.Output<Tensor>("Y")->Resize(ctx.Input<Tensor>("X")->dims());
} }
}; };
...@@ -43,10 +39,6 @@ class SoftmaxOpMaker : public OpProtoAndCheckerMaker { ...@@ -43,10 +39,6 @@ class SoftmaxOpMaker : public OpProtoAndCheckerMaker {
class SoftmaxOpGrad : public OperatorWithKernel { class SoftmaxOpGrad : public OperatorWithKernel {
protected: protected:
void InferShape(const InferShapeContext &ctx) const override { void InferShape(const InferShapeContext &ctx) const override {
PADDLE_ENFORCE(ctx.InputSize() == 3UL,
"Input of SoftmaxOpGrad should be 3, X, Y, YG");
PADDLE_ENFORCE(ctx.OutputSize() == 1UL,
"Output of SoftmaxOpGrad should be 1");
PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null"); PADDLE_ENFORCE(ctx.InputVar("Y") != nullptr, "Input(Y) should not be null");
PADDLE_ENFORCE(ctx.InputVar(framework::GradVarName("Y")) != nullptr, PADDLE_ENFORCE(ctx.InputVar(framework::GradVarName("Y")) != nullptr,
"Input(Y@GRAD) should not be null"); "Input(Y@GRAD) should not be null");
......
...@@ -195,12 +195,28 @@ struct CompatibleType { ...@@ -195,12 +195,28 @@ struct CompatibleType {
typedef typename std::conditional<t1_to_t2, T2, T1>::type type; typedef typename std::conditional<t1_to_t2, T2, T1>::type type;
}; };
template <typename T>
inline std::string enforce_to_string(const T& val) {
std::ostringstream sout;
sout << val;
return sout.str();
}
template <>
inline std::string enforce_to_string(const std::string& val) {
return val;
}
template <>
inline std::string enforce_to_string(const char* const& val) {
return std::string(val);
}
#define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \ #define __PADDLE_BINARY_COMPARE(__VAL0, __VAL1, __CMP, __INV_CMP, ...) \
PADDLE_ENFORCE(__COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL0) \ PADDLE_ENFORCE(__COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL0) \
__CMP __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL1), \ __CMP __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL1), \
"enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \ "enforce %s " #__CMP " %s failed, %s " #__INV_CMP " %s\n%s", \
#__VAL0, #__VAL1, std::to_string(__VAL0), \ #__VAL0, #__VAL1, \
std::to_string(__VAL1), \ paddle::platform::enforce_to_string(__VAL0), \
paddle::platform::enforce_to_string(__VAL1), \
paddle::string::Sprintf("" __VA_ARGS__)); paddle::string::Sprintf("" __VA_ARGS__));
#define __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL) \ #define __COMPATIBLE_TYPE(__VAL0, __VAL1, __VAL) \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册