未验证 提交 d44d1730 编写于 作者: W Wojciech Uss 提交者: GitHub

fix cache key in concat oneDNN kernel (#31820) (#31837)

* fix cache key in concat oneDNN kernel

* key simplified
上级 aa731e63
...@@ -71,6 +71,15 @@ static const std::vector<const Tensor*> ReduceMultiInput( ...@@ -71,6 +71,15 @@ static const std::vector<const Tensor*> ReduceMultiInput(
return reduced; return reduced;
} }
static const std::vector<int> GetDimsForKey(
const std::vector<const Tensor*>& inputs) {
auto dims_key = paddle::framework::vectorize<int>(inputs[0]->dims());
for (auto it = std::next(inputs.begin()); it != inputs.end(); ++it) {
dims_key.push_back((*it)->dims()[0]);
}
return dims_key;
}
template <typename T> template <typename T>
class ConcatPrimitiveFactory { class ConcatPrimitiveFactory {
public: public:
...@@ -134,6 +143,8 @@ template <typename T> ...@@ -134,6 +143,8 @@ template <typename T>
class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
// If any of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change
auto multi_input = ReduceMultiInput(ctx.MultiInput<Tensor>("X")); auto multi_input = ReduceMultiInput(ctx.MultiInput<Tensor>("X"));
EnforceLayouts(multi_input); EnforceLayouts(multi_input);
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
...@@ -156,12 +167,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -156,12 +167,9 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(multi_input[0]->type()); paddle::framework::ToMKLDNNDataType(multi_input[0]->type());
ConcatPrimitiveFactory<T> prim_creator; ConcatPrimitiveFactory<T> prim_creator;
// If one of the multiple inputs of concat has an input size of 0, the std::string key =
// actual size of the multi_input will change platform::CreateKey(dev_ctx, GetDimsForKey(multi_input),
std::string key = platform::CreateKey( multi_input.size(), ctx.OutputName("Out"), dt);
dev_ctx, paddle::framework::vectorize<int>(multi_input[0]->dims()),
multi_input.size(), ctx.OutputName("Out"), dt,
platform::ThreadIDasStr());
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key); key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@concat_p"; const std::string key_prim = key + "@concat_p";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册