未验证 提交 281644cd 编写于 作者: C Chen Weihang 提交者: GitHub

Fix mkldnn invalid infershape impl (#38837)

* fix mkldnn invalid infershape

* add unittest for mkldnn in new executor

* add import os
上级 5e515781
...@@ -31,15 +31,18 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext { ...@@ -31,15 +31,18 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
using DDim = paddle::framework::DDim; using DDim = paddle::framework::DDim;
public: public:
EagerInferShapeContext(const NameTensorMap* in, const NameTensorMap* out, EagerInferShapeContext(
const NameTensorMap* in, const NameTensorMap* out,
const paddle::framework::AttributeMap* attr, const paddle::framework::AttributeMap* attr,
const paddle::framework::AttributeMap* default_attr, const paddle::framework::AttributeMap* default_attr,
const std::string op_type) const std::string op_type,
const paddle::framework::OpKernelType* op_kernel_type = nullptr)
: tensor_in_(in), : tensor_in_(in),
tensor_out_(out), tensor_out_(out),
attrs_(attr), attrs_(attr),
default_attrs_(default_attr), default_attrs_(default_attr),
op_type_(op_type) {} op_type_(op_type),
op_kernel_type_(op_kernel_type) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
// has only one input // has only one input
...@@ -214,6 +217,11 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext { ...@@ -214,6 +217,11 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
bool IsRuntime() const override { return true; } bool IsRuntime() const override { return true; }
bool IsRunMKLDNNKernel() const override {
return (op_kernel_type_ && (op_kernel_type_->data_layout_ ==
paddle::framework::DataLayout::kMKLDNN));
}
// TODO(paddle-dev): Can this be template? // TODO(paddle-dev): Can this be template?
std::vector<paddle::framework::InferShapeVarPtr> GetInputVarPtrs( std::vector<paddle::framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override { const std::string& name) const override {
...@@ -400,6 +408,7 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext { ...@@ -400,6 +408,7 @@ class EagerInferShapeContext : public paddle::framework::InferShapeContext {
const paddle::framework::AttributeMap* attrs_; const paddle::framework::AttributeMap* attrs_;
const paddle::framework::AttributeMap* default_attrs_; const paddle::framework::AttributeMap* default_attrs_;
const std::string op_type_; const std::string op_type_;
const paddle::framework::OpKernelType* op_kernel_type_;
}; };
} // namespace legacy } // namespace legacy
......
...@@ -173,7 +173,7 @@ static void PreparedOpRunImpl( ...@@ -173,7 +173,7 @@ static void PreparedOpRunImpl(
paddle::framework::Scope scope; paddle::framework::Scope scope;
EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs, EagerInferShapeContext infer_shape_ctx(&ins, &outs, &attrs, &default_attrs,
op.Type()); op.Type(), &kernel_type);
op.Info().infer_shape_(&infer_shape_ctx); op.Info().infer_shape_(&infer_shape_ctx);
func(EagerExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs, func(EagerExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs,
......
...@@ -307,6 +307,17 @@ void InterpretercoreInferShapeContext::SetLoDLevel(const std::string& out, ...@@ -307,6 +307,17 @@ void InterpretercoreInferShapeContext::SetLoDLevel(const std::string& out,
bool InterpretercoreInferShapeContext::IsRuntime() const { return true; } bool InterpretercoreInferShapeContext::IsRuntime() const { return true; }
bool InterpretercoreInferShapeContext::IsRunMKLDNNKernel() const {
try {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
framework::DataLayout::kMKLDNN));
} catch (std::bad_cast exp) {
return false;
}
}
// TODO(paddle-dev): Can this be template? // TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> InterpretercoreInferShapeContext::GetInputVarPtrs( std::vector<InferShapeVarPtr> InterpretercoreInferShapeContext::GetInputVarPtrs(
const std::string& name) const { const std::string& name) const {
......
...@@ -84,6 +84,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext { ...@@ -84,6 +84,8 @@ class InterpretercoreInferShapeContext : public InferShapeContext {
bool IsRuntime() const override; bool IsRuntime() const override;
bool IsRunMKLDNNKernel() const override;
// TODO(paddle-dev): Can this be template? // TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs( std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override; const std::string& name) const override;
......
...@@ -240,6 +240,8 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -240,6 +240,8 @@ class CompileTimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override; bool IsRuntime() const override;
bool IsRunMKLDNNKernel() const override;
std::vector<proto::VarType::Type> GetInputsVarType( std::vector<proto::VarType::Type> GetInputsVarType(
const std::string &name) const override { const std::string &name) const override {
return GetVarTypes(Inputs(name)); return GetVarTypes(Inputs(name));
...@@ -930,6 +932,8 @@ void CompileTimeInferShapeContext::SetRepeatedDims( ...@@ -930,6 +932,8 @@ void CompileTimeInferShapeContext::SetRepeatedDims(
bool CompileTimeInferShapeContext::IsRuntime() const { return false; } bool CompileTimeInferShapeContext::IsRuntime() const { return false; }
bool CompileTimeInferShapeContext::IsRunMKLDNNKernel() const { return false; }
proto::VarType::Type CompileTimeInferShapeContext::GetVarType( proto::VarType::Type CompileTimeInferShapeContext::GetVarType(
const std::string &name) const { const std::string &name) const {
return block_.FindVarRecursive(name)->GetType(); return block_.FindVarRecursive(name)->GetType();
......
...@@ -884,6 +884,17 @@ class RuntimeInferShapeContext : public InferShapeContext { ...@@ -884,6 +884,17 @@ class RuntimeInferShapeContext : public InferShapeContext {
bool IsRuntime() const override { return true; } bool IsRuntime() const override { return true; }
bool IsRunMKLDNNKernel() const override {
try {
auto& op_with_kernel = dynamic_cast<const OperatorWithKernel&>(op_);
return ((op_with_kernel.kernel_type()) &&
(op_with_kernel.kernel_type()->data_layout_ ==
framework::DataLayout::kMKLDNN));
} catch (std::bad_cast exp) {
return false;
}
}
// TODO(paddle-dev): Can this be template? // TODO(paddle-dev): Can this be template?
std::vector<InferShapeVarPtr> GetInputVarPtrs( std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override { const std::string& name) const override {
...@@ -1178,9 +1189,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -1178,9 +1189,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::RecordEvent record_event("infer_shape", platform::RecordEvent record_event("infer_shape",
platform::EventRole::kInnerOp); platform::EventRole::kInnerOp);
RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx); RuntimeInferShapeContext infer_shape_ctx(*this, *runtime_ctx);
// TODO(chenweihang): replace this after removing `this->IsMKLDNNType()` this->Info().infer_shape_(&infer_shape_ctx);
// in some mkldnn infershape functions, such conv2d infershape
this->InferShape(&infer_shape_ctx);
} }
if (FLAGS_enable_unused_var_check) { if (FLAGS_enable_unused_var_check) {
......
...@@ -528,11 +528,6 @@ class OperatorWithKernel : public OperatorBase { ...@@ -528,11 +528,6 @@ class OperatorWithKernel : public OperatorBase {
return g_all_op_kernels; return g_all_op_kernels;
} }
bool IsMKLDNNType() const {
return ((this->kernel_type_) && (this->kernel_type_->data_layout_ ==
framework::DataLayout::kMKLDNN));
}
bool SupportGPU() const override { bool SupportGPU() const override {
auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_); auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
return std::any_of(op_kernels.begin(), op_kernels.end(), return std::any_of(op_kernels.begin(), op_kernels.end(),
...@@ -609,6 +604,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -609,6 +604,8 @@ class OperatorWithKernel : public OperatorBase {
return pt_kernel_context_.get(); return pt_kernel_context_.get();
} }
const OpKernelType* kernel_type() const { return kernel_type_.get(); }
private: private:
void RunImpl(const Scope& scope, const platform::Place& place) const final; void RunImpl(const Scope& scope, const platform::Place& place) const final;
void RunImpl(const Scope& scope, const platform::Place& place, void RunImpl(const Scope& scope, const platform::Place& place,
......
...@@ -102,6 +102,8 @@ class InferShapeContext { ...@@ -102,6 +102,8 @@ class InferShapeContext {
virtual bool IsRuntime() const = 0; virtual bool IsRuntime() const = 0;
virtual bool IsRunMKLDNNKernel() const = 0;
virtual std::vector<InferShapeVarPtr> GetInputVarPtrs( virtual std::vector<InferShapeVarPtr> GetInputVarPtrs(
const std::string &name) const = 0; const std::string &name) const = 0;
virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs( virtual std::vector<InferShapeVarPtr> GetOutputVarPtrs(
......
...@@ -32,16 +32,17 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -32,16 +32,17 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
using DDim = framework::DDim; using DDim = framework::DDim;
public: public:
DygraphInferShapeContext(const NameVarMap<VarType>* in, DygraphInferShapeContext(
const NameVarMap<VarType>* out, const NameVarMap<VarType>* in, const NameVarMap<VarType>* out,
const framework::AttributeMap* attr, const framework::AttributeMap* attr,
const framework::AttributeMap* default_attr, const framework::AttributeMap* default_attr, const std::string op_type,
const std::string op_type) const framework::OpKernelType* op_kernel_type = nullptr)
: var_base_map_in_(in), : var_base_map_in_(in),
var_base_map_out_(out), var_base_map_out_(out),
attrs_(attr), attrs_(attr),
default_attrs_(default_attr), default_attrs_(default_attr),
op_type_(op_type) {} op_type_(op_type),
op_kernel_type_(op_kernel_type) {}
bool HasInput(const std::string& name) const override { bool HasInput(const std::string& name) const override {
// has only one input // has only one input
...@@ -214,6 +215,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -214,6 +215,11 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
bool IsRuntime() const override { return true; } bool IsRuntime() const override { return true; }
bool IsRunMKLDNNKernel() const override {
return (op_kernel_type_ &&
(op_kernel_type_->data_layout_ == framework::DataLayout::kMKLDNN));
}
// TODO(paddle-dev): Can this be template? // TODO(paddle-dev): Can this be template?
std::vector<framework::InferShapeVarPtr> GetInputVarPtrs( std::vector<framework::InferShapeVarPtr> GetInputVarPtrs(
const std::string& name) const override { const std::string& name) const override {
...@@ -399,6 +405,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext { ...@@ -399,6 +405,7 @@ class DygraphInferShapeContext : public framework::InferShapeContext {
const framework::AttributeMap* attrs_; const framework::AttributeMap* attrs_;
const framework::AttributeMap* default_attrs_; const framework::AttributeMap* default_attrs_;
const std::string op_type_; const std::string op_type_;
const framework::OpKernelType* op_kernel_type_;
}; };
} // namespace imperative } // namespace imperative
......
...@@ -514,8 +514,8 @@ static void PreparedOpRunImpl( ...@@ -514,8 +514,8 @@ static void PreparedOpRunImpl(
// TODO(zjl): remove scope in dygraph // TODO(zjl): remove scope in dygraph
framework::Scope scope; framework::Scope scope;
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs, DygraphInferShapeContext<VarType> infer_shape_ctx(
&default_attrs, op.Type()); &ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type);
op.Info().infer_shape_(&infer_shape_ctx); op.Info().infer_shape_(&infer_shape_ctx);
func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs, func(DygraphExecutionContext<VarType>(op, scope, *dev_ctx, ctx, ins, outs,
...@@ -560,8 +560,8 @@ static void PreparedOpRunPtImpl( ...@@ -560,8 +560,8 @@ static void PreparedOpRunPtImpl(
platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins, platform::DeviceContext* dev_ctx, const NameVarMap<VarType>& ins,
const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs, const NameVarMap<VarType>& outs, const framework::AttributeMap& attrs,
const framework::AttributeMap& default_attrs) { const framework::AttributeMap& default_attrs) {
DygraphInferShapeContext<VarType> infer_shape_ctx(&ins, &outs, &attrs, DygraphInferShapeContext<VarType> infer_shape_ctx(
&default_attrs, op.Type()); &ins, &outs, &attrs, &default_attrs, op.Type(), &kernel_type);
op.Info().infer_shape_(&infer_shape_ctx); op.Info().infer_shape_(&infer_shape_ctx);
BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins, BuildDygraphPtenKernelContext<VarType>(pt_kernel_signature, pt_kernel, ins,
......
...@@ -93,7 +93,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -93,7 +93,7 @@ void BatchNormOp::InferShape(framework::InferShapeContext *ctx) const {
x_dims, x_dims.size())); x_dims, x_dims.size()));
const int64_t C = const int64_t C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) ((ctx->IsRunMKLDNNKernel() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1] ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
...@@ -508,7 +508,7 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -508,7 +508,7 @@ void BatchNormGradOp::InferShape(framework::InferShapeContext *ctx) const {
ctx->Attrs().Get<std::string>("data_layout")); ctx->Attrs().Get<std::string>("data_layout"));
const int C = const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) ((ctx->IsRunMKLDNNKernel() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1] ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
...@@ -911,7 +911,7 @@ void BatchNormDoubleGradOp::InferShape( ...@@ -911,7 +911,7 @@ void BatchNormDoubleGradOp::InferShape(
const DataLayout data_layout = framework::StringToDataLayout( const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout")); ctx->Attrs().Get<std::string>("data_layout"));
const int C = const int C =
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) ((ctx->IsRunMKLDNNKernel() == true) || (data_layout == DataLayout::kNCHW)
? x_dims[1] ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
......
...@@ -57,7 +57,7 @@ std::vector<int64_t> ConvOp::ComputeOutputShape( ...@@ -57,7 +57,7 @@ std::vector<int64_t> ConvOp::ComputeOutputShape(
// MKL-DNN Kernels are using NCHW order of dims description // MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel // so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (this->IsMKLDNNType() == false) && const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC"); (data_format == "NHWC" || data_format == "NDHWC");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -49,7 +49,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -49,7 +49,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
const std::string data_layout_str = const std::string data_layout_str =
ctx->Attrs().Get<std::string>("data_format"); ctx->Attrs().Get<std::string>("data_format");
const DataLayout data_layout = const DataLayout data_layout =
this->IsMKLDNNType() ? DataLayout::kNCHW ctx->IsRunMKLDNNKernel() ? DataLayout::kNCHW
: framework::StringToDataLayout(data_layout_str); : framework::StringToDataLayout(data_layout_str);
PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true, PADDLE_ENFORCE_EQ(in_dims.size() == 4 || in_dims.size() == 5, true,
......
...@@ -100,8 +100,8 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp { ...@@ -100,8 +100,8 @@ class InplaceABNGradOp : public paddle::operators::BatchNormGradOp {
const DataLayout data_layout = framework::StringToDataLayout( const DataLayout data_layout = framework::StringToDataLayout(
ctx->Attrs().Get<std::string>("data_layout")); ctx->Attrs().Get<std::string>("data_layout"));
const int C = const int C = ((ctx->IsRunMKLDNNKernel() == true) ||
((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) (data_layout == DataLayout::kNCHW)
? y_dims[1] ? y_dims[1]
: y_dims[y_dims.size() - 1]); : y_dims[y_dims.size() - 1]);
......
...@@ -97,7 +97,7 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -97,7 +97,7 @@ void PoolOp::InferShape(framework::InferShapeContext* ctx) const {
// MKL-DNN Kernels are using NCHW order of dims description // MKL-DNN Kernels are using NCHW order of dims description
// so we ignore data_format consideration for MKL-DNN kernel // so we ignore data_format consideration for MKL-DNN kernel
const bool channel_last = (this->IsMKLDNNType() == false) && const bool channel_last = (ctx->IsRunMKLDNNKernel() == false) &&
(data_format == "NHWC" || data_format == "NDHWC"); (data_format == "NHWC" || data_format == "NDHWC");
// update paddings if "SAME" or global_pooling // update paddings if "SAME" or global_pooling
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
from __future__ import print_function from __future__ import print_function
import os
import unittest import unittest
import numpy as np import numpy as np
...@@ -232,6 +233,15 @@ class TestMKLDNNDilations(TestConv2DMKLDNNOp): ...@@ -232,6 +233,15 @@ class TestMKLDNNDilations(TestConv2DMKLDNNOp):
self.groups = 3 self.groups = 3
# TODO(chenweihang): To solve the coverage problem, add this unittest,
# remove this unittest after new executor set to default executor
class TestConv2dMKLDNNByNewExecutor(TestConv2DMKLDNNOp):
def test_check_output_by_new_executor(self):
os.environ['FLAGS_USE_STANDALONE_EXECUTOR'] = '1'
self.test_check_output()
del os.environ['FLAGS_USE_STANDALONE_EXECUTOR']
if __name__ == '__main__': if __name__ == '__main__':
from paddle import enable_static from paddle import enable_static
enable_static() enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册