提交 94b0ebd7 编写于 作者: W weihaoji

[XPU] bugfix on pow, ew_div and reshape

test=develop test=xpu
上级 aff26fa1
......@@ -36,11 +36,10 @@ void FillConstantCompute::Run() {
if (param.dtype == static_cast<int32_t>(lite::core::FluidType::FP32)) {
auto data = param.out->template mutable_data<float>(TARGET(kXPU));
value.fp32 = param.value;
write_size = write_size * sizeof(float);
int r = xdnn::memset(ctx.GetRawContext(), /* context */
reinterpret_cast<void*>(data),
value.int32,
write_size);
int r = xdnn::memset_4_byte(ctx.GetRawContext(), /* context */
reinterpret_cast<void*>(data),
value.int32,
write_size);
CHECK_EQ(r, 0);
} else if (param.dtype ==
......@@ -48,22 +47,17 @@ void FillConstantCompute::Run() {
auto data = param.out->template mutable_data<int32_t>(TARGET(kXPU));
value.int32 = param.value;
write_size = write_size * sizeof(int32_t);
int r = xdnn::memset(ctx.GetRawContext(), /* context */
reinterpret_cast<void*>(data),
value.int32,
write_size);
int r = xdnn::memset_4_byte(ctx.GetRawContext(), /* context */
reinterpret_cast<void*>(data),
value.int32,
write_size);
CHECK_EQ(r, 0);
} else if (param.dtype == static_cast<int32_t>(lite::core::FluidType::INT8)) {
auto data = param.out->template mutable_data<int8_t>(TARGET(kXPU));
value.int32 = 0;
for (int i = 0; i < 4; i++) {
value.int32 += static_cast<int32_t>(param.value);
value.int32 = value.int32 << 8;
}
int r = xdnn::memset(ctx.GetRawContext(), /* context */
reinterpret_cast<void*>(data),
value.int32,
param.value,
write_size);
CHECK_EQ(r, 0);
} else {
......
......@@ -11,15 +11,44 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/xpu/reshape_compute.h"
#include "lite/backends/xpu/xpu_header_sitter.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace xpu {
void ReshapeCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->As<XPUContext>();
auto x = param.x;
auto output = param.output;
auto output_dims = output->dims();
if (param.inplace) {
output->ShareDataWith(*x);
output->Resize(output_dims);
} else {
int r = xdnn::memcpy_device(ctx.GetRawContext(),
param.output->mutable_data<float>(TARGET(kXPU)),
param.x->data<float>(),
param.x->numel() * sizeof(float));
CHECK_EQ(r, 0);
}
}
} // namespace xpu
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(reshape2,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::Reshape2Compute<float>,
paddle::lite::kernels::xpu::ReshapeCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kXPU))})
......@@ -27,3 +56,15 @@ REGISTER_LITE_KERNEL(reshape2,
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
REGISTER_LITE_KERNEL(reshape,
kXPU,
kFloat,
kNCHW,
paddle::lite::kernels::xpu::ReshapeCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("ShapeTensor", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindInput("Shape", {LiteType::GetTensorTy(TARGET(kXPU))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kXPU))})
.Finalize();
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
......@@ -20,29 +21,13 @@ namespace lite {
namespace kernels {
namespace xpu {
template <typename T>
class Reshape2Compute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
class ReshapeCompute : public KernelLite<TARGET(kXPU), PRECISION(kFloat)> {
public:
using param_t = operators::ReshapeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto x = param.x;
auto output = param.output;
auto xshape = param.xshape;
auto x_dims = x->dims();
auto x_dims_data = x_dims.Vectorize();
auto out_dims = output->dims();
output->ShareDataWith(*x);
output->Resize(out_dims);
auto* xshape_data = xshape->mutable_data<int64_t>(TARGET(kXPU));
TargetWrapperXPU::MemcpySync(xshape_data,
x_dims_data.data(),
x_dims.size() * sizeof(int64_t),
IoDirection::HtoD);
}
virtual void Run();
virtual ~Reshape2Compute() = default;
virtual ~ReshapeCompute() = default;
};
} // namespace xpu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册