未验证 提交 a38c1512 编写于 作者: F fengjiayi 提交者: GitHub

Add GetInputsElementDim (#6091)

上级 1238706d
......@@ -22,6 +22,12 @@ std::vector<framework::DDim> InferShapeContext::GetInputsDim(
return GetDims(names);
}
DDim InferShapeContext::GetInputsElementDim(const std::string &name,
int idx) const {
const std::vector<std::string> &names = Inputs(name);
return this->GetDim(names[idx]);
}
void InferShapeContext::SetOutputsDim(
const std::string &name, const std::vector<framework::DDim> &dims) {
auto &names = Outputs(name);
......
......@@ -37,6 +37,7 @@ class InferShapeContext {
virtual framework::DDim GetInputDim(const std::string &name) const = 0;
std::vector<framework::DDim> GetInputsDim(const std::string &name) const;
DDim GetInputsElementDim(const std::string &name, int idx) const;
virtual void SetOutputDim(const std::string &name, const DDim &dim) = 0;
void SetOutputsDim(const std::string &name,
......
......@@ -287,7 +287,6 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
auto p_names = ctx->Inputs(kParameters);
auto pg_names = ctx->Outputs(kParamGrads);
auto dims = ctx->GetInputsDim(kParameters);
auto var_types = ctx->GetInputsVarType(kParameters);
std::vector<std::string> names_to_set;
std::vector<framework::DDim> dims_to_set;
......@@ -295,13 +294,14 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
if (pg_names[i] == framework::kEmptyVarName) {
continue;
}
auto dims = ctx->GetInputsElementDim(kParameters, i);
if (var_types[i] == framework::VarDesc::LOD_TENSOR) {
names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims[i]);
dims_to_set.push_back(dims);
} else if (var_types[i] == framework::VarDesc::LOD_TENSOR_ARRAY) {
// not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims[i]);
dims_to_set.push_back(dims);
}
}
ctx->SetDims(names_to_set, dims_to_set);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册