提交 4df6cf4d 编写于 作者: Y Yu Yang 提交者: GitHub

Rewrite feed/fetch op (#4815)

* Feed/Fetch op just plain operator, not a OpWithKernel
* Do not register OpInfoMaker since Feed/Fetch will never be
  configured by users
* Feed/Fetch op has empty gradient
* Feed/Fetch op do not hard code `feed_variable`, `fetch_variable` as
  its input and output, make it as a plain Operator input/output
上级 440ad999
......@@ -28,8 +28,8 @@ limitations under the License. */
USE_OP(elementwise_add);
USE_OP(gaussian_random);
USE_OP(feed);
USE_OP(fetch);
USE_NO_KERNEL_OP(feed);
USE_NO_KERNEL_OP(fetch);
USE_OP(mul);
USE_OP(sum);
USE_OP(squared_l2_distance);
......@@ -37,6 +37,9 @@ USE_OP(fill_constant);
USE_OP(mean);
USE_OP(sgd);
constexpr auto kFeedValueName = "feed_value";
constexpr auto kFetchValueName = "fetch_value";
using namespace paddle::platform;
using namespace paddle::framework;
......@@ -77,9 +80,9 @@ void AddOp(const std::string& type, const VariableNameMap& inputs,
template <typename T>
void SetFeedVariable(const std::vector<std::vector<T>>& inputs,
const std::vector<std::vector<int64_t>>& dims) {
Variable* g_feed_value = GetGlobalScope().FindVar("feed_value");
Variable* g_feed_value = GetGlobalScope().FindVar(kFeedValueName);
auto& feed_inputs =
*(g_feed_value->GetMutable<std::vector<paddle::framework::Tensor>>());
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
size_t size = inputs.size();
feed_inputs.resize(size);
for (size_t i = 0; i < size; i++) {
......@@ -92,9 +95,9 @@ void SetFeedVariable(const std::vector<std::vector<T>>& inputs,
// So we can memcpy the data from fetch_value to vector<T>
template <typename T>
std::vector<std::vector<T>> GetFetchVariable() {
Variable* g_fetch_value = GetGlobalScope().FindVar("fetch_value");
Variable* g_fetch_value = GetGlobalScope().FindVar(kFetchValueName);
auto& fetch_outputs =
*(g_fetch_value->GetMutable<std::vector<paddle::framework::Tensor>>());
*(g_fetch_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
size_t size = fetch_outputs.size();
std::vector<std::vector<T>> result;
......@@ -126,8 +129,10 @@ class ExecutorTesterRandom : public ::testing::Test {
{{"dims", std::vector<int>{input_dim, embed_dim}}}, init_root_block);
AddOp("gaussian_random", {}, {{"Out", {"w2"}}},
{{"dims", std::vector<int>{embed_dim, input_dim}}}, init_root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, init_root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, init_root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {{"Out", {kFetchValueName}}},
{{"col", 0}}, init_root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {{"Out", {kFetchValueName}}},
{{"col", 1}}, init_root_block);
// flush
init_program.Proto();
......@@ -143,7 +148,7 @@ class ExecutorTesterRandom : public ::testing::Test {
// feed data
inputs_.push_back({1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
dims_.push_back({batch_size, input_dim});
AddOp("feed", {}, {{"Out", {"a"}}},
AddOp("feed", {{"Input", {kFeedValueName}}}, {{"Out", {"a"}}},
{{"dims", std::vector<int>{batch_size, input_dim}}, {"col", 0}},
root_block);
......@@ -175,9 +180,12 @@ class ExecutorTesterRandom : public ::testing::Test {
{"Grad", {"w2@GRAD"}}},
{{"ParamOut", {"w2"}}}, {}, root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {}, {{"col", 0}}, root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {}, {{"col", 1}}, root_block);
AddOp("fetch", {{"Input", {"l2_distance"}}}, {}, {{"col", 0}}, root_block);
AddOp("fetch", {{"Input", {"w1"}}}, {{"Out", {kFetchValueName}}},
{{"col", 0}}, root_block);
AddOp("fetch", {{"Input", {"w2"}}}, {{"Out", {kFetchValueName}}},
{{"col", 1}}, root_block);
AddOp("fetch", {{"Input", {"l2_distance"}}}, {{"Out", {kFetchValueName}}},
{{"col", 0}}, root_block);
// flush
program.Proto();
......@@ -204,12 +212,14 @@ class ExecutorTesterFeedAndFetch : public ::testing::Test {
std::vector<int> dim{6};
AddOp("feed", {}, {{"Out", {"a"}}}, {{"dims", dim}, {"col", 0}},
root_block);
AddOp("feed", {}, {{"Out", {"b"}}}, {{"dims", dim}, {"col", 1}},
root_block);
AddOp("fetch", {{"Input", {"a"}}}, {}, {{"col", 0}}, root_block);
AddOp("fetch", {{"Input", {"b"}}}, {}, {{"col", 1}}, root_block);
AddOp("feed", {{"Input", {kFeedValueName}}}, {{"Out", {"a"}}},
{{"dims", dim}, {"col", 0}}, root_block);
AddOp("feed", {{"Input", {kFeedValueName}}}, {{"Out", {"b"}}},
{{"dims", dim}, {"col", 1}}, root_block);
AddOp("fetch", {{"Input", {"a"}}}, {{"Out", {kFetchValueName}}},
{{"col", 0}}, root_block);
AddOp("fetch", {{"Input", {"b"}}}, {{"Out", {kFetchValueName}}},
{{"col", 1}}, root_block);
// flush
program.Proto();
......@@ -241,7 +251,6 @@ TEST_F(ExecutorTesterRandom, CPU) {
paddle::memory::Used(cpu_place);
std::unique_ptr<Executor> executor(new Executor(places));
executor->Run(init_pdesc_, &GetGlobalScope(), 0);
SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, &GetGlobalScope(), 0);
......@@ -251,7 +260,7 @@ TEST_F(ExecutorTesterRandom, CPU) {
TEST_F(ExecutorTesterFeedAndFetch, CPU) {
std::vector<Place> places;
CPUPlace cpu_place;
places.push_back(cpu_place);
places.emplace_back(cpu_place);
// We have a global Scope and BuddyAllocator, and we must ensure
// global BuddyAllocator is initialized before global Scope. Thus,
......@@ -265,11 +274,11 @@ TEST_F(ExecutorTesterFeedAndFetch, CPU) {
SetFeedVariable<float>(inputs_, dims_);
executor->Run(pdesc_, &GetGlobalScope(), 0);
std::vector<std::vector<float>> result = GetFetchVariable<float>();
PADDLE_ENFORCE_EQ(result.size(), inputs_.size());
ASSERT_EQ(result.size(), inputs_.size());
for (size_t i = 0; i < result.size(); ++i) {
PADDLE_ENFORCE_EQ(result[i].size(), inputs_[i].size());
ASSERT_EQ(result[i].size(), inputs_[i].size());
for (size_t j = 0; j < result[i].size(); ++j) {
PADDLE_ENFORCE_EQ(result[i][j], inputs_[i][j]);
ASSERT_EQ(result[i][j], inputs_[i][j]);
}
}
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "paddle/framework/lod_tensor.h"
namespace paddle {
namespace framework {
using FeedFetchType = LoDTensor;
using FeedFetchList = std::vector<FeedFetchType>;
} // namespace framework
} // namespace paddle
......@@ -149,5 +149,13 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
}
};
class EmptyGradOpMaker : public GradOpDescMakerBase {
public:
using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDescBind>> operator()() const override {
return {};
}
};
} // namespace framework
} // namespace paddle
......@@ -220,9 +220,12 @@ static InferShapeFuncMap &InferShapeFuncs() {
void OpDescBind::CheckAttrs() {
PADDLE_ENFORCE(!Type().empty(),
"CheckAttr() can not be called before type is setted.");
const auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
PADDLE_ENFORCE_NOT_NULL(checker, "Operator \"%s\" has no registered checker.",
Type());
auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
if (checker == nullptr) {
// checker is not configured. That operator could be generated by Paddle,
// not by users.
return;
}
checker->Check(attrs_);
}
......
......@@ -100,6 +100,22 @@ class Tensor {
inline void CopyFrom(const Tensor& src, const platform::Place& dst_place,
const platform::DeviceContext& ctx);
// FIXME(yuyang18): CopyFrom should without template T, use the replace
// `CopyFrom` with `CopyFromTensor`
inline void CopyFromTensor(const Tensor& src,
const platform::Place& dst_place,
const platform::DeviceContext& ctx) {
// NOLINTNEXTLINES_8 cpplint.py will recognize below lines as functions.
// That is a bug of cpplint.py. Just ignore lint these lines.
if (src.type() == std::type_index(typeid(double))) {
CopyFrom<double>(src, dst_place, ctx);
} else if (src.type() == std::type_index(typeid(float))) {
CopyFrom<float>(src, dst_place, ctx);
} else if (src.type() == std::type_index(typeid(int))) {
CopyFrom<int>(src, dst_place, ctx);
}
}
/**
* @brief Copy the content of an external vector to a tensor.
*
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/feed_op.h"
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
class FeedOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null.");
auto& shape = ctx->Attrs().Get<std::vector<int>>("dims");
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
[](int a) { return static_cast<int64_t>(a); });
ctx->SetOutputDim("Out", framework::make_ddim(shape_int64));
// TODO(qijun): need to handle LodTensor later
}
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("dataType"));
}
};
class FeedOpMaker : public framework::OpProtoAndCheckerMaker {
class FeedOp : public framework::OperatorBase {
public:
FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("dataType", "output data type")
.SetDefault(framework::DataType::FP32);
AddAttr<int>("col", "The col in global feed variable").SetDefault(0);
AddAttr<std::vector<int>>("dims", "The dimension of feed tensor.");
AddOutput("Out", "The output of feed op.");
AddComment(R"DOC(Feed data from global feed variable)DOC");
FeedOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto feed_var_name = Input("Input");
auto *feed_var = scope.FindVar(feed_var_name);
PADDLE_ENFORCE(feed_var != nullptr,
"Cannot find feed_var in scope, feed_var_name is %s",
feed_var_name);
auto out_name = this->Output("Out");
auto *out_var = scope.FindVar(out_name);
PADDLE_ENFORCE(out_var != nullptr,
"Cannot find out_var in scope, out_var_name is %s",
out_name);
auto col = Attr<int>("col");
auto &feed_list = feed_var->Get<framework::FeedFetchList>();
auto &feed_item = feed_list.at(static_cast<size_t>(col));
auto *out_item = out_var->GetMutable<framework::FeedFetchType>();
out_item->CopyFromTensor(feed_item, dev_ctx.GetPlace(), dev_ctx);
out_item->set_lod(feed_item.lod());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(feed, ops::FeedOp, ops::FeedOpMaker);
REGISTER_OP_CPU_KERNEL(feed, ops::FeedKernel<float>);
// We do not need to register OpInfoMaker,
// since feed operator will not be used by end users directly
REGISTER_OPERATOR(feed, paddle::operators::FeedOp,
paddle::framework::EmptyGradOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/feed_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(feed, ops::FeedKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class FeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
framework::Variable* g_feed_variable =
framework::GetGlobalScope().FindVar("feed_value");
const auto& tensors =
g_feed_variable->Get<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col");
PADDLE_ENFORCE_GT(tensors.size(), static_cast<size_t>(col));
// TODO(qijun):
// check tensors[col].dims() with attribute,
// except the first dimenson.
out->CopyFrom<T>(tensors[col], ctx.GetPlace(), ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/fetch_op.h"
#include "paddle/framework/feed_fetch_type.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
class FetchOp : public framework::OperatorWithKernel {
class FetchOp : public framework::OperatorBase {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
FetchOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should be not null.");
}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto fetch_var_name = Input("Input");
auto *fetch_var = scope.FindVar(fetch_var_name);
PADDLE_ENFORCE(fetch_var != nullptr,
"Cannot find fetch variable in scope, fetch_var_name is %s",
fetch_var_name);
auto out_name = this->Output("Out");
auto *out_var = scope.FindVar(out_name);
PADDLE_ENFORCE(out_var != nullptr,
"Cannot find out_var in scope, out_var_name is %s",
out_name);
auto col = static_cast<size_t>(Attr<int>("col"));
framework::DataType IndicateDataType(
const framework::ExecutionContext& ctx) const override {
return static_cast<framework::DataType>(Attr<int>("dataType"));
auto *fetch_list = out_var->GetMutable<framework::FeedFetchList>();
auto &src_item = fetch_var->Get<framework::FeedFetchType>();
if (col >= fetch_list->size()) {
fetch_list->resize(col + 1);
}
};
auto &dst_item = fetch_list->at(col);
class FetchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FetchOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("dataType", "output data type")
.SetDefault(framework::DataType::FP32);
AddAttr<int>("col", "The col in global fetch variable").SetDefault(0);
AddInput("Input", "The output of fetch op.");
AddComment(R"DOC(Fetch data to global fetch variable)DOC");
// FIXME(yuyang18): Should we assume the fetch operator always generate
// CPU outputs?
dst_item.CopyFromTensor(src_item, platform::CPUPlace(), dev_ctx);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(fetch, ops::FetchOp, ops::FetchOpMaker);
REGISTER_OP_CPU_KERNEL(fetch, ops::FetchKernel<float>);
// We do not need to register OpInfoMaker,
// since fetch operator will not be used by end users directly
REGISTER_OPERATOR(fetch, paddle::operators::FetchOp,
paddle::framework::EmptyGradOpMaker);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/fetch_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(fetch, ops::FetchKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class FetchKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor* input = ctx.Input<framework::Tensor>("Input");
framework::Variable* g_fetch_variable =
framework::GetGlobalScope().FindVar("fetch_value");
auto* tensors =
g_fetch_variable->GetMutable<std::vector<framework::Tensor>>();
int col = ctx.template Attr<int>("col");
if (tensors->size() < static_cast<size_t>(col + 1)) {
tensors->resize(col + 1);
}
PADDLE_ENFORCE_GT(tensors->size(), static_cast<size_t>(col));
(*tensors)[col].Resize(input->dims());
(*tensors)[col].mutable_data<T>(platform::CPUPlace());
(*tensors)[col].CopyFrom<T>(*input, platform::CPUPlace(),
ctx.device_context());
// TODO(qijun): need to handle LodTensor later
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册