提交 45c4dcaa 编写于 作者: Q qijun

add fetch operator

上级 20725f2d
...@@ -75,15 +75,15 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) { ...@@ -75,15 +75,15 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope) {
device_context->Wait(); device_context->Wait();
} }
// // print tensor value // // print tensor value
for (auto& var : block.vars()) { // for (auto& var : block.vars()) {
std::cout << var.name() << std::endl; // std::cout << var.name() << std::endl;
auto v = scope->FindVar(var.name()); // auto v = scope->FindVar(var.name());
const LoDTensor& t = v->Get<LoDTensor>(); // const LoDTensor& t = v->Get<LoDTensor>();
for (int i = 0; i < t.numel(); ++i) { // for (int i = 0; i < t.numel(); ++i) {
std::cout << t.data<float>()[i] << " "; // std::cout << t.data<float>()[i] << " ";
} // }
std::cout << std::endl; // std::cout << std::endl;
} // }
} }
} // namespace framework } // namespace framework
......
...@@ -25,6 +25,7 @@ limitations under the License. */ ...@@ -25,6 +25,7 @@ limitations under the License. */
USE_OP(elementwise_add); USE_OP(elementwise_add);
USE_OP(gaussian_random); USE_OP(gaussian_random);
USE_OP(feed); USE_OP(feed);
USE_OP(fetch);
using std::string; using std::string;
using namespace paddle::platform; using namespace paddle::platform;
...@@ -94,6 +95,41 @@ void add_feed_op(string var_name, int index, proto_block* block) { ...@@ -94,6 +95,41 @@ void add_feed_op(string var_name, int index, proto_block* block) {
Out->add_arguments(var_name); Out->add_arguments(var_name);
} }
void add_fetch_op(string var_name, int index, proto_block* block) {
std::vector<int> dim{3};
// insert variable
auto a = block->add_vars();
a->set_name(var_name);
auto a_lt = a->mutable_lod_tensor();
a_lt->set_data_type(paddle::framework::DataType::FP32);
for (int i : dim) {
a_lt->add_dims(i);
}
// insert operation
auto op = block->add_ops();
op->set_type("fetch");
// set dims attr
auto dims = op->add_attrs();
dims->set_name("dims");
dims->set_type(paddle::framework::AttrType::INTS);
for (int i : dim) {
dims->add_ints(i);
}
// set col attr
auto col = op->add_attrs();
col->set_name("col");
col->set_type(paddle::framework::AttrType::INT);
col->set_i(index);
auto Out = op->add_inputs();
Out->set_parameter("Input");
Out->add_arguments(var_name);
}
std::once_flag set_variable_flag; std::once_flag set_variable_flag;
template <typename T> template <typename T>
...@@ -119,6 +155,27 @@ void set_feed_variable(const std::vector<std::vector<T>>& inputs) { ...@@ -119,6 +155,27 @@ void set_feed_variable(const std::vector<std::vector<T>>& inputs) {
} }
} }
template <typename T>
std::vector<std::vector<T>> get_fetch_variable() {
typedef std::vector<paddle::framework::Tensor> FetchOutputs;
Variable* g_fetch_value = GetScope()->FindVar("fetch_value");
FetchOutputs& fetch_outputs = *(g_fetch_value->GetMutable<FetchOutputs>());
auto size = fetch_outputs.size();
std::vector<std::vector<T>> result;
result.reserve(size);
for (size_t i = 0; i < size; i++) {
std::vector<T> tmp;
tmp.reserve(fetch_outputs[i].numel());
memcpy(tmp.data(), fetch_outputs[i].data<T>(),
fetch_outputs[i].numel() * sizeof(T));
result.push_back(tmp);
}
return result;
}
class ExecutorTesterRandom : public ::testing::Test { class ExecutorTesterRandom : public ::testing::Test {
public: public:
virtual void SetUp() override { virtual void SetUp() override {
...@@ -181,6 +238,8 @@ class ExecutorTesterFeed : public ::testing::Test { ...@@ -181,6 +238,8 @@ class ExecutorTesterFeed : public ::testing::Test {
Out->set_parameter("Out"); Out->set_parameter("Out");
Out->add_arguments("c"); Out->add_arguments("c");
add_fetch_op("c", 0, root_block);
std::vector<float> vec1 = {1.0, 2.0, 3.0}; std::vector<float> vec1 = {1.0, 2.0, 3.0};
std::vector<float> vec2 = {4.0, 5.0, 6.0}; std::vector<float> vec2 = {4.0, 5.0, 6.0};
inputs_.push_back(vec1); inputs_.push_back(vec1);
...@@ -213,8 +272,16 @@ TEST_F(ExecutorTesterFeed, CPU) { ...@@ -213,8 +272,16 @@ TEST_F(ExecutorTesterFeed, CPU) {
// 3 mini-batch // 3 mini-batch
for (int i = 0; i < 3; i++) { for (int i = 0; i < 3; i++) {
// need to set feed variable before Executor::Run // need to set feed variable before Executor::Run
std::cout << "start mini-batch " << i << std::endl;
set_feed_variable<float>(inputs_); set_feed_variable<float>(inputs_);
executor->Run(pdesc_, GetScope()); executor->Run(pdesc_, GetScope());
std::vector<std::vector<float>> result = get_fetch_variable<float>();
for (auto& vec : result) {
for (auto& num : vec) {
std::cout << num << " ";
}
std::cout << std::endl;
}
} }
delete executor; delete executor;
......
...@@ -74,7 +74,10 @@ std::unique_ptr<T> make_unique(Args&&... args) { ...@@ -74,7 +74,10 @@ std::unique_ptr<T> make_unique(Args&&... args) {
framework::Scope* GetScope() { framework::Scope* GetScope() {
static std::unique_ptr<framework::Scope> g_scope = static std::unique_ptr<framework::Scope> g_scope =
make_unique<framework::Scope>(); make_unique<framework::Scope>();
std::call_once(feed_variable_flag, [&]() { g_scope->NewVar("feed_value"); }); std::call_once(feed_variable_flag, [&]() {
g_scope->NewVar("feed_value");
g_scope->NewVar("fetch_value");
});
return g_scope.get(); return g_scope.get();
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/activation_op.h" #include "paddle/operators/activation_op.h"
......
...@@ -49,9 +49,9 @@ class FeedOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -49,9 +49,9 @@ class FeedOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("data_type", "output data type") AddAttr<int>("data_type", "output data type")
.SetDefault(framework::DataType::FP32); .SetDefault(framework::DataType::FP32);
AddAttr<int>("col", "The col in global feed variable").SetDefault(0); AddAttr<int>("col", "The col in global feed variable").SetDefault(0);
AddAttr<std::vector<int>>("dims", "The dimension of random tensor."); AddAttr<std::vector<int>>("dims", "The dimension of feed tensor.");
AddOutput("Out", "The output of dropout op."); AddOutput("Out", "The output of feed op.");
AddComment(R"DOC(Feed data to global feed variable)DOC"); AddComment(R"DOC(Feed data from global feed variable)DOC");
} }
}; };
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/fetch_op.h"
namespace paddle {
namespace operators {
class FetchOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
typedef std::vector<framework::Tensor> FetchOutputs;
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null.");
int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_fetch_variable =
framework::GetScope()->FindVar("fetch_value");
FetchOutputs* tensors = g_fetch_variable->GetMutable<FetchOutputs>();
if (tensors->size() < col) {
tensors->resize(col);
}
auto input_dim = ctx->GetInputDim("Input");
framework::Tensor tmp;
tmp.Resize(input_dim);
(*tensors)[col].Resize(input_dim);
// need to handle LodTensor later
}
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("data_type"));
}
};
class FetchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("data_type", "output data type")
.SetDefault(framework::DataType::FP32);
AddAttr<int>("col", "The col in global fetch variable").SetDefault(0);
AddAttr<std::vector<int>>("dims", "The dimension of fetch tensor.");
AddInput("Input", "The output of fetch op.");
AddComment(R"DOC(Fetch data to global fetch variable)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fetch, ops::FetchOp, ops::FetchOpMaker);
REGISTER_OP_CPU_KERNEL(fetch, ops::FetchKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/feed_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class FetchKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
typedef std::vector<framework::Tensor> FetchOutputs;
Tensor* input = ctx.Output<Tensor>("Input");
int col = ctx.template Attr<int>("col");
framework::Variable* g_fetch_variable =
framework::GetScope()->FindVar("fetch_value");
FetchOutputs tensors = g_fetch_variable->Get<FetchOutputs>();
tensors[col].mutable_data<T>(platform::CPUPlace());
tensors[col].CopyFrom<T>(*input, platform::CPUPlace());
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册