提交 4cff8798 编写于 作者: X xiaolil1

Revert "fix bugs for int8 run"

This reverts commit 053b0428.
上级 053b0428
......@@ -62,18 +62,6 @@ inline MKLDNNDataType ToMKLDNNDataType(const std::type_index type) {
return MKLDNNDataType::data_undef;
}
inline std::type_index MKLDNNToTypeIndex(const MKLDNNDataType data_type) {
static const std::map<MKLDNNDataType, std::type_index> dict{
{MKLDNNDataType::f32, std::type_index(typeid(float))}, // NOLINT
{MKLDNNDataType::s8, std::type_index(typeid(char))}, // NOLINT
{MKLDNNDataType::u8, std::type_index(typeid(unsigned char))},
{MKLDNNDataType::s16, std::type_index(typeid(int16_t))},
{MKLDNNDataType::s32, std::type_index(typeid(int32_t))}};
auto iter = dict.find(data_type);
if (iter != dict.end()) return iter->second;
return std::type_index(typeid(float));
}
#endif
void TransDataLayoutFromMKLDNN(const OpKernelType& kernel_type_for_var,
......
......@@ -39,15 +39,9 @@ class Tensor {
public:
inline mkldnn::memory::format format() const { return format_; }
inline mkldnn::memory::data_type data_type() const { return data_type_; }
inline void set_format(const mkldnn::memory::format format) {
format_ = format;
}
inline void set_data_type(const mkldnn::memory::data_type data_type) {
data_type_ = data_type;
}
protected:
/**
......@@ -60,8 +54,6 @@ class Tensor {
*/
mkldnn::memory::format format_ = mkldnn::memory::format::format_undef;
mkldnn::memory::data_type data_type_ = mkldnn::memory::data_type::f32;
#endif
public:
......
......@@ -14,7 +14,6 @@
#include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/framework/data_layout_transform.h"
namespace paddle {
namespace operators {
......@@ -322,7 +321,6 @@ std::cout<<"this is conv kernel op....................."<<std::endl;
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups");
std::cout<<"fuse_relu = "<<fuse_relu<<" fuse_residual_conn = "<<fuse_residual_conn<<std::endl;
// TODO(tpatejko): add support for dilation
PADDLE_ENFORCE(
......@@ -431,20 +429,15 @@ std::cout<<"log3....."<<std::endl;
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<float>(), chosen_memory_format);
//memory::data_type dst_dt = memory::data_type::f32;
//auto dst_dt = std::type_index(typeid(float));
if(is_INT8){
src_md = platform::MKLDNNMemDesc(
src_tz, memory::data_type::u8, chosen_memory_format);
weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8,
(g == 1) ? chosen_memory_format : mkldnn::memory::format::goihw);
//dst_dt = fuse_relu?memory::data_type::u8:memory::data_type::s8;
//dst_dt = fuse_relu? std::type_index(typeid(unsigned char)) : std::type_index(typeid(char));
dst_md = platform::MKLDNNMemDesc(
dst_tz,
fuse_relu? paddle::framework::ToMKLDNNDataType(std::type_index(typeid(unsigned char))) :
paddle::framework::ToMKLDNNDataType(std::type_index(typeid(char))),
fuse_relu?memory::data_type::u8:memory::data_type::s8,
chosen_memory_format);
}
......@@ -509,7 +502,7 @@ std::cout<<"log3....."<<std::endl;
std::shared_ptr<mkldnn::memory> dst_memory_p;
if(is_INT8){
//T* output_data = nullptr;
int8_t* output_data = nullptr;
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
PADDLE_ENFORCE_EQ(output->dims(), residual_param->dims(),
......@@ -517,28 +510,15 @@ std::cout<<"log3....."<<std::endl;
"same dimension sizes");
output->ShareDataWith(*residual_param);
if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
} else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
}
output_data = output->mutable_data<int8_t>(ctx.GetPlace());
} else {
std::cout<<"conv log 1 ....................."<<std::endl;
if(fuse_relu){
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<uint8_t>(output_data));
} else{
int8_t* output_data = output->mutable_data<int8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
}
output_data =
output->mutable_data<int8_t>(ctx.GetPlace(), handler.GetDstMemorySize());
std::cout<<"conv log 2 //////////////////////"<<std::endl;
}
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<int8_t>(output_data));
std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<<" dst fmt = "<<dst_memory_p->get_primitive_desc().desc().data.format<<std::endl;
} else{
T* output_data = nullptr;
......@@ -602,10 +582,7 @@ std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
//output->set_data_type(paddle::framework::MKLDNNToTypeIndex(dst_dt));
//output->set_data_type(dst_dt);
std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<<" dst fmt = "<<dst_memory_p->get_primitive_desc().desc().data.format<<"output dt = "<<paddle::framework::ToMKLDNNDataType(output->type())<<"dst dt = "<<dst_memory_p->get_primitive_desc().desc().data.data_type<<std::endl;
std::cout<<"this is conv end!!!!!!!!!!!!!!!!!!!!"<<std::endl;
std::cout<<"input fmt = "<<input->format()<<" output fmt = "<<output->format()<<" dst fmt = "<<dst_memory_p->get_primitive_desc().desc().data.format<<std::endl;
}
private:
......@@ -939,7 +916,7 @@ namespace ops = paddle::operators;
REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNOpKernel<float>,
ops::ConvMKLDNNOpKernel<uint8_t>);
ops::ConvMKLDNNOpKernel<int8_t>);
REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvMKLDNNGradOpKernel<float>);
......@@ -46,9 +46,9 @@ std::cout<<"this is dequant op ***********"<<std::endl;
const auto& engine = dev_ctx.GetEngine();
const T* input_data = input->data<T>();
float* output_data = output->mutable_data<float>(ctx.GetPlace());
T* output_data = output->mutable_data<T>(ctx.GetPlace());
//T scale_data = *(scale->data<T>());
std::vector<float> scale_data = {*(scale->data<float>())};
std::vector<T> scale_data = {*(scale->data<T>())};
std::vector<primitive> pipeline;
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
......@@ -69,7 +69,7 @@ std::cout<<"this is dequant op ***********"<<std::endl;
auto dst_md = platform::MKLDNNMemDesc(
{dst_tz}, memory::data_type::f32, memory::format::nchw);
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<float>(output_data));
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<T>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(dst_pd, src_pd, attri));
......@@ -112,5 +112,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker, paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, ops::DeQuantOpKernel<uint8_t>);
REGISTER_OP_KERNEL(dequantize, MKLDNN, ::paddle::platform::CPUPlace, ops::DeQuantOpKernel<float>);
......@@ -14,7 +14,6 @@ limitations under the License. */
#include "paddle/fluid/operators/pool_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/framework/data_layout_transform.h"
namespace paddle {
namespace operators {
......@@ -72,7 +71,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
"It must use CPUPlace.");
std::cout<<"this is pool op"<<std::endl;
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
......@@ -130,16 +129,14 @@ std::cout<<"this is pool op"<<std::endl;
CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides,
padding_right_bottom);
}
mkldnn::memory::data_type dt = paddle::framework::ToMKLDNNDataType(input->type());
std::cout<<"input type = "<<dt<<std::endl;
auto src_md = platform::MKLDNNMemDesc(
src_tz, dt, input_format);
src_tz, platform::MKLDNNGetDataType<T>(), input_format);
/* create memory descriptor for pooling without specified format
* ('any') which lets a primitive (pooling in this case) choose
* the memory format preferred for best performance
*/
auto dst_md = platform::MKLDNNMemDesc(dst_tz, dt,
auto dst_md = platform::MKLDNNMemDesc(dst_tz, mkldnn::memory::f32,
mkldnn::memory::format::any);
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
......@@ -402,9 +399,6 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(pool2d, MKLDNN, ::paddle::platform::CPUPlace,
ops::PoolMKLDNNOpKernel<float>,
ops::PoolMKLDNNOpKernel<int8_t>,
ops::PoolMKLDNNOpKernel<uint8_t>);
ops::PoolMKLDNNOpKernel<float>);
REGISTER_OP_KERNEL(pool2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::PoolMKLDNNGradOpKernel<float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册