diff --git a/paddle/gserver/layers/SwitchOrderLayer.cpp b/paddle/gserver/layers/SwitchOrderLayer.cpp index 5b2cfe3c48d65d85d792d8817afd444b00561104..92cd61cdd515d5c693df086c9575a5f197c00cee 100644 --- a/paddle/gserver/layers/SwitchOrderLayer.cpp +++ b/paddle/gserver/layers/SwitchOrderLayer.cpp @@ -32,11 +32,11 @@ bool SwitchOrderLayer::init(const LayerMap& layerMap, outDims_ = TensorShape(4); auto& reshape_conf = config_.reshape_conf(); - for (int i = 0; i < reshape_conf.heightaxis_size(); i++) { - heightAxis_.push_back(reshape_conf.heightaxis(i)); + for (int i = 0; i < reshape_conf.height_axis_size(); i++) { + heightAxis_.push_back(reshape_conf.height_axis(i)); } - for (int i = 0; i < reshape_conf.widthaxis_size(); i++) { - widthAxis_.push_back(reshape_conf.widthaxis(i)); + for (int i = 0; i < reshape_conf.width_axis_size(); i++) { + widthAxis_.push_back(reshape_conf.width_axis(i)); } createFunction(nchw2nhwc_, "NCHW2NHWC", FuncConfig()); createFunction(nhwc2nchw_, "NHWC2NCHW", FuncConfig()); diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index e0c14ad5b512c7329062a5426ef34844ec268020..d1f3bc241fa621cb0070125980996e8627e40fd6 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -2019,10 +2019,10 @@ TEST(Layer, SwitchOrderLayer) { img->set_img_size_y(16); ReshapeConfig* reshape = config.layerConfig.mutable_reshape_conf(); - reshape->add_heightaxis(0); - reshape->add_heightaxis(1); - reshape->add_heightaxis(2); - reshape->add_widthaxis(3); + reshape->add_height_axis(0); + reshape->add_height_axis(1); + reshape->add_height_axis(2); + reshape->add_width_axis(3); // config softmax layer config.layerConfig.set_type("switch_order"); diff --git a/proto/ModelConfig.proto b/proto/ModelConfig.proto index 0f44d8cb8d78ed23cc1105ac7aff37de5faeffa1..7d7fc23a4691646dfce4c162a445864c748501d9 100644 --- a/proto/ModelConfig.proto +++ b/proto/ModelConfig.proto @@ -288,8 +288,8 @@ message PadConfig { } message ReshapeConfig { - repeated uint32 heightAxis = 1; - repeated uint32 widthAxis = 2; + repeated uint32 height_axis = 1; + repeated uint32 width_axis = 2; } message MultiBoxLossConfig {