未验证 提交 dde9cec0 编写于 作者: J Jacek Czaja 提交者: GitHub

oneDNN NHWC fixes (#40049)

* - Prototype of third solution

- fix

- compilation fixes

- fix

- fixe

- fix

- fix

- compilation fix

- comment fix

- lint

update mkldnn conv_elementwise_add_fuse_pass ut

- NHWC changes to prelu

- alhpa dims

- UT fix

- fix to UT

- lint

- Some fixes

- added to BWD of prelu NHWC support

- reverted removal of resetting cu_layout in clearing of caching

* - Small changes

* - compilation fix

* - fix

* - fix

* lint

* - fixes after internal review

* - compilation fix

* - lint
上级 813f61d2
...@@ -174,10 +174,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id, ...@@ -174,10 +174,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool force_disable_gc, bool keep_kid_scopes) { bool force_disable_gc, bool keep_kid_scopes) {
platform::RecordBlock b(block_id); platform::RecordBlock b(block_id);
if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc); if (FLAGS_use_mkldnn) EnableMKLDNN(pdesc);
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_); platform::AttachPointerHashToMKLDNNKey(this, place_);
platform::RegisterModelLayout(ctx->ops_, place_);
#endif #endif
auto ctx = Prepare(pdesc, block_id, skip_ref_cnt_vars, force_disable_gc);
RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars, RunPreparedContext(ctx.get(), scope, create_local_scope, create_vars,
keep_kid_scopes); keep_kid_scopes);
} }
......
...@@ -118,7 +118,7 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() { ...@@ -118,7 +118,7 @@ ResidualConnectionMKLDNNFusePass::ResidualConnectionMKLDNNFusePass() {
.IsType<std::vector<int>>() .IsType<std::vector<int>>()
.End() .End()
.AddAttr("data_format") .AddAttr("data_format")
.IsStringIn({"NCHW", "AnyLayout"}) .IsStringIn({"NHWC", "NCHW", "AnyLayout"})
.End(); .End();
AddOpCompat(OpCompat("elementwise_add")) AddOpCompat(OpCompat("elementwise_add"))
......
...@@ -41,6 +41,7 @@ void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc, ...@@ -41,6 +41,7 @@ void NaiveExecutor::Prepare(Scope *scope, const ProgramDesc &program_desc,
void NaiveExecutor::Run() { void NaiveExecutor::Run() {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
platform::AttachPointerHashToMKLDNNKey(this, place_); platform::AttachPointerHashToMKLDNNKey(this, place_);
platform::RegisterModelLayout(ops_, place_);
#endif #endif
platform::ScopedFlushDenormal flush; platform::ScopedFlushDenormal flush;
for (auto &op : ops_) { for (auto &op : ops_) {
......
...@@ -221,7 +221,7 @@ class LRNOp : public framework::OperatorWithKernel { ...@@ -221,7 +221,7 @@ class LRNOp : public framework::OperatorWithKernel {
auto ar = paddle::framework::AttrReader(attrs); auto ar = paddle::framework::AttrReader(attrs);
const std::string data_format = ar.Get<std::string>("data_format"); const std::string data_format = ar.Get<std::string>("data_format");
auto dl = framework::StringToDataLayout(data_format); auto dl = framework::StringToDataLayout(data_format);
// Some models may have intentionally set "AnyLayout" for pool // Some models may have intentionally set "AnyLayout" for lrn
// op. Treat this as NCHW (default data_format value) // op. Treat this as NCHW (default data_format value)
if (dl != framework::DataLayout::kAnyLayout) { if (dl != framework::DataLayout::kAnyLayout) {
return framework::OpKernelType(expected_kernel_type.data_type_, return framework::OpKernelType(expected_kernel_type.data_type_,
......
...@@ -50,13 +50,8 @@ class PReluMKLDNNHandler ...@@ -50,13 +50,8 @@ class PReluMKLDNNHandler
if (weights->dims().size() != x->dims().size()) { if (weights->dims().size() != x->dims().size()) {
auto new_weights_dims = std::vector<int64_t>(x->dims().size(), 1); auto new_weights_dims = std::vector<int64_t>(x->dims().size(), 1);
if (mode == "channel") { if (mode == "channel") {
if (data_format == "NHWC") { new_weights_dims[1] =
new_weights_dims[x->dims().size() - 1] = *std::max_element(weights_dims.begin(), weights_dims.end());
*std::max_element(weights_dims.begin(), weights_dims.end());
} else {
new_weights_dims[1] =
*std::max_element(weights_dims.begin(), weights_dims.end());
}
} }
weights_dims = std::move(new_weights_dims); weights_dims = std::move(new_weights_dims);
} }
......
...@@ -98,7 +98,7 @@ TEST(test_pool2d_transpose_nhwc, cpu_place) { ...@@ -98,7 +98,7 @@ TEST(test_pool2d_transpose_nhwc, cpu_place) {
TEST(test_pool2d_relu_relu_nhwc, cpu_place) { TEST(test_pool2d_relu_relu_nhwc, cpu_place) {
framework::DDim dims({1, 4, 8, 512}); // NHWC shape framework::DDim dims({1, 4, 8, 512}); // NHWC shape
framework::DDim expected_dims({1, 512, 3, 7}); // NHWC expected shape framework::DDim expected_dims({1, 512, 3, 7}); // NCHW expected shape
platform::CPUPlace p; platform::CPUPlace p;
framework::Scope scope; framework::Scope scope;
......
...@@ -17,6 +17,26 @@ limitations under the License. */ ...@@ -17,6 +17,26 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
framework::OpKernelType innerGetKernelTypeForVar(
const Tensor &tensor, const framework::OpKernelType &expected_kernel_type) {
#ifdef PADDLE_WITH_MKLDNN
auto isOneDNNKernelChosen =
(expected_kernel_type.data_layout_ == framework::DataLayout::kMKLDNN);
auto isNotOneDNNTensor = (tensor.layout() != framework::DataLayout::kMKLDNN);
auto isModelNHWC =
(paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == framework::DataLayout::kNHWC);
// All inputs (including alpha) need shape rotating
if (isOneDNNKernelChosen && isNotOneDNNTensor && isModelNHWC) {
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(),
framework::DataLayout::kNHWC);
}
#endif
return framework::OpKernelType(expected_kernel_type.data_type_,
tensor.place(), tensor.layout());
}
class PReluOp : public framework::OperatorWithKernel { class PReluOp : public framework::OperatorWithKernel {
public: public:
PReluOp(const std::string &type, const framework::VariableNameMap &inputs, PReluOp(const std::string &type, const framework::VariableNameMap &inputs,
...@@ -53,7 +73,7 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -53,7 +73,7 @@ class PReluOp : public framework::OperatorWithKernel {
"For mode 'channel', data_format must be one of " "For mode 'channel', data_format must be one of "
"NCHW and NHWC. But recevied data_format: %s", "NCHW and NHWC. But recevied data_format: %s",
data_format_str)); data_format_str));
if (data_format_str == "NCHW") { if (data_format_str == "NCHW" || ctx->IsRunMKLDNNKernel()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
product(ctx->GetInputDim("Alpha")) == x_dim[1], true, product(ctx->GetInputDim("Alpha")) == x_dim[1], true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -128,6 +148,12 @@ class PReluOp : public framework::OperatorWithKernel { ...@@ -128,6 +148,12 @@ class PReluOp : public framework::OperatorWithKernel {
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
return innerGetKernelTypeForVar(tensor, expected_kernel_type);
}
}; };
class PReluOpMaker : public framework::OpProtoAndCheckerMaker { class PReluOpMaker : public framework::OpProtoAndCheckerMaker {
...@@ -212,6 +238,12 @@ class PReluGradOp : public framework::OperatorWithKernel { ...@@ -212,6 +238,12 @@ class PReluGradOp : public framework::OperatorWithKernel {
#endif #endif
return framework::OpKernelType(input_data_type, ctx.GetPlace()); return framework::OpKernelType(input_data_type, ctx.GetPlace());
} }
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const {
return innerGetKernelTypeForVar(tensor, expected_kernel_type);
}
}; };
template <typename T> template <typename T>
......
...@@ -559,6 +559,34 @@ inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT ...@@ -559,6 +559,34 @@ inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT
} }
} }
inline void RegisterModelLayout(
std::vector<std::unique_ptr<framework::OperatorBase>>& ops,
const platform::Place& place) {
if (platform::is_cpu_place(place)) {
auto check_attrib = [](std::unique_ptr<framework::OperatorBase>& op,
const std::string& attrib_name) -> bool {
if (op->HasAttr(attrib_name)) {
auto data_format = op->Attr<std::string>(attrib_name);
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
data_format.compare("NHWC") == 0 ? framework::DataLayout::kNHWC
: framework::DataLayout::kNCHW);
return true;
} else {
return false;
}
};
for (auto& op : ops) {
if (check_attrib(op, std::string("data_format"))) {
return;
}
if (check_attrib(op, std::string("data_layout"))) {
return;
}
}
}
}
inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) { inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" || return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" ||
op->GetAttrIfExists<bool>("use_quantizer")); op->GetAttrIfExists<bool>("use_quantizer"));
......
...@@ -25,17 +25,120 @@ from hypothesis import given, settings, seed, example, assume ...@@ -25,17 +25,120 @@ from hypothesis import given, settings, seed, example, assume
import hypothesis.strategies as st import hypothesis.strategies as st
class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest): # the two inputs of elementwise_add are tensor
class TestConvElementwiseAddMkldnnFusePass1(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [ attrs = [
program_config.ops[i].attrs program_config.ops[i].attrs
for i in range(len(program_config.ops)) for i in range(len(program_config.ops))
] ]
# If the problem has been fixed, the judgment if attrs[1]['data_format'] == "NHWC" and attrs[3]['axis'] == 0:
# needs to be deleted!!! return False
if attrs[1]['data_format'] == "NHWC": if attrs[1]['data_format'] == "NCHW" and attrs[3]['axis'] == -1:
return False return False
return True
def sample_program_config(self, draw):
data_format = draw(st.sampled_from(["NCHW", "NHWC"]))
dilations = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
padding_algorithm = draw(st.sampled_from(["EXPLICIT", "SAME", "VALID"]))
groups = draw(st.sampled_from([1, 2, 4]))
paddings = draw(st.sampled_from([[0, 3], [1, 1], [1, 2, 3, 4]]))
strides = draw(st.sampled_from([[1, 1], [2, 2], [1, 2]]))
axis = draw(st.sampled_from([-1, 0]))
batch_size = draw(st.integers(min_value=1, max_value=4))
def generate_input():
if data_format == "NCHW":
return np.random.random(
[batch_size, 48, 64, 64]).astype(np.float32)
else:
return np.random.random(
[batch_size, 64, 64, 48]).astype(np.float32)
def generate_weight():
return np.random.random(
[48, int(48 / groups), 3, 3]).astype(np.float32)
relu_op = OpConfig(
type="relu",
inputs={"X": ["input_data"]},
outputs={"Out": ["relu_out"]},
attrs={})
conv2d_op1 = OpConfig(
type="conv2d",
inputs={"Input": ["relu_out"],
"Filter": ["conv_weight1"]},
outputs={"Output": ["conv_output1"]},
attrs={
"data_format": data_format,
"dilations": dilations,
"padding_algorithm": padding_algorithm,
"groups": groups,
"paddings": paddings,
"strides": strides
})
conv2d_op2 = OpConfig(
type="conv2d",
inputs={"Input": ["input_data"],
"Filter": ["conv_weight2"]},
outputs={"Output": ["conv_output2"]},
attrs={
"data_format": data_format,
"dilations": dilations,
"padding_algorithm": padding_algorithm,
"groups": groups,
"paddings": paddings,
"strides": strides
})
elt_op = OpConfig(
type="elementwise_add",
inputs={"X": ["conv_output1"],
"Y": ["conv_output2"]},
outputs={"Out": ["elementwise_output"]},
attrs={'axis': axis})
model_net = [relu_op, conv2d_op1, conv2d_op2, elt_op]
program_config = ProgramConfig(
ops=model_net,
weights={
"conv_weight1": TensorConfig(data_gen=partial(generate_weight)),
"conv_weight2": TensorConfig(data_gen=partial(generate_weight))
},
inputs={
"input_data": TensorConfig(data_gen=partial(generate_input))
},
outputs=["elementwise_output"])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True)
yield config, ["relu", "conv2d", "conv2d"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(
quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"])
'''
class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [
program_config.ops[i].attrs
for i in range(len(program_config.ops))
]
if "elementwise_weight" in program_config.weights:
if program_config.weights["elementwise_weight"].shape[0] == program_config.inputs["input_data1"].shape[1]:
if attrs[2]['axis'] != 1:
return False
if program_config.weights["elementwise_weight"].shape[0] == program_config.inputs["input_data1"].shape[3]:
if attrs[2]['axis'] != -1:
return False
return True return True
def sample_program_config(self, draw): def sample_program_config(self, draw):
...@@ -101,7 +204,7 @@ class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest): ...@@ -101,7 +204,7 @@ class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest):
"strides": strides "strides": strides
}) })
if axis == -1 or axis == 0: if axis == 0:
elt_op = OpConfig( elt_op = OpConfig(
type="elementwise_add", type="elementwise_add",
inputs={"X": ["input_data1"], inputs={"X": ["input_data1"],
...@@ -118,14 +221,12 @@ class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest): ...@@ -118,14 +221,12 @@ class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest):
model_net = [relu_op, conv2d_op, elt_op] model_net = [relu_op, conv2d_op, elt_op]
if axis == 1: if axis == 0:
program_config = ProgramConfig( program_config = ProgramConfig(
ops=model_net, ops=model_net,
weights={ weights={
"conv_weight": "conv_weight":
TensorConfig(data_gen=partial(generate_weight1)), TensorConfig(data_gen=partial(generate_weight1))
"elementwise_weight":
TensorConfig(data_gen=partial(generate_weight2))
}, },
inputs={ inputs={
"input_data1": "input_data1":
...@@ -137,7 +238,9 @@ class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest): ...@@ -137,7 +238,9 @@ class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest):
ops=model_net, ops=model_net,
weights={ weights={
"conv_weight": "conv_weight":
TensorConfig(data_gen=partial(generate_weight1)) TensorConfig(data_gen=partial(generate_weight1)),
"elementwise_weight":
TensorConfig(data_gen=partial(generate_weight2))
}, },
inputs={ inputs={
"input_data1": "input_data1":
...@@ -154,7 +257,7 @@ class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest): ...@@ -154,7 +257,7 @@ class TestConvElementwiseAddMkldnnFusePass(PassAutoScanTest):
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"]) quant=False, passes=["conv_elementwise_add_mkldnn_fuse_pass"])
'''
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册