Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
1226cc01
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1226cc01
编写于
6月 16, 2021
作者:
W
wanghuancoder
提交者:
GitHub
6月 16, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
runtimecontext (#33608)
* runtimecontext * ExecutionContextV2 * refine * refine * pass test
上级
bc8a8042
变更
4
展开全部
显示空白变更内容
内联
并排
Showing
4 changed file
with
648 addition
and
567 deletion
+648
-567
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/new_exec.h
paddle/fluid/framework/new_exec.h
+586
-461
paddle/fluid/framework/new_exec_test.cc
paddle/fluid/framework/new_exec_test.cc
+60
-57
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+1
-48
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
1226cc01
...
@@ -409,7 +409,7 @@ cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framewo
...
@@ -409,7 +409,7 @@ cc_library(custom_operator SRCS custom_operator.cc DEPS tensor attribute framewo
cc_test
(
custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog
)
cc_test
(
custom_tensor_test SRCS custom_tensor_test.cc DEPS custom_tensor glog
)
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
#cc_binary(test_executor SRCS test_executor.cc DEPS executor op_registry ${GLOB_OP_LIB} ${GLOB_OPERATOR_DEPS} )
cc_binary
(
new_executor SRCS new_exec_test.cc DEPS operator op_registry executor
${
GLOB_OP_LIB
}
${
GLOB_OPERATOR_DEPS
}
profiler
)
cc_binary
(
new_executor SRCS new_exec_test.cc DEPS operator op_registry executor
${
GLOB_OP_LIB
}
${
GLOB_OPERATOR_DEPS
}
profiler
place
)
set
(
FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator
)
set
(
FLUID_FRAMEWORK_MODULES proto_desc memory lod_tensor executor data_feed_proto layer dynamic_loader custom_operator
)
...
...
paddle/fluid/framework/new_exec.h
浏览文件 @
1226cc01
此差异已折叠。
点击以展开。
paddle/fluid/framework/new_exec_test.cc
浏览文件 @
1226cc01
#include <iostream>
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#include <string>
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gperftools/profiler.h>
#include <chrono>
#include <iostream>
#include <map>
#include <map>
#include <memory>
#include <memory>
#include <string>
#include <string>
...
@@ -9,69 +23,58 @@
...
@@ -9,69 +23,58 @@
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/new_exec.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/pybind/pybind.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/platform/init.h"
#include "paddle/fluid/framework/new_exec.h"
#include "paddle/fluid/pybind/pybind.h"
#include <chrono>
#include <gperftools/profiler.h>
int
main
()
int
main
()
{
{
paddle
::
framework
::
InitDevices
();
paddle
::
framework
::
InitDevices
();
paddle
::
framework
::
VariableScope
global_scope
;
paddle
::
framework
::
VariableScope
global_scope
;
auto
place
=
paddle
::
platform
::
CUDAPlace
(
0
);
{
{
auto
test_prog
=
paddle
::
framework
::
load_from_file
(
"lm_startup_program"
);
auto
test_prog
=
paddle
::
framework
::
load_from_file
(
"lm_startup_program"
);
paddle
::
framework
::
build_variable_scope
(
test_prog
,
&
global_scope
);
paddle
::
framework
::
build_variable_scope
(
test_prog
,
&
global_scope
);
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_func_list
;
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_func_list
;
std
::
vector
<
std
::
unique_ptr
<
paddle
::
framework
::
OperatorBase
>>
op_list
;
std
::
vector
<
paddle
::
framework
::
OperatorBase
*>
op_list
;
paddle
::
framework
::
build_op_func_list
(
test_prog
,
op_list
,
vec_func_list
,
global_scope
);
paddle
::
framework
::
build_op_func_list
(
test_prog
,
op_list
,
vec_func_list
,
&
global_scope
,
place
);
paddle
::
framework
::
exec_op_func_list
(
vec_func_list
,
op_list
,
global_scope
);
paddle
::
framework
::
exec_op_func_list
(
vec_func_list
,
op_list
,
global_scope
,
place
);
}
}
cerr
<<
"run main"
<<
endl
;
std
::
cerr
<<
"run main"
<<
std
::
endl
;
auto
main_prog
=
paddle
::
framework
::
load_from_file
(
"lm_main_program"
);
auto
main_prog
=
paddle
::
framework
::
load_from_file
(
"lm_main_program"
);
paddle
::
framework
::
build_variable_scope
(
main_prog
,
&
global_scope
);
paddle
::
framework
::
build_variable_scope
(
main_prog
,
&
global_scope
);
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_main_func_list
;
std
::
vector
<
paddle
::
framework
::
OpFuncNode
>
vec_main_func_list
;
std
::
vector
<
std
::
unique_ptr
<
paddle
::
framework
::
OperatorBase
>>
op_main_list
;
std
::
vector
<
paddle
::
framework
::
OperatorBase
*>
op_main_list
;
paddle
::
framework
::
build_op_func_list
(
main_prog
,
op_main_list
,
vec_main_func_list
,
global_scope
);
paddle
::
framework
::
build_op_func_list
(
main_prog
,
op_main_list
,
vec_main_func_list
,
&
global_scope
,
place
);
auto
start
=
std
::
chrono
::
steady_clock
::
now
();
auto
start
=
std
::
chrono
::
steady_clock
::
now
();
ProfilerStart
(
"new_executor.prof"
);
// ProfilerStart("new_executor.prof");
for
(
size_t
i
=
0
;
i
<
2320
;
++
i
)
for
(
size_t
i
=
0
;
i
<
2320
;
++
i
)
{
{
if
(
i
%
200
==
0
)
{
if
(
i
%
200
==
0
)
std
::
cerr
<<
i
<<
std
::
endl
;
{
cerr
<<
i
<<
endl
;
}
}
paddle
::
framework
::
exec_op_func_list
(
vec_main_func_list
,
op_main_list
,
global_scope
);
paddle
::
framework
::
exec_op_func_list
(
vec_main_func_list
,
op_main_list
,
global_scope
,
place
);
33
}
}
ProfilerStop
();
//
ProfilerStop();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
auto
end
=
std
::
chrono
::
steady_clock
::
now
();
std
::
chrono
::
duration
<
double
>
diff
=
end
-
start
;
std
::
chrono
::
duration
<
double
>
diff
=
end
-
start
;
cerr
<<
"time cost "
<<
diff
.
count
()
<<
endl
;
std
::
cerr
<<
"time cost "
<<
diff
.
count
()
<<
std
::
endl
;
return
1
;
return
1
;
}
}
paddle/fluid/framework/operator.h
浏览文件 @
1226cc01
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -586,7 +583,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -586,7 +583,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
public:
public:
RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx)
RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx)
: op_(op), ctx_(ctx) {}
: op_(op), ctx_(ctx) {}
bool HasInput(const std::string& name) const override {
bool HasInput(const std::string& name) const override {
// has only one input
// has only one input
const auto& ins = ctx_.inputs;
const auto& ins = ctx_.inputs;
...
@@ -602,7 +598,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -602,7 +598,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
"Input %s should not contain more than one inputs.", name));
"Input %s should not contain more than one inputs.", name));
return in[0] != nullptr;
return in[0] != nullptr;
}
}
bool HasOutput(const std::string& name) const override {
bool HasOutput(const std::string& name) const override {
// has only one output
// has only one output
const auto& outs = ctx_.outputs;
const auto& outs = ctx_.outputs;
...
@@ -620,7 +615,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -620,7 +615,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
"Output %s should not contain more than one outputs.", name));
"Output %s should not contain more than one outputs.", name));
return out[0] != nullptr;
return out[0] != nullptr;
}
}
bool HasInputs(const std::string& name) const override {
bool HasInputs(const std::string& name) const override {
const auto& ins = ctx_.inputs;
const auto& ins = ctx_.inputs;
auto it = ins.find(name);
auto it = ins.find(name);
...
@@ -634,7 +628,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -634,7 +628,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}
return true;
return true;
}
}
bool HasOutputs(const std::string& name) const override {
bool HasOutputs(const std::string& name) const override {
const auto& outs = ctx_.outputs;
const auto& outs = ctx_.outputs;
auto it = outs.find(name);
auto it = outs.find(name);
...
@@ -648,17 +641,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -648,17 +641,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
}
}
return true;
return true;
}
}
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
std::vector<std::string> Inputs(const std::string& name) const override {
std::vector<std::string> Inputs(const std::string& name) const override {
return op_.Inputs(name);
return op_.Inputs(name);
}
}
std::vector<std::string> Outputs(const std::string& name) const override {
std::vector<std::string> Outputs(const std::string& name) const override {
return op_.Outputs(name);
return op_.Outputs(name);
}
}
std::string GetInputNameByIdx(size_t idx) const override {
std::string GetInputNameByIdx(size_t idx) const override {
auto& op_proto =
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
...
@@ -669,7 +658,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -669,7 +658,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
op_.Type(), idx, op_proto->inputs().size()));
op_.Type(), idx, op_proto->inputs().size()));
return op_proto->inputs()[idx].name();
return op_proto->inputs()[idx].name();
}
}
std::string GetOutputNameByIdx(size_t idx) const override {
std::string GetOutputNameByIdx(size_t idx) const override {
auto& op_proto =
auto& op_proto =
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
...
@@ -681,7 +669,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -681,7 +669,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
op_.Type(), idx, op_proto->outputs().size()));
op_.Type(), idx, op_proto->outputs().size()));
return op_proto->outputs()[idx].name();
return op_proto->outputs()[idx].name();
}
}
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) override {
size_t j = 0) override {
auto in_it = ctx_.inputs.find(in);
auto in_it = ctx_.inputs.find(in);
...
@@ -702,16 +689,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -702,16 +689,13 @@ class RuntimeInferShapeContext : public InferShapeContext {
"The index of output dimension is out of range, "
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
"excepted index less than %zu, but received %zu.",
out_it->second.size(), j));
out_it->second.size(), j));
Variable* in_var = in_it->second[i];
Variable* in_var = in_it->second[i];
Variable* out_var = out_it->second[j];
Variable* out_var = out_it->second[j];
PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(
in_var->Type(), out_var->Type(),
in_var->Type(), out_var->Type(),
platform::errors::InvalidArgument(
platform::errors::InvalidArgument(
"The type of input (%s) and output (%s) are inconsistent.", in,
"The type of input (%s) and output (%s) are inconsistent.", in,
out));
out));
if (in_var->IsType<framework::SelectedRows>()) {
if (in_var->IsType<framework::SelectedRows>()) {
auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
...
@@ -728,7 +712,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -728,7 +712,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
"or SelectedRows."));
"or SelectedRows."));
}
}
}
}
void ShareAllLoD(const std::string& in,
void ShareAllLoD(const std::string& in,
const std::string& out) const override {
const std::string& out) const override {
auto in_it = ctx_.inputs.find(in);
auto in_it = ctx_.inputs.find(in);
...
@@ -740,23 +723,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -740,23 +723,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_it, ctx_.outputs.end(),
out_it, ctx_.outputs.end(),
platform::errors::NotFound("Output [%s] found error in Op [%s]", out,
platform::errors::NotFound("Output [%s] found error in Op [%s]", out,
op_.Type()));
op_.Type()));
auto& in_var_list = in_it->second;
auto& in_var_list = in_it->second;
auto& out_var_list = out_it->second;
auto& out_var_list = out_it->second;
PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(
in_var_list.size(), out_var_list.size(),
in_var_list.size(), out_var_list.size(),
platform::errors::PreconditionNotMet(
platform::errors::PreconditionNotMet(
"Op [%s]: Input var size should be equal with output var size",
"Op [%s]: Input var size should be equal with output var size",
op_.Type()));
op_.Type()));
auto& out_var_names = op_.Outputs(out);
auto& out_var_names = op_.Outputs(out);
for (size_t i = 0; i < in_var_list.size(); ++i) {
for (size_t i = 0; i < in_var_list.size(); ++i) {
if (out_var_names[i] == framework::kEmptyVarName) {
if (out_var_names[i] == framework::kEmptyVarName) {
continue;
continue;
}
}
Variable* in_var = in_var_list[i];
Variable* in_var = in_var_list[i];
if (!in_var->IsType<LoDTensor>()) return;
if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = out_var_list[i];
Variable* out_var = out_var_list[i];
...
@@ -773,7 +751,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -773,7 +751,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
out_tensor->set_layout(in_tensor.layout());
out_tensor->set_layout(in_tensor.layout());
}
}
}
}
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
size_t j = 0) const override {
size_t j = 0) const override {
auto in_it = ctx_.inputs.find(in);
auto in_it = ctx_.inputs.find(in);
...
@@ -794,7 +771,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -794,7 +771,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
"The index of output dimension is out of range, "
"The index of output dimension is out of range, "
"excepted index less than %zu, but received %zu.",
"excepted index less than %zu, but received %zu.",
out_it->second.size(), j));
out_it->second.size(), j));
Variable* in_var = in_it->second.at(i);
Variable* in_var = in_it->second.at(i);
if (!in_var->IsType<LoDTensor>()) return;
if (!in_var->IsType<LoDTensor>()) return;
Variable* out_var = out_it->second.at(j);
Variable* out_var = out_it->second.at(j);
...
@@ -805,7 +781,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -805,7 +781,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
auto& in_tensor = in_var->Get<LoDTensor>();
auto& in_tensor = in_var->Get<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
auto* out_tensor = out_var->GetMutable<LoDTensor>();
out_tensor->set_lod(in_tensor.lod());
out_tensor->set_lod(in_tensor.lod());
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out Tensor?
// Shall we have a better method to shared info between in/out Tensor?
...
@@ -826,14 +801,12 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -826,14 +801,12 @@ class RuntimeInferShapeContext : public InferShapeContext {
#endif
#endif
out_tensor->set_layout(in_tensor.layout());
out_tensor->set_layout(in_tensor.layout());
}
}
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetLoDLevel is only used in compile time. The calculation of "
"GetLoDLevel is only used in compile time. The calculation of "
"output's actual lod is different among operators so that should be "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
"set in the runtime kernel."));
}
}
void SetLoDLevel(const std::string& out, int32_t lod_level,
void SetLoDLevel(const std::string& out, int32_t lod_level,
size_t j = 0) const override {
size_t j = 0) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
PADDLE_THROW(platform::errors::PreconditionNotMet(
...
@@ -841,9 +814,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -841,9 +814,7 @@ class RuntimeInferShapeContext : public InferShapeContext {
"output's actual lod is different among operators so that should be "
"output's actual lod is different among operators so that should be "
"set in the runtime kernel."));
"set in the runtime kernel."));
}
}
bool IsRuntime() const override { return true; }
bool IsRuntime() const override { return true; }
// TODO(paddle-dev): Can this be template?
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) override {
const std::string& name) override {
...
@@ -853,7 +824,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -853,7 +824,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
res.insert(res.begin(), vars.begin(), vars.end());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
return res;
}
}
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
std::vector<InferShapeVarPtr> GetOutputVarPtrs(
const std::string& name) override {
const std::string& name) override {
const std::vector<Variable*>& vars = OutputVars(name);
const std::vector<Variable*>& vars = OutputVars(name);
...
@@ -862,7 +832,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -862,7 +832,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
res.insert(res.begin(), vars.begin(), vars.end());
res.insert(res.begin(), vars.begin(), vars.end());
return res;
return res;
}
}
DDim GetInputDim(const std::string& name) const override {
DDim GetInputDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
const std::vector<Variable*>& vars = InputVars(name);
PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(
...
@@ -872,22 +841,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -872,22 +841,18 @@ class RuntimeInferShapeContext : public InferShapeContext {
name, vars.size()));
name, vars.size()));
return this->GetDim(vars[0]);
return this->GetDim(vars[0]);
}
}
std::vector<DDim> GetInputsDim(const std::string& name) const override {
std::vector<DDim> GetInputsDim(const std::string& name) const override {
const std::vector<Variable*>& vars = InputVars(name);
const std::vector<Variable*>& vars = InputVars(name);
return GetDims(vars);
return GetDims(vars);
}
}
std::vector<proto::VarType::Type> GetInputsVarType(
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string& name) const override {
const std::string& name) const override {
return GetVarTypes(InputVars(name));
return GetVarTypes(InputVars(name));
}
}
std::vector<proto::VarType::Type> GetOutputsVarType(
std::vector<proto::VarType::Type> GetOutputsVarType(
const std::string& name) const override {
const std::string& name) const override {
return GetVarTypes(OutputVars(name));
return GetVarTypes(OutputVars(name));
}
}
void SetOutputDim(const std::string& name, const DDim& dim) override {
void SetOutputDim(const std::string& name, const DDim& dim) override {
auto& vars = OutputVars(name);
auto& vars = OutputVars(name);
PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(
...
@@ -897,13 +862,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -897,13 +862,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
name, vars.size()));
name, vars.size()));
SetDim(vars[0], dim);
SetDim(vars[0], dim);
}
}
void SetOutputsDim(const std::string& name,
void SetOutputsDim(const std::string& name,
const std::vector<DDim>& dims) override {
const std::vector<DDim>& dims) override {
auto& vars = OutputVars(name);
auto& vars = OutputVars(name);
SetDims(vars, dims);
SetDims(vars, dims);
}
}
protected:
protected:
DDim GetDim(Variable* var) const {
DDim GetDim(Variable* var) const {
PADDLE_ENFORCE_NOT_NULL(
PADDLE_ENFORCE_NOT_NULL(
...
@@ -919,7 +882,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -919,7 +882,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
ToTypeName(var->Type())));
ToTypeName(var->Type())));
}
}
}
}
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const {
std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const {
std::vector<DDim> ret;
std::vector<DDim> ret;
ret.reserve(vars.size());
ret.reserve(vars.size());
...
@@ -927,12 +889,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -927,12 +889,10 @@ class RuntimeInferShapeContext : public InferShapeContext {
[this](Variable* var) { return this->GetDim(var); });
[this](Variable* var) { return this->GetDim(var); });
return ret;
return ret;
}
}
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
PADDLE_THROW(platform::errors::PreconditionNotMet(
"GetRepeatedDims method only ban be used in compile time."));
"GetRepeatedDims method only ban be used in compile time."));
}
}
void SetDim(Variable* var, const DDim& dim) {
void SetDim(Variable* var, const DDim& dim) {
if (var->IsType<LoDTensor>()) {
if (var->IsType<LoDTensor>()) {
var->GetMutable<LoDTensor>()->Resize(dim);
var->GetMutable<LoDTensor>()->Resize(dim);
...
@@ -945,7 +905,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -945,7 +905,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
ToTypeName(var->Type())));
ToTypeName(var->Type())));
}
}
}
}
void SetDims(const std::vector<Variable*>& vars,
void SetDims(const std::vector<Variable*>& vars,
const std::vector<DDim>& dims) {
const std::vector<DDim>& dims) {
size_t length = vars.size();
size_t length = vars.size();
...
@@ -962,13 +921,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -962,13 +921,11 @@ class RuntimeInferShapeContext : public InferShapeContext {
SetDim(vars[i], dims[i]);
SetDim(vars[i], dims[i]);
}
}
}
}
void SetRepeatedDims(const std::string& name,
void SetRepeatedDims(const std::string& name,
const std::vector<DDim>& dims) override {
const std::vector<DDim>& dims) override {
PADDLE_THROW(platform::errors::PreconditionNotMet(
PADDLE_THROW(platform::errors::PreconditionNotMet(
"SetRepeatedDims method only can be used in compile time."));
"SetRepeatedDims method only can be used in compile time."));
}
}
std::vector<proto::VarType::Type> GetVarTypes(
std::vector<proto::VarType::Type> GetVarTypes(
const std::vector<Variable*>& vars) const {
const std::vector<Variable*>& vars) const {
std::vector<proto::VarType::Type> retv;
std::vector<proto::VarType::Type> retv;
...
@@ -978,11 +935,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -978,11 +935,9 @@ class RuntimeInferShapeContext : public InferShapeContext {
this, std::placeholders::_1));
this, std::placeholders::_1));
return retv;
return retv;
}
}
proto::VarType::Type GetVarType(Variable* var) const {
proto::VarType::Type GetVarType(Variable* var) const {
return ToVarType(var->Type());
return ToVarType(var->Type());
}
}
private:
private:
const std::vector<Variable*>& InputVars(const std::string& name) const {
const std::vector<Variable*>& InputVars(const std::string& name) const {
auto it = ctx_.inputs.find(name);
auto it = ctx_.inputs.find(name);
...
@@ -992,7 +947,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -992,7 +947,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
"Operator (%s) does not have the input (%s).", op_.Type(), name));
"Operator (%s) does not have the input (%s).", op_.Type(), name));
return it->second;
return it->second;
}
}
const std::vector<Variable*>& OutputVars(const std::string& name) const {
const std::vector<Variable*>& OutputVars(const std::string& name) const {
auto it = ctx_.outputs.find(name);
auto it = ctx_.outputs.find(name);
PADDLE_ENFORCE_NE(
PADDLE_ENFORCE_NE(
...
@@ -1001,7 +955,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
...
@@ -1001,7 +955,6 @@ class RuntimeInferShapeContext : public InferShapeContext {
"Operator (%s) does not have the outputs (%s).", op_.Type(), name));
"Operator (%s) does not have the outputs (%s).", op_.Type(), name));
return it->second;
return it->second;
}
}
const OperatorBase& op_;
const OperatorBase& op_;
const RuntimeContext& ctx_;
const RuntimeContext& ctx_;
};
};
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录