F68b8ce494847fc75548023d5f4c15d9
从零实现机器学习参数服务器框架(二)

简介

序号 名称 简介
1 框架架构 介绍整体框架架构、涉及模块
2 模型、数据计算并行 实现参数服务器中的Worker与Server
3 节点间通信接口 实现应用常用接口,使用ZeroMQ实现节点间通信
4 一致性控制 在Server端实现常见的一致性控制模式:BSP/ASP/SSP
5 容错实现 使用CheckPoint Rollback、心跳监测等方式实现计算容错
6 Yarn资源管理系统支持 编写Yarn Application以使用Yarn进行资源分配
7 机器学习算法实现 实现Logistic Regression与KMeans两种算法

这一篇将会依照之前的写作顺序详细介绍参数服务器中的模型并行与数据并行的实现,主要涉及到的功能如下:

  • WorkerThread/ServerThread: 参数服务器中的最重要的两个角色
  • IdMapper: 实现集群配置文件的读取与解析,通过唯一的节点id为Worker与Server等重要线程分配全局id
  • HDFSManager: 多个节点并行从HDFS中读取数据并保存在内存中
  • MapStorage/VectorStroage: 多个节点根据指定存储两种不同类型的参数 (vector与matrix)

Worker与Server

pslite architecture

pslite architecture

介绍WorkerThread/ServerThread之前,先看看ps-lite的设计,这里Worker Node对应着WorkerThread,ServerThread对应着Server Node,当一个机器学习应用启动后,会经过以下步骤:

  1. Worker Node开始从HDFS/文件等数据源中读取数据到内存中
  2. Worker Node开始梯度计算
  3. Worker Node将梯度通过Push推送到Server Node
  4. Server Node收到梯度,将内存中保存的参数进行更新
  5. Worker Node开始一个epoch并向Server Node发起Pull请求
  6. Server Node回应请求,将最新的参数返回给所有的Worker Node

这是参数服务器在训练时的典型步骤,我们要编写的框架也是严格遵循着这种设计。唯一不同的在我们的框架中,WorkerThread与ServerThread以线程的形式运行在相同的节点上,而ps-lite的设计是Worker与Server是以进程的形式运行在不同的节点上。

WorkerThread

worker/worker_thread.hpp

WorkerThread本质上是一个线程,它使用到了生产者-消费者的并发设计模式

class WorkerThread : public Actor, public AbstractCallbackRunner {
    public:
    WorkerThread(uint32_t worker_id) : Actor(worker_id) {}
    ...
    // 更新梯度时/Push会调用到这个接口    
    virtual void NewRequest(uint32_t app_thread_id, uint32_t model_id, uint32_t expected_responses) override;
    ...
    // Server/Pull时会通知Worker调用这个接口    
    virtual void AddResponse(uint32_t app_thread_id, uint32_t model_id, Message& msg) override;

    ...
    protected:
    void Main() override;

    private:
    ...
    // 用来保存当前Worker训练进度
    std::map<uint32_t, std::map<uint32_t, std::pair<uint32_t, uint32_t>>> tracker_;
    // 用来保存Worker的回调句柄
    std::map<uint32_t, std::map<uint32_t, std::function<void(Message& message)>>> recv_handle_;
    std::map<uint32_t, std::map<uint32_t, std::function<void()>>> recv_finish_handle_;
};

这里需要强调一下线程使用到的生产者-消费者模式,之后的ServerThread,SenderThread, MasterThread都使用到了这个模式进行线程间通信。

当WorkerThread线程启动后,将会调用到Main方法,随后将会进入一个无限循环:

worker/worker_thread.cpp

void WorkerThread::Main() {
    // 无限循环,只有收到了kExit的flag才会退出
    while (true) {
        Message msg;
        // 阻塞等待消息,如果新的消息该线程将会拿不到CPU时间片
        work_queue_.WaitAndPop(&msg);

        if (msg.meta.flag == Flag::kExit)
            break;

        ...
        // 处理消息
        AddResponse(msg.meta.recver, msg.meta.model_id, msg);
    }
}

目前WorkerThread收到的消息只有从ServerThread发来的Pull回应,用来获取到最新的参数数据,以便Worker根据数据继续计算梯度。

ServerThread

server/server_thread.hpp

class ServerThread : public Actor {
    public:
    ServerThread(uint32_t server_id) : Actor(server_id) {}

    void RegisterModel(uint32_t model_id, std::unique_ptr<AbstractModel> &&model);

    AbstractModel *GetModel(uint32_t model_id);

    // 在收到Push请求后,Server更新参数数据
    void UpdateModel(int failed_node_id, std::vector<Node> &nodes, third_party::Range &range);

    void RollbackModel() {
        for (auto it = models_.begin(); it != models_.end(); it++) {
            it->second->Restore();
        }
    }

    protected:
    virtual void Main() override;

    // 单个Server可以维护着多个不同的参数模型
    std::unordered_map<uint32_t, std::unique_ptr<AbstractModel>> models_;
};

和WorkerThread一样,这里详细分析下它的Main方法:

server/server_thread.cpp

void ServerThread::Main() {
    while (true) {
        Message msg;
        work_queue_.WaitAndPop(&msg);

        if (msg.meta.flag == Flag::kExit)
            break;

        uint32_t model_id = msg.meta.model_id;
        ...
        switch (msg.meta.flag) {
            case Flag::kClock: {
                models_[model_id]->Clock(msg);
                break;
            }
            case Flag::kAdd: {
                models_[model_id]->Add(msg);
                break;
            }
            case Flag::kGet: {
                models_[model_id]->Get(msg);
                break;
            }
            case Flag::kResetWorkerInModel: {
                models_[model_id]->ResetWorker(msg);
                break;
            }
            case Flag::kCheckpoint: {
                models_[model_id]->Dump(msg);
                break;
            }
            default:
                CHECK(false) << "Unknown flag in msg: " << FlagName[static_cast<int>(msg.meta.flag)];
        }
    }
}

可以看到ServerThread需要处理的信息种类比WorkerThread多的多。这里分别做下介绍,但是万变不离其宗,这些消息本质上都是围绕着参数的获取、新增、更新来定制的。

操作类型 简介
Add Push操作后更新参数梯度
Get Pull操作后获取到当前Server的参数
Clock Push操作完成后,通知所有节点进入下一个epoch
Reset 初始化各节点参数
Dump 用于容错的CheckPoint操作,将当前参数进行备份

到这里参数服务器中两个重要的角色Worker与Server就介绍完成了。我们现在已经知道了Worker与Server使用到了生产者消费者模式进行通信,这其中必然会涉及到唯一识别符的使用,这样才能够确保指定的消费者能够处理到对应的消息。IdMapper用来负责这项工作的。

IdMapper

IdMapper

IdMapper

IdMapper的实现简而言之是根据指定文件的节点id来划分各个节点中线程id的实现。前面已经提及,在我们的框架中,Worker与Server是以线程的形式运行在每个机器上的,这也意味着每个机器可以并行的运行多个Worker与Server。如上图所示id的划分,每个节点允许运行最多50个Server与Worker。

作为使用这个框架开发机器学习算法应用的开发者,只要指定节点id, ip address及对应port即可:

config/localnodes

0:localhost:14560
1:localhost:14561
2:localhost:14562
3:localhost:14563
4:localhost:14564

上面这是在本地下的伪集群环境,开启5个不同的进程担任不同的节点,每个节点拥有不同的port与id。框架会自动的解析配置文件,以此来初始化集群中各个节点的线程id。有了唯一识别符,不同节点间的Worker与Server就可以正常的进行通信了。

driver/simple_id_mapper.cpp

// 初始化Server与Worker Helper Id
void SimpleIdMapper::Init(int num_server_threads_per_node) {
    ...
    for (const auto &node : nodes_) {
        CHECK_LT(node.id, kMaxNodeId);

        for (int i = 0; i < num_server_threads_per_node; i++) {
            node2server_[node.id].push_back(node.id * kMaxThreadsPerNode + i);
        }

        node2worker_helper_[node.id].push_back(node.id * kMaxThreadsPerNode + kWorkerHelperThreadId);
    }
}

// 初始化Worker的Id
uint32_t SimpleIdMapper::AllocateWorkerThread(uint32_t node_id) {
    CHECK(node2worker_helper_.find(node_id) != node2worker_helper_.end());
    CHECK_LE(node2worker_[node_id].size(), kMaxThreadsPerNode - kMaxBgThreadsPerNode);

    for (int i = kMaxBgThreadsPerNode; i < kMaxThreadsPerNode; i++) {
        int tid = i + node_id * kMaxThreadsPerNode;
        if (node2worker_[node_id].find(tid) == node2worker_[node_id].end()) {
            node2worker_[node_id].insert(tid);
            return tid;
        }
    }
    CHECK(false);
    return -1;
}

值得注意的是这里node2worker_helper_的Id是直接分配给WorkerThread使用的。而node2worker_的id并没有直接的作用。

driver/engine.cpp

```cpp

top Created with Sketch.