提交 15400748 编写于 作者: Q qijun

follow comments and refine codes

上级 932402c1
......@@ -378,7 +378,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
backward_descs[dup_op[i]]->Rename(out_name, new_name);
sum_op_inputs.emplace_back(new_name);
}
LOG(INFO) << "fuck " << sum_op_inputs.size();
LOG(INFO) << "sum_op_inputs size " << sum_op_inputs.size();
std::unique_ptr<OpDescBind> sum_op(new OpDescBind(
"sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {}));
pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
......
......@@ -60,15 +60,13 @@ void AddOp(const std::string& type, const VariableNameMap& inputs,
op->SetAttrMap(attrs);
}
std::once_flag set_variable_flag;
// Tensors in feed value variable will only be in CPUPlace
// So we can memcpy the data from vector<T> to feed_value
// So we can memcpy the data from vector<T> to feed_value
template <typename T>
void SetFeedVariable(const std::vector<std::vector<T>>& inputs) {
typedef std::vector<paddle::framework::Tensor> FeedInputs;
Variable* g_feed_value = GetGlobalScope()->FindVar("feed_value");
FeedInputs& feed_inputs = *(g_feed_value->GetMutable<FeedInputs>());
auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::Tensor>>());
size_t size = inputs.size();
feed_inputs.resize(size);
for (size_t i = 0; i < size; i++) {
......@@ -82,9 +80,9 @@ void SetFeedVariable(const std::vector<std::vector<T>>& inputs) {
// So we can memcpy the data from fetch_value to vector<T>
template <typename T>
std::vector<std::vector<T>> GetFetchVariable() {
typedef std::vector<paddle::framework::Tensor> FetchOutputs;
Variable* g_fetch_value = GetGlobalScope()->FindVar("fetch_value");
FetchOutputs& fetch_outputs = *(g_fetch_value->GetMutable<FetchOutputs>());
auto& fetch_outputs =
*(g_fetch_value->GetMutable<std::vector<paddle::framework::Tensor>>());
size_t size = fetch_outputs.size();
std::vector<std::vector<T>> result;
......@@ -143,22 +141,22 @@ class ExecutorTesterRandom : public ::testing::Test {
{{"Out", {"l2_distance"}}, {"sub_result", {"l2_distance_sub"}}}, {},
root_block);
AddOp("gaussian_random", {}, {{"Out", {"l2_distance@GRAD"}}},
{{"dims", std::vector<int>{batch_size, 1}}}, root_block);
AppendBackward(program, {});
program.Proto();
for (auto& op : pdesc_.blocks(0).ops()) {
if (op.type() == "sum") {
LOG(INFO) << "Here";
for (auto& var : op.inputs()) {
for (auto& argu : var.arguments()) {
LOG(INFO) << var.parameter() << " " << argu;
}
}
}
}
// AddOp("gaussian_random", {}, {{"Out", {"l2_distance@GRAD"}}},
// {{"dims", std::vector<int>{batch_size, 1}}}, root_block);
// AppendBackward(program, {});
// program.Proto();
// for (auto& op : pdesc_.blocks(0).ops()) {
// if (op.type() == "sum") {
// LOG(INFO) << "Here";
// for (auto& var : op.inputs()) {
// for (auto& argu : var.arguments()) {
// LOG(INFO) << var.parameter() << " " << argu;
// }
// }
// }
// }
AddOp("fetch", {{"Input", {"l2_distance"}}}, {},
{{"dims", std::vector<int>{batch_size}}, {"col", 1}}, root_block);
......
......@@ -23,15 +23,15 @@ class FeedOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
typedef std::vector<framework::Tensor> FeedInputs;
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null.");
int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_feed_variable =
framework::GetGlobalScope()->FindVar("feed_value");
const FeedInputs& tensors = g_feed_variable->Get<FeedInputs>();
const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>();
PADDLE_ENFORCE_GT(tensors.size(), col);
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
auto in_dim = tensors[col].dims();
ctx->SetOutputDim("Out", in_dim);
// TODO(qijun): need to handle LodTensor later
......
......@@ -23,13 +23,13 @@ template <typename T>
class FeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
typedef std::vector<framework::Tensor> FeedInputs;
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
framework::Variable* g_feed_variable =
framework::GetGlobalScope()->FindVar("feed_value");
int col = ctx.template Attr<int>("col");
const FeedInputs& tensors = g_feed_variable->Get<FeedInputs>();
const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>();
out->CopyFrom<T>(tensors[col], ctx.GetPlace());
}
};
......
......@@ -23,13 +23,13 @@ class FetchOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
typedef std::vector<framework::Tensor> FetchOutputs;
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null.");
int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_fetch_variable =
framework::GetGlobalScope()->FindVar("fetch_value");
FetchOutputs* tensors = g_fetch_variable->GetMutable<FetchOutputs>();
auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
if (tensors->size() < static_cast<size_t>(col + 1)) {
tensors->resize(col + 1);
}
......
......@@ -23,12 +23,12 @@ template <typename T>
class FetchKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
typedef std::vector<framework::Tensor> FetchOutputs;
const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
int col = ctx.template Attr<int>("col");
framework::Variable* g_fetch_variable =
framework::GetGlobalScope()->FindVar("fetch_value");
FetchOutputs* tensors = g_fetch_variable->GetMutable<FetchOutputs>();
auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
(*tensors)[col].mutable_data<T>(platform::CPUPlace());
(*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册