提交 ff6d3d28 编写于 作者: xiebaiyuan's avatar xiebaiyuan

bilinear_interp_impl

上级 9f6f5166
......@@ -20,8 +20,25 @@ namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void BilinearOp<DeviceType, T>::InferShape() const {
// todo check
this->param_.Out()->Resize(this->param_.InputX()->dims());
PADDLE_MOBILE_ENFORCE(this->param_.InputX() != nullptr,
"Input(X) of BilinearInterOp should not be null.");
PADDLE_MOBILE_ENFORCE(this->param_.Out() != nullptr,
"Output(Out) of BilinearInterOp should not be null.");
auto dim_x = this->param_.InputX()->dims(); // NCHW format
int out_h = this->param_.OutH();
int out_w = this->param_.OutW();
PADDLE_MOBILE_ENFORCE(dim_x.size() == 4, "X's dimension must be 4");
if (this->param_.InputOutPutSize() != nullptr) {
auto out_size_dim = this->param_.InputOutPutSize()->dims();
PADDLE_MOBILE_ENFORCE(out_size_dim.size() == 1,
"OutSize's dimension size must be 1");
PADDLE_MOBILE_ENFORCE(out_size_dim[0] == 2, "OutSize's dim[0] must be 2");
}
std::vector<int64_t> dim_out({dim_x[0], dim_x[1], out_h, out_w});
this->param_.Out()->Resize(framework::make_ddim(dim_out));
}
} // namespace operators
......
......@@ -22,7 +22,71 @@ namespace paddle_mobile {
namespace operators {
template <typename P>
void BilinearInterpCompute(const BilinearInterpParam<CPU>& param) {}
void BilinearInterpCompute(const BilinearInterpParam<CPU>& param) {
auto out_dims = param.Out()->dims();
auto* input = param.InputX()->data<float>();
auto out_size_t = param.InputOutPutSize();
int out_h = param.OutH();
int out_w = param.OutW();
if (out_size_t != nullptr) {
auto out_size_data = out_size_t->data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
auto* output = param.Out()->mutable_data<float>(
{out_dims[0], out_dims[1], out_h, out_w});
auto batch_size = param.InputX()->dims()[0];
auto channels = param.InputX()->dims()[1];
auto in_h = param.InputX()->dims()[2];
auto in_w = param.InputX()->dims()[3];
auto in_hw = in_h * in_w;
auto out_hw = out_h * out_w;
auto in_chw = channels * in_hw;
auto out_chw = channels * out_hw;
float ratio_h =
(out_h > 1) ? static_cast<float>(in_h - 1) / (out_h - 1) : 0.f;
float ratio_w =
(out_w > 1) ? static_cast<float>(in_w - 1) / (out_w - 1) : 0.f;
if (in_h == out_h && in_w == out_w) {
memcpy(output, input, param.InputX()->numel() * sizeof(float));
} else {
for (int k = 0; k < batch_size; ++k) { // loop for batches
for (int i = 0; i < out_h; ++i) { // loop for images
int h = ratio_h * i;
int hid = (h < in_h - 1) ? 1 : 0;
float h1lambda = ratio_h * i - h;
float h2lambda = 1.f - h1lambda;
for (int j = 0; j < out_w; ++j) {
int w = ratio_w * j;
int wid = (w < in_w - 1) ? 1 : 0;
float w1lambda = ratio_w * j - w;
float w2lambda = 1.f - w1lambda;
// calculate four position for bilinear interpolation
const float* in_pos = &input[k * in_chw + h * in_w + w];
float* out_pos = &output[k * out_chw + i * out_w + j];
for (int c = 0; c < channels; ++c) { // loop for channels
// bilinear interpolation
out_pos[0] = static_cast<float>(
h2lambda * (w2lambda * in_pos[0] + w1lambda * in_pos[wid]) +
h1lambda * (w2lambda * in_pos[hid * in_w] +
w1lambda * in_pos[hid * in_w + wid]));
in_pos += in_hw;
out_pos += out_hw;
}
}
}
}
}
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -2336,14 +2336,21 @@ class BilinearInterpParam : public OpParam {
input_x_ = InputXFrom<GType>(inputs, scope);
input_outsize_ = InputOutSizeFrom<GType>(inputs, scope);
out_ = OutFrom<GType>(outputs, scope);
out_h_ = GetAttr<int>("out_h", attrs);
out_w_ = GetAttr<int>("out_w", attrs);
}
const RType *InputX() const { return input_x_; }
const RType *InputOutPutSize() const { return input_outsize_; }
RType *Out() const { return out_; }
int OutH() const { return out_h_; }
int OutW() const { return out_w_; }
private:
RType *input_x_;
RType *input_outsize_;
RType *out_;
int out_h_;
int out_w_;
};
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册