提交 053b0428 编写于 作者: X xiaolil1

fix bugs for int8 run

上级 751a826c
...@@ -62,6 +62,18 @@ inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) { ...@@ -62,6 +62,18 @@ inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) {
return MKLDNNDataType::data_undef; return MKLDNNDataType::data_undef;
} }
inline std::type_index MKLDNNToTypeIndex(const MKLDNNDataType data_type) {
static const std::map<MKLDNNDataType, std::type_index> dict{
{MKLDNNDataType::f32, std::type_index(typeid(float))}, // NOLINT
{MKLDNNDataType::s8, std::type_index(typeid(char))}, // NOLINT
{MKLDNNDataType::u8, std::type_index(typeid(unsigned char))},
{MKLDNNDataType::s16, std::type_index(typeid(int16_t))},
{MKLDNNDataType::s32, std::type_index(typeid(int32_t))}};
auto iter = dict.find(data_type);
if (iter != dict.end()) return iter->second;
return std::type_index(typeid(float));
}
#endif #endif
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var, void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
......
...@@ -39,10 +39,16 @@ class Tensor { ...@@ -39,10 +39,16 @@ class Tensor {
public: public:
inline mkldnn::memory::format format() const { return format_; } inline mkldnn::memory::format format() const { return format_; }
inline mkldnn::memory::data_type data_type() const { return data_type_; }
inline void set_format(const mkldnn::memory::format format) { inline void set_format(const mkldnn::memory::format format) {
format_ = format; format_ = format;
} }
inline void set_data_type(const mkldnn::memory::data_type data_type) {
data_type_ = data_type;
}
protected: protected:
/** /**
* @brief the detail format of memory block which have layout as kMKLDNN * @brief the detail format of memory block which have layout as kMKLDNN
...@@ -54,6 +60,8 @@ class Tensor { ...@@ -54,6 +60,8 @@ class Tensor {
*/ */
mkldnn::memory::format format_ = mkldnn::memory::format::format_undef; mkldnn::memory::format format_ = mkldnn::memory::format::format_undef;
mkldnn::memory::data_type data_type_ = mkldnn::memory::data_type::f32;
#endif #endif
public: public:
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/framework/data_layout_transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -321,6 +322,7 @@ std::cout<<"this is conv kernel op....................."<<std::endl; ...@@ -321,6 +322,7 @@ std::cout<<"this is conv kernel op....................."<<std::endl;
bool fuse_relu = ctx.Attr<bool>("fuse_relu"); bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
std::cout<<"fuse_relu = "<<fuse_relu<<" fuse_residual_conn = "<<fuse_residual_conn<<std::endl;
// TODO(tpatejko): add support for dilation // TODO(tpatejko): add support for dilation
PADDLE_ENFORCE( PADDLE_ENFORCE(
...@@ -429,15 +431,20 @@ std::cout<<"log3....."<<std::endl; ...@@ -429,15 +431,20 @@ std::cout<<"log3....."<<std::endl;
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format);
//memory::data_type dst_dt = memory::data_type::f32;
//auto dst_dt = std::type_index(typeid(float));
if(is_INT8){ if(is_INT8){
src_md = platform::MKLDNNMemDesc( src_md = platform::MKLDNNMemDesc(
src_tz, memory::data_type::u8, chosen_memory_format); src_tz, memory::data_type::u8, chosen_memory_format);
weights_md = platform::MKLDNNMemDesc( weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8, weights_tz, memory::data_type::s8,
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw); (g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
//dst_dt = fuse_relu?memory::data_type::u8:memory::data_type::s8;
//dst_dt = fuse_relu? std::type_index(typeid(unsigned char)) : std::type_index(typeid(char));
dst_md = platform::MKLDNNMemDesc( dst_md = platform::MKLDNNMemDesc(
dst_tz, dst_tz,
fuse_relu?memory::data_type::u8:memory::data_type::s8, fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) :
paddle::framework::ToMKLDNNDataType(std::type_index(typeid(char))),
chosen_memory_format); chosen_memory_format);
} }
...@@ -502,7 +509,7 @@ std::cout<<"log3....."<<std::endl; ...@@ -502,7 +509,7 @@ std::cout<<"log3....."<<std::endl;
std::shared_ptr<mkldnn::memory> dst_memory_p; std::shared_ptr<mkldnn::memory> dst_memory_p;
if(is_INT8){ if(is_INT8){
int8_t* output_data = nullptr; //T* output_data = nullptr;
if (fuse_residual_conn) { if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(), PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
...@@ -510,15 +517,28 @@ std::cout<<"log3....."<<std::endl; ...@@ -510,15 +517,28 @@ std::cout<<"log3....."<<std::endl;
"same dimension sizes"); "same dimension sizes");
output->ShareDataWith(*residual_param); output->ShareDataWith(*residual_param);
output_data = output->mutable_data<int8_t>(ctx.GetPlace()); if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
} else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
}
} else { } else {
std::cout<<"conv log 1 ....................."<<std::endl; std::cout<<"conv log 1 ....................."<<std::endl;
output_data = if(fuse_relu){
output->mutable_data<int8_t>(ctx.GetPlace(), handler.GetDstMemorySize()); uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
std::cout<<"conv log 2 //////////////////////"<<std::endl; dst_memory_p =
} handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
} else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
}
std::cout<<"conv log 2 //////////////////////"<<std::endl;
}
std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<<" dst fmt = "<<dst_memory_p->get_primitive_desc().desc().data.format<<std::endl; std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<<" dst fmt = "<<dst_memory_p->get_primitive_desc().desc().data.format<<std::endl;
} else{ } else{
T* output_data = nullptr; T* output_data = nullptr;
...@@ -582,7 +602,10 @@ std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()< ...@@ -582,7 +602,10 @@ std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<<" dst fmt = "<<dst_memory_p->get_primitive_desc().desc().data.format<<std::endl; //output->set_data_type(paddle::framework::MKLDNNToTypeIndex(dst_dt));
//output->set_data_type(dst_dt);
std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<<" dst fmt = "<<dst_memory_p->get_primitive_desc().desc().data.format<<"output dt = "<<paddle::framework::ToMKLDNNDataType(output->type())<<"dst dt = "<<dst_memory_p->get_primitive_desc().desc().data.data_type<<std::endl;
std::cout<<"this is conv end!!!!!!!!!!!!!!!!!!!!"<<std::endl;
} }
private: private:
...@@ -916,7 +939,7 @@ namespace ops = paddle::operators; ...@@ -916,7 +939,7 @@ namespace ops = paddle::operators;
REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNOpKernel<float>, ops::ConvMKLDNNOpKernel<float>,
ops::ConvMKLDNNOpKernel<int8_t>); ops::ConvMKLDNNOpKernel<uint8_t>);
REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNGradOpKernel<float>); ops::ConvMKLDNNGradOpKernel<float>);
...@@ -46,9 +46,9 @@ std::cout<<"this is dequant op ***********"<<std::endl; ...@@ -46,9 +46,9 @@ std::cout<<"this is dequant op ***********"<<std::endl;
const auto& engine = dev_ctx.GetEngine(); const auto& engine = dev_ctx.GetEngine();
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace()); float* output_data = output->mutable_data<float>(ctx.GetPlace());
//T scale_data = *(scale->data<T>()); //T scale_data = *(scale->data<T>());
std::vector<T> scale_data = {*(scale->data<T>())}; std::vector<float> scale_data = {*(scale->data<float>())};
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
...@@ -69,7 +69,7 @@ std::cout<<"this is dequant op ***********"<<std::endl; ...@@ -69,7 +69,7 @@ std::cout<<"this is dequant op ***********"<<std::endl;
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, memory::data_type::f32, memory::format::nchw); {dst_tz}, memory::data_type::f32, memory::format::nchw);
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine); auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<T>(output_data)); auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<float>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>( auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(dst_pd, src_pd, attri)); new reorder::primitive_desc(dst_pd, src_pd, attri));
...@@ -112,5 +112,5 @@ namespace ops = paddle::operators; ...@@ -112,5 +112,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>); REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, ops::DeQuantOpKernel<float>); REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, ops::DeQuantOpKernel<uint8_t>);
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/pool_op.h" #include "paddle/fluid/operators/pool_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/framework/data_layout_transform.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -71,7 +72,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -71,7 +72,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()), PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace."); "It must use CPUPlace.");
std::cout<<"this is pool op"<<std::endl;
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
...@@ -129,14 +130,16 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -129,14 +130,16 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides,
padding_right_bottom); padding_right_bottom);
} }
mkldnn::memory::data_type dt = paddle::framework::ToMKLDNNDataType(input->type());
std::cout<<"input type = "<<dt<<std::endl;
auto src_md = platform::MKLDNNMemDesc( auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), input_format); src_tz, dt, input_format);
/* create memory descriptor for pooling without specified format /* create memory descriptor for pooling without specified format
* ('any') which lets a primitive (pooling in this case) choose * ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance * the memory format preferred for best performance
*/ */
auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32, auto dst_md = platform::MKLDNNMemDesc(dst_tz, dt,
mkldnn::memory::format::any); mkldnn::memory::format::any);
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd = std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
...@@ -399,6 +402,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -399,6 +402,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace,
ops::PoolMKLDNNOpKernel<float>); ops::PoolMKLDNNOpKernel<float>,
ops::PoolMKLDNNOpKernel<int8_t>,
ops::PoolMKLDNNOpKernel<uint8_t>);
REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::PoolMKLDNNGradOpKernel<float>); ops::PoolMKLDNNGradOpKernel<float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册