未验证 提交 5d29a27c 编写于 作者: O oyjxer 提交者: GitHub

[NPU] fix npu op elementwise_mul_grad (#31592)

上级 09bf2cfc
......@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_ASCEND_CL
#include <memory>
#include <string>
......@@ -58,17 +59,21 @@ class ElementwiseMulGradNPUKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
dx->mutable_data<T>(place);
dy->mutable_data<T>(place);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
auto dx_runner = NpuOpRunner("Mul", {*dout, *y}, {*dx}, {});
dx_runner.Run(stream);
auto dy_runner = NpuOpRunner("Mul", {*x, *dout}, {*dy}, {});
dy_runner.Run(stream);
if (dx) {
dx->mutable_data<T>(place);
auto dx_runner = NpuOpRunner("Mul", {*dout, *y}, {*dx}, {});
dx_runner.Run(stream);
}
if (dy) {
dy->mutable_data<T>(place);
auto dy_runner = NpuOpRunner("Mul", {*x, *dout}, {*dy}, {});
dy_runner.Run(stream);
}
}
};
......@@ -88,3 +93,4 @@ REGISTER_OP_NPU_KERNEL(
ops::ElementwiseMulGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::ElementwiseMulGradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
#endif
......@@ -122,6 +122,7 @@ class TestElementwiseMulNet(unittest.TestCase):
e = paddle.multiply(a, b)
f = paddle.multiply(c, d)
f.stop_gradient = True
g = paddle.multiply(e, f)
fc_1 = fluid.layers.fc(input=g, size=128)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册