提交 623848af 编写于 作者: Q qijun

add feed operator

上级 3014f6a1
...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include <memory> // for unique_ptr
#include <mutex> // for call_once
#include "paddle/string/printf.h" #include "paddle/string/printf.h"
namespace paddle { namespace paddle {
...@@ -62,5 +64,19 @@ void Scope::DropKids() { ...@@ -62,5 +64,19 @@ void Scope::DropKids() {
kids_.clear(); kids_.clear();
} }
std::once_flag feed_variable_flag;
template <typename T, typename... Args>
std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
framework::Scope* GetScope() {
static std::unique_ptr<framework::Scope> g_scope =
make_unique<framework::Scope>();
std::call_once(feed_variable_flag, [&]() { g_scope->NewVar("feed_value"); });
return g_scope.get();
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -73,5 +73,7 @@ class Scope { ...@@ -73,5 +73,7 @@ class Scope {
DISABLE_COPY_AND_ASSIGN(Scope); DISABLE_COPY_AND_ASSIGN(Scope);
}; };
framework::Scope* GetScope();
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
You may obtain a copy of the License at You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/activation_op.h" #include "paddle/operators/activation_op.h"
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/feed_op.h"
namespace paddle {
namespace operators {
class FeedOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
typedef std::vector<framework::Tensor> FeedInputs;
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output should be not null.");
int col = ctx->Attrs().Get<int>("col");
framework::Variable* g_feed_variable =
framework::GetScope()->FindVar("feed_value");
FeedInputs tensors = g_feed_variable->Get<FeedInputs>();
auto in_dim = tensors[col].dims();
ctx->SetOutputDim("Y", in_dim);
// need to handle LodTensor later
}
};
class FeedOpMaker : public framework::OpProtoAndCheckerMaker {
public:
FeedOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddAttr<int>("col", "The col in Global Feed Variable");
AddOutput("Out", "The output of dropout op.");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(feed, ops::FeedOp, ops::FeedOpMaker);
REGISTER_OP_CPU_KERNEL(feed, ops::FeedKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/feed_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(feed, ops::FeedKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class FeedKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
typedef std::vector<framework::Tensor> FeedInputs;
Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
framework::Variable* g_feed_variable =
framework::GetScope()->FindVar("feed_value");
int col = ctx.template Attr<int>("col");
FeedInputs tensors = g_feed_variable->Get<FeedInputs>();
out->CopyFrom<T>(tensors[col], ctx.GetPlace());
}
};
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册