未验证 提交 281644cd 编写于 作者: C Chen Weihang 提交者: GitHub

Fix mkldnn invalid infershape impl (#38837)

* fix mkldnn invalid infershape

* add unittest for mkldnn in new executor

* add import os
上级 5e515781
......@@ -31,15 +31,18 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
using DDim = paddle::framework::DDim;
public:
EagerInferShapeContext(const NameTensorMap* in, const NameTensorMap* out,
EagerInferShapeContext(
const NameTensorMap* in, const NameTensorMap* out,
const paddle::framework::AttributeMap* attr,
const paddle::framework::AttributeMap* default_attr,
const std::string op_type)
const std::string op_type,
const paddle::framework::OpKernelType* op_kernel_type = nullptr)
: tensor_in_(in),
tensor_out_(out),
attrs_(attr),
default_attrs_(default_attr),
op_type_(op_type) {}
op_type_(op_type),
op_kernel_type_(op_kernel_type) {}
bool HasInput(const std::string& name) const override {
// has only one input
......@@ -214,6 +217,11 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
bool IsRuntime() const override { return true; }
bool IsRunMKLDNNKernel() const override {
return (op_kernel_type_ && (op_kernel_type_->data_layout_ ==
paddle::framework::DataLayout::kMKLDNN));
}
// TODO(paddle-dev): Can this be template?
std::vector<paddle::framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
......@@ -400,6 +408,7 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
const paddle::framework::AttributeMap* attrs_;
const paddle::framework::AttributeMap* default_attrs_;
const std::string op_type_;
const paddle::framework::OpKernelType* op_kernel_type_;
};
} // namespace legacy
......
......@@ -173,7 +173,7 @@ static void PreparedOpRunImpl(
paddle::framework::Scope scope;
EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs,
op.Type());
op.Type(), &kernel_type);
op.Info().infer_shape_(&infer_shape_ctx);
func(EagerExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs,
......
......@@ -307,6 +307,17 @@ void InterpretercoreInferShapeContext::SetLoDLevel(const std::string& out,
bool InterpretercoreInferShapeContext::IsRuntime() const { return true; }
bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const {
try {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
framework::DataLayout::kMKLDNN));
} catch (std::bad_cast exp) {
return false;
}
}
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> InterpretercoreInferShapeContext::GetInputVarPtrs(
const std::string& name) const {
......
......@@ -84,6 +84,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool IsRuntime() const override;
bool IsRunMKLDNNKernel() const override;
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override;
......
......@@ -240,6 +240,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override;
bool IsRunMKLDNNKernel() const override;
std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const override {
return GetVarTypes(Inputs(name));
......@@ -930,6 +932,8 @@ void CompileTimeInferShapeContext::SetRepeatedDims(
bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
bool CompileTimeInferShapeContext::IsRunMKLDNNKernel() const { return false; }
proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
const std::string &name) const {
return block_.FindVarRecursive(name)->GetType();
......
......@@ -884,6 +884,17 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override { return true; }
bool IsRunMKLDNNKernel() const override {
try {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
framework::DataLayout::kMKLDNN));
} catch (std::bad_cast exp) {
return false;
}
}
// TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
......@@ -1178,9 +1189,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("infer_shape",
platform::EventRole::kInnerOp);
RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx);
// TODO(chenweihang): replace this after removing `this->IsMKLDNNType()`
// in some mkldnn infershape functions, such conv2d infershape
this->InferShape(&infer_shape_ctx);
this->Info().infer_shape_(&infer_shape_ctx);
}
if (FLAGS_enable_unused_var_check) {
......
......@@ -528,11 +528,6 @@ class OperatorWithKernel : public OperatorBase {
return g_all_op_kernels;
}
bool IsMKLDNNType() const {
return ((this->kernel_type_) && (this->kernel_type_->data_layout_ ==
framework::DataLayout::kMKLDNN));
}
bool SupportGPU() const override {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(),
......@@ -609,6 +604,8 @@ class OperatorWithKernel : public OperatorBase {
return pt_kernel_context_.get();
}
const OpKernelType* kernel_type() const { return kernel_type_.get(); }
private:
void RunImpl(const Scope& scope, const platform::Place& place) const final;
void RunImpl(const Scope& scope, const platform::Place& place,
......
......@@ -102,6 +102,8 @@ class InferShapeContext {
virtual bool IsRuntime() const = 0;
virtual bool IsRunMKLDNNKernel() const = 0;
virtual std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) const = 0;
virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs(
......
......@@ -32,16 +32,17 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
using DDim = framework::DDim;
public:
DygraphInferShapeContext(const NameVarMap<VarType>* in,
const NameVarMap<VarType>* out,
DygraphInferShapeContext(
const NameVarMap<VarType>* in, const NameVarMap<VarType>* out,
const framework::AttributeMap* attr,
const framework::AttributeMap* default_attr,
const std::string op_type)
const framework::AttributeMap* default_attr, const std::string op_type,
const framework::OpKernelType* op_kernel_type = nullptr)
: var_base_map_in_(in),
var_base_map_out_(out),
attrs_(attr),
default_attrs_(default_attr),
op_type_(op_type) {}
op_type_(op_type),
op_kernel_type_(op_kernel_type) {}
bool HasInput(const std::string& name) const override {
// has only one input
......@@ -214,6 +215,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
bool IsRuntime() const override { return true; }
bool IsRunMKLDNNKernel() const override {
return (op_kernel_type_ &&
(op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN));
}
// TODO(paddle-dev): Can this be template?
std::vector<framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override {
......@@ -399,6 +405,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const framework::AttributeMap* attrs_;
const framework::AttributeMap* default_attrs_;
const std::string op_type_;
const framework::OpKernelType* op_kernel_type_;
};
} // namespace imperative
......
......@@ -514,8 +514,8 @@ static void PreparedOpRunImpl(
// TODO(zjl): remove scope in dygraph
framework::Scope scope;
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type());
DygraphInferShapeContext<VarType> infer_shape_ctx(
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type);
op.Info().infer_shape_(&infer_shape_ctx);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
......@@ -560,8 +560,8 @@ static void PreparedOpRunPtImpl(
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) {
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs,
&default_attrs, op.Type());
DygraphInferShapeContext<VarType> infer_shape_ctx(
&ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type);
op.Info().infer_shape_(&infer_shape_ctx);
BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
......
......@@ -93,7 +93,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
x_dims, x_dims.size()));
const int64_t C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
((ctx->IsRunMKLDNNKernel() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
......@@ -508,7 +508,7 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
ctx->Attrs().Get<std::string>("data_layout"));
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
((ctx->IsRunMKLDNNKernel() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
......@@ -911,7 +911,7 @@ void BatchNormDoubleGradOp::InferShape(
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
((ctx->IsRunMKLDNNKernel() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1]
: x_dims[x_dims.size() - 1]);
......
......@@ -57,7 +57,7 @@ std::vector<int64_t> ConvOp::ComputeOutputShape(
// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (this->IsMKLDNNType() == false) &&
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");
PADDLE_ENFORCE_EQ(
......
......@@ -49,7 +49,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
const std::string data_layout_str =
ctx->Attrs().Get<std::string>("data_format");
const DataLayout data_layout =
this->IsMKLDNNType() ? DataLayout::kNCHW
ctx->IsRunMKLDNNKernel() ? DataLayout::kNCHW
: framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
......
......@@ -100,8 +100,8 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp {
const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout"));
const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW)
const int C = ((ctx->IsRunMKLDNNKernel() == true) ||
(data_layout == DataLayout::kNCHW)
? y_dims[1]
: y_dims[y_dims.size() - 1]);
......
......@@ -97,7 +97,7 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
// MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (this->IsMKLDNNType() == false) &&
const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC");
// update paddings if "SAME" or global_pooling
......
......@@ -14,6 +14,7 @@
from __future__ import print_function
import os
import unittest
import numpy as np
......@@ -232,6 +233,15 @@ class TestMKLDNNDilations(TestConv2DMKLDNNOp):
self.groups = 3
# TODO(chenweihang): To solve the coverage problem, add this unittest,
# remove this unittest after new executor set to default executor
class TestConv2dMKLDNNByNewExecutor(TestConv2DMKLDNNOp):
def test_check_output_by_new_executor(self):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
self.test_check_output()
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
if __name__ == '__main__':
from paddle import enable_static
enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册