提交 f12355f7 编写于 作者: M Megvii Engine Team 提交者: huangxinda

fix(imperative/grad): fix hardcode dtype in subtensor_grad_rule

GitOrigin-RevId: 50da4af26dd4f0f0efe38f07573d704ea2fbe841
上级 4e4497b9
...@@ -35,9 +35,9 @@ std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) { ...@@ -35,9 +35,9 @@ std::shared_ptr<Tensor> broadcast_to(Tensor* x, Tensor* s) {
return python::apply(op, x, s)[0]; return python::apply(op, x, s)[0];
} }
std::shared_ptr<Tensor> make_tensor(CompNode cn, Tensor* shape, float v = 0) { std::shared_ptr<Tensor> make_empty_tensor(CompNode cn, Tensor* shape, DType dtype) {
HostTensorND scalar{cn, {{1}, dtype::Float32()}}; HostTensorND scalar{cn, {{1}, dtype}};
scalar.ptr<float>()[0] = v; std:memset(scalar.raw_ptr(), 0, dtype.size());
interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false); interpreter::Interpreter::Handle handle = interpreter_for_py->put(scalar, false);
auto&& t = std::make_shared<Tensor>(handle); auto&& t = std::make_shared<Tensor>(handle);
auto res = broadcast_to(t.get(), shape); auto res = broadcast_to(t.get(), shape);
...@@ -117,7 +117,7 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak ...@@ -117,7 +117,7 @@ apply_result_t subtensor_grad_rule(ApplyContext& ctx, CustomBackward::Maker& mak
apply_result_t ret(1); apply_result_t ret(1);
if (grad && inputs[0]) { if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1); SmallVector<Tensor*> args_(inputs.size()+1);
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); auto&& zeros = make_empty_tensor(grad->comp_node(), inputs[0].get(), grad->dtype());
args_[0] = zeros.get(); args_[0] = zeros.get();
args_[1] = grad; args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
...@@ -147,7 +147,7 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward: ...@@ -147,7 +147,7 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward:
apply_result_t ret(1); apply_result_t ret(1);
if (grad && inputs[0]) { if (grad && inputs[0]) {
SmallVector<Tensor*> args_(inputs.size()+1); SmallVector<Tensor*> args_(inputs.size()+1);
auto&& zeros = make_tensor(grad->comp_node(), inputs[0].get()); auto&& zeros = make_empty_tensor(grad->comp_node(), inputs[0].get(), grad->dtype());
args_[0] = zeros.get(); args_[0] = zeros.get();
args_[1] = grad; args_[1] = grad;
for (size_t i = 1; i < inputs.size(); ++i) { for (size_t i = 1; i < inputs.size(); ++i) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册