提交 3cf50b74 编写于 作者: S seiriosPlus

fix fuse

上级 66321576
......@@ -82,7 +82,7 @@ class LargeScaleFuseAdamOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Beta1Pow", "(Tensor) Input beta1 power accumulator");
AddInput("Beta2Pow", "(Tensor) Input beta2 power accumulator");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddOutput("Beta1PowOut", "(Tensor) Output beta1 power accumulator");
AddOutput("Beta2PowOut", "(Tensor) Output beta2 power accumulator");
......@@ -150,6 +150,4 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(
lookup_sparse_table_fuse_adam,
ops::LargeScaleFuseAdamOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LargeScaleFuseAdamOpKernel<paddle::platform::CPUDeviceContext,
double>);
ops::LargeScaleFuseAdamOpKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
......@@ -37,8 +38,9 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
using paddle::framework::LoDTensor;
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
const auto *grad_var = ctx.InputVar("Grad");
PADDLE_ENFORCE(
......@@ -56,8 +58,8 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows *grad_merge_ptr;
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(context.template device_context<DeviceContext>(), *in_grad,
math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
merge_func(ctx.template device_context<platform::CPUDeviceContext>(), grad,
&tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge;
......@@ -71,8 +73,8 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
auto grad_width = grad_v.dims()[1];
// auto is_entry = context.Attr<bool>("is_entry");
auto tablename = context.Attr<std::string>("tablename");
auto value_names = Attr<std::vector<std::string>>("value_names");
auto tablename = ctx.Attr<std::string>("tablename");
auto value_names = ctx.Attr<std::vector<std::string>>("value_names");
auto *beta1_pow = ctx.Input<LoDTensor>("Beta1Pow");
auto *beta2_pow = ctx.Input<LoDTensor>("Beta2Pow");
......@@ -116,11 +118,11 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
auto &moment_1 = values[1];
auto &moment_2 = values[2];
T lr = *lr_;
T lr_ = lr[0];
T beta1_ = beta1_pow->data<T>()[0];
T beta2_ = beta2_pow->data<T>()[0];
lr *= sqrt(1 - beta1_) / (1 - beta2_);
lr_ *= sqrt(1 - beta1_) / (1 - beta2_);
for (size_t i = 0; i < in_rows.size(); i++) {
auto *m1_data = moment_1[i]->data();
......@@ -131,7 +133,7 @@ class LargeScaleFuseAdamOpKernel<platform::CPUDeviceContext, T>
auto g = grad_v.data<T>()[grad_width * i + x];
m1_data[x] = beta1_ * m1_data[x] + (1 - beta1_) * g;
m2_data[x] = beta2_ * m2_data[x] + (1 - beta2_) * g * g;
p_data[x] -= lr * (m1_data[x] / (sqrt(m2_data[x]) + epsilon));
p_data[x] -= lr_ * (m1_data[x] / (sqrt(m2_data[x]) + epsilon));
}
}
}
......
......@@ -79,7 +79,7 @@ class LargeScaleFuseSGDOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Grad",
"(SelectedRows) Ids's type should be SelectedRows"
"THe ids to be looked up in W.");
AddInput("LearningRate", "(Tensor) Learning rate of SGD");
AddAttr<bool>("is_entry",
"(bool)"
"sparse table need entry");
......@@ -117,5 +117,4 @@ REGISTER_OPERATOR(
REGISTER_OP_CPU_KERNEL(
lookup_sparse_table_fuse_sgd,
ops::LargeScaleFuseSGDOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::LargeScaleFuseSGDOpKernel<paddle::platform::CPUDeviceContext, double>);
ops::LargeScaleFuseSGDOpKernel<paddle::platform::CPUDeviceContext, float>);
......@@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/distributed/large_scale_kv.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/selected_rows_functor.h"
namespace paddle {
namespace operators {
......@@ -56,8 +57,8 @@ class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T>
framework::SelectedRows tmp_grad_merge;
const framework::SelectedRows *grad_merge_ptr;
math::scatter::MergeAdd<DeviceContext, T> merge_func;
merge_func(context.template device_context<DeviceContext>(), *in_grad,
math::scatter::MergeAdd<platform::CPUDeviceContext, T> merge_func;
merge_func(ctx.template device_context<platform::CPUDeviceContext>(), grad,
&tmp_grad_merge, true);
grad_merge_ptr = &tmp_grad_merge;
......@@ -71,8 +72,8 @@ class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T>
auto grad_width = grad_v.dims()[1];
// auto is_entry = context.Attr<bool>("is_entry");
auto tablename = context.Attr<std::string>("tablename");
auto value_names = Attr<std::vector<std::string>>("value_names");
auto tablename = ctx.Attr<std::string>("tablename");
auto value_names = ctx.Attr<std::vector<std::string>>("value_names");
std::vector<std::vector<std::vector<float> *>> values;
std::vector<int64_t> dims;
......@@ -88,15 +89,16 @@ class LargeScaleFuseSGDOpKernel<platform::CPUDeviceContext, T>
auto &params = values[0];
auto blas = math::GetBlas<DeviceContext, T>(context);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(ctx);
std::vector<T> grads;
framework::TensorToVector(grad_v, context.device_context(), &grads);
framework::TensorToVector(grad_v, ctx.device_context(), &grads);
blas.VMUL(grads, lr[0], grads);
blas.SCAL(grads.size(), lr[0], grads.data());
for (int x = 0; x < static_cast<int>(in_rows.size()); ++x) {
blas.VSUB(grad_width, params[x], grads.data() + grad_width * x, params);
blas.VSUB(grad_width, params[x]->data(), grads.data() + grad_width * x,
params[x]->data());
}
}
};
......
......@@ -657,13 +657,15 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
if op.type == "sgd":
grad = main_program.global_block().vars[op.input("Grad")[0]]
lr = main_program.global_block().vars[op.input("LearningRate")[0]]
# remove origin optimzier op
block._remove_op(opt_idx)
block._insert_op(
opt_idx,
type="lookup_sparse_table_fuse_sgd",
inputs={"Grad": grad},
inputs={"Grad": grad,
"LearningRate": lr},
attrs={
"is_entry": is_entry,
"tablename": table_name,
......@@ -672,6 +674,7 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
elif op.type == "adam":
grad = main_program.global_block().vars[op.input("Grad")[0]]
lr = main_program.global_block().vars[op.input("LearningRate")[0]]
beta1_pow = main_program.global_block().vars[op.input("Beta1Pow")[
0]]
beta2_pow = main_program.global_block().vars[op.input("Beta2Pow")[
......@@ -693,6 +696,7 @@ def large_scale_sparse_pass(program, main_program, config, is_startup=False):
type="lookup_sparse_table_fuse_adam",
inputs={
"Grad": grad,
"LearningRate": lr,
"Beta1Pow": beta1_pow,
"Beta2Pow": beta2_pow
},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册