提交 e1bba54c 编写于 作者: xiebaiyuan's avatar xiebaiyuan

shape op

上级 200c2c9e
......@@ -176,6 +176,7 @@ void Executor<Dtype, P>::LoadMemory(const framework::VarDesc var_desc,
type_size = 8;
break;
case framework::VARTYPE_TYPE_INT32:
memory = tensor->mutable_data<int32_t>();
type_size = 4;
break;
case framework::VARTYPE_TYPE_INT64:
......@@ -308,6 +309,9 @@ bool Executor<Dtype, P>::varInputMemory(
}
case framework::VARTYPE_TYPE_INT32: {
tensor = var->template GetMutable<framework::LoDTensor>();
tensor->template mutable_data<int32_t>();
is_mute_match = true;
break;
}
......
......@@ -22,7 +22,15 @@ namespace paddle_mobile {
namespace operators {
template <typename P>
void ShapeCompute(const ShapeParam<CPU>& param) {}
void ShapeCompute(const ShapeParam<CPU>& param) {
auto* in_t = param.InputX();
auto* out_t = param.Out();
auto out_data = out_t->mutable_data<int32_t>();
auto in_dims = in_t->dims();
for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = static_cast<int32_t>(in_dims[i]);
}
}
} // namespace operators
} // namespace paddle_mobile
......
......@@ -20,6 +20,10 @@ namespace paddle_mobile {
namespace operators {
template <typename DeviceType, typename T>
void ShapeOp<DeviceType, T>::InferShape() const {
PADDLE_MOBILE_ENFORCE(this->param_.InputX() != nullptr,
"Input (Input) of get_shape op should not be null.");
PADDLE_MOBILE_ENFORCE(this->param_.Out() != nullptr,
"Output (Out) of get_shape op should not be null.");
this->param_.Out()->Resize(this->param_.InputX()->dims());
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册