未验证 提交 d44d1730 编写于 作者: W Wojciech Uss 提交者: GitHub

fix cache key in concat oneDNN kernel (#31820) (#31837)

* fix cache key in concat oneDNN kernel

* key simplified
上级 aa731e63
......@@ -71,6 +71,15 @@ static const std::vector<const Tensor*> ReduceMultiInput(
return reduced;
}
static const std::vector<int> GetDimsForKey(
const std::vector<const Tensor*>& inputs) {
auto dims_key = paddle::framework::vectorize<int>(inputs[0]->dims());
for (auto it = std::next(inputs.begin()); it != inputs.end(); ++it) {
dims_key.push_back((*it)->dims()[0]);
}
return dims_key;
}
template <typename T>
class ConcatPrimitiveFactory {
public:
......@@ -134,6 +143,8 @@ template <typename T>
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
// If any of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change
auto multi_input = ReduceMultiInput(ctx.MultiInput<Tensor>("X"));
EnforceLayouts(multi_input);
Tensor* output = ctx.Output<Tensor>("Out");
......@@ -156,12 +167,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(multi_input[0]->type());
ConcatPrimitiveFactory<T> prim_creator;
// If one of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change
std::string key = platform::CreateKey(
dev_ctx, paddle::framework::vectorize<int>(multi_input[0]->dims()),
multi_input.size(), ctx.OutputName("Out"), dt,
platform::ThreadIDasStr());
std::string key =
platform::CreateKey(dev_ctx, GetDimsForKey(multi_input),
multi_input.size(), ctx.OutputName("Out"), dt);
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@concat_p";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册