提交 15400748 编写于 作者: Q qijun

follow comments and refine codes

上级 932402c1
...@@ -378,7 +378,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -378,7 +378,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
backward_descs[dup_op[i]]->Rename(out_name, new_name); backward_descs[dup_op[i]]->Rename(out_name, new_name);
sum_op_inputs.emplace_back(new_name); sum_op_inputs.emplace_back(new_name);
} }
LOG(INFO) << "fuck " << sum_op_inputs.size(); LOG(INFO) << "sum_op_inputs size " << sum_op_inputs.size();
std::unique_ptr<OpDescBind> sum_op(new OpDescBind( std::unique_ptr<OpDescBind> sum_op(new OpDescBind(
"sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {})); "sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {}));
pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)}); pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
......
...@@ -60,15 +60,13 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, ...@@ -60,15 +60,13 @@ void AddOp(const std::string& type, const VariableNameMap& inputs,
op->SetAttrMap(attrs); op->SetAttrMap(attrs);
} }
std::once_flag set_variable_flag;
// Tensors in feed value variable will only be in CPUPlace // Tensors in feed value variable will only be in CPUPlace
// So we can memcpy the data from vector<T> to feed_value // So we can memcpy the data from vector<T> to feed_value
template <typename T> template <typename T>
void SetFeedVariable(const std::vector<std::vector<T>>& inputs) { void SetFeedVariable(const std::vector<std::vector<T>>& inputs) {
typedef std::vector<paddle::framework::Tensor> FeedInputs;
Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value"); Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value");
FeedInputs& feed_inputs = *(g_feed_value->GetMutable<FeedInputs>()); auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::Tensor>>());
size_t size = inputs.size(); size_t size = inputs.size();
feed_inputs.resize(size); feed_inputs.resize(size);
for (size_t i = 0; i < size; i++) { for (size_t i = 0; i < size; i++) {
...@@ -82,9 +80,9 @@ void SetFeedVariable(const std::vector<std::vector<T>>& inputs) { ...@@ -82,9 +80,9 @@ void SetFeedVariable(const std::vector<std::vector<T>>& inputs) {
// So we can memcpy the data from fetch_value to vector<T> // So we can memcpy the data from fetch_value to vector<T>
template <typename T> template <typename T>
std::vector<std::vector<T>> GetFetchVariable() { std::vector<std::vector<T>> GetFetchVariable() {
typedef std::vector<paddle::framework::Tensor> FetchOutputs;
Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value"); Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value");
FetchOutputs& fetch_outputs = *(g_fetch_value->GetMutable<FetchOutputs>()); auto& fetch_outputs =
*(g_fetch_value->GetMutable<std::vector<paddle::framework::Tensor>>());
size_t size = fetch_outputs.size(); size_t size = fetch_outputs.size();
std::vector<std::vector<T>> result; std::vector<std::vector<T>> result;
...@@ -143,22 +141,22 @@ class ExecutorTesterRandom : public ::testing::Test { ...@@ -143,22 +141,22 @@ class ExecutorTesterRandom : public ::testing::Test {
{{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {}, {{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {},
root_block); root_block);
AddOp("gaussian_random", {}, {{"Out", {"l2_distance@GRAD"}}}, // AddOp("gaussian_random", {}, {{"Out", {"l2_distance@GRAD"}}},
{{"dims", std::vector<int>{batch_size, 1}}}, root_block); // {{"dims", std::vector<int>{batch_size, 1}}}, root_block);
AppendBackward(program, {}); // AppendBackward(program, {});
program.Proto(); // program.Proto();
for (auto& op : pdesc_.blocks(0).ops()) { // for (auto& op : pdesc_.blocks(0).ops()) {
if (op.type() == "sum") { // if (op.type() == "sum") {
LOG(INFO) << "Here"; // LOG(INFO) << "Here";
for (auto& var : op.inputs()) { // for (auto& var : op.inputs()) {
for (auto& argu : var.arguments()) { // for (auto& argu : var.arguments()) {
LOG(INFO) << var.parameter() << " " << argu; // LOG(INFO) << var.parameter() << " " << argu;
} // }
} // }
} // }
} // }
AddOp("fetch", {{"Input", {"l2_distance"}}}, {}, AddOp("fetch", {{"Input", {"l2_distance"}}}, {},
{{"dims", std::vector<int>{batch_size}}, {"col", 1}}, root_block); {{"dims", std::vector<int>{batch_size}}, {"col", 1}}, root_block);
......
...@@ -23,15 +23,15 @@ class FeedOp : public framework::OperatorWithKernel { ...@@ -23,15 +23,15 @@ class FeedOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
typedef std::vector<framework::Tensor> FeedInputs;
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null.");
int col = ctx->Attrs().Get<int>("col"); int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_feed_variable = framework::Variable* g_feed_variable =
framework::GetGlobalScope()->FindVar("feed_value"); framework::GetGlobalScope()->FindVar("feed_value");
const FeedInputs& tensors = g_feed_variable->Get<FeedInputs>(); const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>();
PADDLE_ENFORCE_GT(tensors.size(), col); PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
auto in_dim = tensors[col].dims(); auto in_dim = tensors[col].dims();
ctx->SetOutputDim("Out", in_dim); ctx->SetOutputDim("Out", in_dim);
// TODO(qijun): need to handle LodTensor later // TODO(qijun): need to handle LodTensor later
......
...@@ -23,13 +23,13 @@ template <typename T> ...@@ -23,13 +23,13 @@ template <typename T>
class FeedKernel : public framework::OpKernel<T> { class FeedKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
typedef std::vector<framework::Tensor> FeedInputs;
framework::Tensor* out = ctx.Output<framework::Tensor>("Out"); framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
framework::Variable* g_feed_variable = framework::Variable* g_feed_variable =
framework::GetGlobalScope()->FindVar("feed_value"); framework::GetGlobalScope()->FindVar("feed_value");
int col = ctx.template Attr<int>("col"); int col = ctx.template Attr<int>("col");
const FeedInputs& tensors = g_feed_variable->Get<FeedInputs>(); const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>();
out->CopyFrom<T>(tensors[col], ctx.GetPlace()); out->CopyFrom<T>(tensors[col], ctx.GetPlace());
} }
}; };
......
...@@ -23,13 +23,13 @@ class FetchOp : public framework::OperatorWithKernel { ...@@ -23,13 +23,13 @@ class FetchOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
typedef std::vector<framework::Tensor> FetchOutputs;
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null."); PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null.");
int col = ctx->Attrs().Get<int>("col"); int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_fetch_variable = framework::Variable* g_fetch_variable =
framework::GetGlobalScope()->FindVar("fetch_value"); framework::GetGlobalScope()->FindVar("fetch_value");
FetchOutputs* tensors = g_fetch_variable->GetMutable<FetchOutputs>(); auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
if (tensors->size() < static_cast<size_t>(col + 1)) { if (tensors->size() < static_cast<size_t>(col + 1)) {
tensors->resize(col + 1); tensors->resize(col + 1);
} }
......
...@@ -23,12 +23,12 @@ template <typename T> ...@@ -23,12 +23,12 @@ template <typename T>
class FetchKernel : public framework::OpKernel<T> { class FetchKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
typedef std::vector<framework::Tensor> FetchOutputs;
const framework::Tensor* input = ctx.Input<framework::Tensor>("Input"); const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
int col = ctx.template Attr<int>("col"); int col = ctx.template Attr<int>("col");
framework::Variable* g_fetch_variable = framework::Variable* g_fetch_variable =
framework::GetGlobalScope()->FindVar("fetch_value"); framework::GetGlobalScope()->FindVar("fetch_value");
FetchOutputs* tensors = g_fetch_variable->GetMutable<FetchOutputs>(); auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
(*tensors)[col].mutable_data<T>(platform::CPUPlace()); (*tensors)[col].mutable_data<T>(platform::CPUPlace());
(*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace()); (*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace());
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册