未验证 提交 5d29a27c 编写于 作者: O oyjxer 提交者: GitHub

[NPU] fix npu op elementwise_mul_grad (#31592)

上级 09bf2cfc
...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -58,18 +59,22 @@ class ElementwiseMulGradNPUKernel : public framework::OpKernel<T> { ...@@ -58,18 +59,22 @@ class ElementwiseMulGradNPUKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
dx->mutable_data<T>(place);
dy->mutable_data<T>(place);
auto stream = auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
if (dx) {
dx->mutable_data<T>(place);
auto dx_runner = NpuOpRunner("Mul", {*dout, *y}, {*dx}, {}); auto dx_runner = NpuOpRunner("Mul", {*dout, *y}, {*dx}, {});
dx_runner.Run(stream); dx_runner.Run(stream);
}
if (dy) {
dy->mutable_data<T>(place);
auto dy_runner = NpuOpRunner("Mul", {*x, *dout}, {*dy}, {}); auto dy_runner = NpuOpRunner("Mul", {*x, *dout}, {*dy}, {});
dy_runner.Run(stream); dy_runner.Run(stream);
} }
}
}; };
} // namespace operators } // namespace operators
...@@ -88,3 +93,4 @@ REGISTER_OP_NPU_KERNEL( ...@@ -88,3 +93,4 @@ REGISTER_OP_NPU_KERNEL(
ops::ElementwiseMulGradNPUKernel<paddle::platform::NPUDeviceContext, float>, ops::ElementwiseMulGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ElementwiseMulGradNPUKernel<paddle::platform::NPUDeviceContext, ops::ElementwiseMulGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>); paddle::platform::float16>);
#endif
...@@ -122,6 +122,7 @@ class TestElementwiseMulNet(unittest.TestCase): ...@@ -122,6 +122,7 @@ class TestElementwiseMulNet(unittest.TestCase):
e = paddle.multiply(a, b) e = paddle.multiply(a, b)
f = paddle.multiply(c, d) f = paddle.multiply(c, d)
f.stop_gradient = True
g = paddle.multiply(e, f) g = paddle.multiply(e, f)
fc_1 = fluid.layers.fc(input=g, size=128) fc_1 = fluid.layers.fc(input=g, size=128)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册