C402c39fbf5b03e76f032afe4a449cbe
从零实现机器学习参数服务器框架(三)

简介

序号 名称 简介
1 框架架构 介绍整体框架架构、涉及模块
2 模型、数据计算并行 实现参数服务器中的Worker与Server
3 节点间通信接口 实现应用常用接口,使用ZeroMQ实现节点间通信
4 一致性控制 在Server端实现常见的一致性控制模式:BSP/ASP/SSP
5 容错实现 使用CheckPoint Rollback、心跳监测等方式实现计算容错
6 Yarn资源管理系统支持 编写Yarn Application以使用Yarn进行资源分配
7 机器学习算法实现 实现Logistic Regression与KMeans两种算法

上一篇实现了参数服务器中最重要的两个角色: Worker与Server,它们在线程间是通过使用生产者消费者来实现的,然而对于处于不同节点之间的通信,本质上是通过ZeroMQ包装的网络Socket来实现的。

这一篇就会介绍节点间通信的基本实现,主要涉及的功能如下:

  • 使用ZeroMQ进行节点间的网络互连
  • 使用锁实现节点间的通信栏栅,以保证集群中各节点训练状态的同步

节点通信基础

cluster communication

cluster communication

以上是参数服务器框架各节点之间的通信链接,从图中可以看到,各个节点是互相持有对方的socket fd,也就是说每个节点即是客户端也是服务端。

Mailbox

comm/mailbox.hpp

class Mailbox : public AbstractMailbox {
    public:
    Mailbox(const Node &node, const std::vector<Node> &nodes, AbstractIdMapper *id_mapper, Engine *engine = nullptr);

    ...

    virtual int Send(const Message &msg) override;

    int Recv(Message *msg);

    // 启动mailbox,初始化通信
    void Start(const Node &master_node = {});

    void Stop(bool barrier = true);

    ...

    void Barrier(bool send = true);

    // 进入无限循环,接收其它节点的通信信息
    void StartReceiving();

    ...

    private:
    void Connect(const Node &node);

    void Bind(const Node &node);

    void Receiving();

    std::map<uint32_t, ThreadsafeQueue<Message> *const> queue_map_;
    ...
    // 使用map存储各个节点的socket fd
    std::unordered_map<uint32_t, void *> senders_;
    ...

};

Mailbox是节点间基础通信的抽象,它实现了建立网络连接、发送信息、接收信息等功能。网络通信的底层是依赖着ZeroMQ来实现的。

Mailbox::Start

comm/mailbox.cpp

```cpp
// 启动Mailbox
void Mailbox::Start(const Node &master_node) {
ConnectAndBind(master_node);
StartReceiving();
}

// 在各个节点建立连接
void Mailbox::ConnectAndBind(const Node &master_node) {
context_ = zmq_ctx_new();
CHECK(context_ != nullptr) << "create zmq context failed";
zmq_ctx_set(context_, ZMQ_MAX_SOCKETS, 65536);

Bind(node_);
VLOG(1) << "Finished binding";
for (const auto &node : nodes_) {
    Connect(node);
}
if (master_node.is_master) {
    Connect(master_node);
}
VLOG(1) << "Finished connecting";

}

void Mailbox::Bind(const Node &node) {
receiver_ = zmq_socket(context_, ZMQ_ROUTER);
CHECK(receiver_ != nullptr) << "create receiver socket failed: " << zmq_strerror(errno);
std::string address = "tcp://*:" + std::to_string(node.port);
// 使用ZMQ进行bind,实际上是使用到了socket的bind
if (zmq_bind(receiver_, address.c_str()) != 0) {
LOG(FATAL) << "bind to " + address + " failed: " << zmq_strerror(errno);
}
}

void Mailbox::Connect(const Node &node) {
auto it = senders_.find(node.id);
if (it != senders_.end()) {
zmq_close(it->second);
}
void *sender = zmq_socket(context_, ZMQ_DEALER);
CHECK(sender != nullptr) << zmq_strerror(errno);
std::string my_id = "ps" + std::to_string(node_.id);
zmq_setsockopt(sender, ZMQ_IDENTITY, my_id.data(), my_id.size());
std::string addr = "tcp://" + node.hostname + ":" + std::to_string(node.port);
if (zmq_connect(sender, addr.c_str()) != 0) {
LOG(FATAL) << "connect to " + addr + " failed: " << zmq_strerror(errno);
}
// 将通信句柄以node id为识别符保存在map中进行缓存,在下次通信时会使用到
senders_[node.id] = sender;
}

// 启动一个新的线程并调用Receiving方法
void Mailbox::StartReceiving() {
receiver_thread_ = std::thread(&Mailbox::Receiving, this);
}

void Mailbox::Receiving() {
VLOG(1) << "Start receiving";
while (true) {
Message msg;
int recv_bytes = Recv(&msg);

top Created with Sketch.