ソケットプログラミング - TCP/IPサーバの例

RubyなどのLLでソケット通信のプログラムを書いていると、あまりにも簡単に書けすぎて、ついついC言語でのソケット通信の書き方を忘れてしまいます。

ということで、おさらいのためにシンプルなechoサーバをCで書いてみました。おおまかな説明はソース中にコメントで書いてありますので省きます。

server.c
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <errno.h>
#include <poll.h>
#include <signal.h>
#include <sys/socket.h>
#include <netdb.h>

#define BUFFER_SIZE 512
#define READ_TIMEOUT 10000      // 10 seconds

/* get*info系のエラーメッセージを出力する */
void put_gaierror(const char *tag, int error_code)
{
    // errnoにエラーが格納されている場合もある
    if (error_code == EAI_SYSTEM) {
        perror(tag);
    } else {
        fprintf(stderr, "%s: %s\n", tag, gai_strerror(error_code));
    }
    return;
}

/* アドレス構造体の情報を可読化して出力する */
void print_nameinfo(const char *tag, const struct sockaddr *addr,
        socklen_t addrlen)
{
    char host[NI_MAXHOST];
    char serv[NI_MAXSERV];
    int error_code;

    // ホスト名もサービス名も逆引きしない
    error_code = getnameinfo(addr, addrlen, host, sizeof(host),
            serv, sizeof(serv), NI_NUMERICHOST | NI_NUMERICSERV);
    if (error_code != 0) {
        put_gaierror("getnameinfo", error_code);
        strcpy(host, "***");
        strcpy(serv, "***");
    }

    printf("%s %s (%s)\n", tag, host, serv);

    return;
}

/* 接続待ち受けソケットを開く */
int open_socket(const char *service)
{
    struct addrinfo hints;
    struct addrinfo *addr_set;
    struct addrinfo *addr;
    int error_code;
    int sock = -1;
    int on = 1;

    memset(&hints, 0, sizeof(hints));
    hints.ai_flags = AI_PASSIVE;        // 全てのネットワークアドレスで待ち受ける
    hints.ai_family = AF_UNSPEC;        // 任意のプロトコルファミリーで待ち受ける
    hints.ai_socktype = SOCK_STREAM;    // TCPを指定

    error_code = getaddrinfo(NULL, service, &hints, &addr_set);
    if (error_code != 0) {
        put_gaierror("getaddrinfo", error_code);
        return -1;
    }

    // 順番にソケット作成を試行
    for (addr = addr_set; addr; addr = addr->ai_next) {
        sock = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
        if (sock < 0) {
            perror("socket");
            continue;
        }

        // TIME_WAIT時もすぐにbindできるようにSO_REUSEADDRオプションを指定
        if (setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) != 0) {
            perror("setsockopt");       // このエラーはスルー
        }

        if (bind(sock, addr->ai_addr, addr->ai_addrlen) != 0) {
            perror("bind");
            close(sock);
            sock = -1;
            continue;
        }

        if (listen(sock, SOMAXCONN) != 0) {
            perror("listen");
            close(sock);
            sock = -1;
            continue;
        }

        // 成功
        print_nameinfo("Listening", addr->ai_addr, addr->ai_addrlen);
        break;
    }

    freeaddrinfo(addr_set);

    return sock;
}

/* クライアントとの通信のメインルーチン */
int run(int sock)
{
    struct pollfd read_fd;
    int poll_count;
    char buf[BUFFER_SIZE];
    char *write_ptr;
    ssize_t read_len;
    ssize_t write_len;

    memset(&read_fd, 0, sizeof(read_fd));
    read_fd.fd = sock;
    read_fd.events = POLLIN;

    // 単純に延々と受け取ったデータをオウム返しするだけ
    while (1) {
        // 30秒間データ受信を待つ
        poll_count = poll(&read_fd, 1, READ_TIMEOUT);
        if (poll_count < 0) {
            if (errno == EINTR) {       // 割り込みによる中断の場合はやり直し
                continue;
            }
            perror("poll");
            close(sock);
            return -1;
        } else if (poll_count == 0) {  // タイムアウトしたら終了
            printf("Connection timeout.\n");
            break;
        }
        // POLLINを一つpollしているだけなので、特にreventsのチェックはしない

        read_len = read(sock, buf, sizeof(buf));
        if (read_len < 0) {
            if (errno == EINTR) {       // 割り込みによる中断の場合はやり直し
                continue;
            }
            perror("read");
            close(sock);
            return -1;
        } else if (read_len == 0) {     // EOF
            printf("Connection closed.\n");
            break;
        }

        // 面倒なのでしていないが、本当はPOLLOUTもpollすべき
        write_ptr = buf;
        while (read_len > 0) {
            write_len = write(sock, write_ptr, read_len);
            if (write_len < 0) {
                if (errno == EINTR) {   // 割り込みによる中断の場合はやり直し
                    continue;
                }
                perror("write");
                close(sock);
                return -1;
            }
            write_ptr += write_len;
            read_len -= write_len;
        }
    }

    close(sock);

    return 0;
}

/* シグナルによる中断を検出するためのフラグ */
static int _interrupted = 0;

/* シグナルハンドラ内では中断フラグをセットするだけ */
void sigint_action(int signum)
{
    _interrupted = 1;
    return;
}

/* サーバのメインルーチン */
int start_server(int accept_socket)
{
    struct sigaction action;

    // Ctrl+Cで終了するようにシグナルの動作を設定する
    memset(&action, 0, sizeof(action));
    action.sa_handler = sigint_action;
    action.sa_flags = SA_RESETHAND;     // 一度呼ばれたらデフォルト動作に戻す
    if (sigaction(SIGINT, &action, NULL) != 0) {
        perror("sigaction");
        return -1;
    }

    // 終了した子プロセスがゾンビプロセスにならないように設定する
    memset(&action, 0, sizeof(action));
    action.sa_handler = SIG_IGN;
    action.sa_flags = SA_NOCLDWAIT;
    if (sigaction(SIGCHLD, &action, NULL) != 0) {
        perror("sigaction");
        return -1;
    }

    // シグナルで中断されるまでループを続ける
    while (!_interrupted) {
        // sockaddr_storage構造体は
        // 全てのアドレスファミリー用のsockaddr構造体を
        // 収容できるサイズを持つことが保証されているため、
        // プロトコル非依存の通信ではこれを用いる
        struct sockaddr_storage addr;
        socklen_t addrlen;
        int connection_socket;
        pid_t pid;

        // クライアントからの接続待ち受け
        addrlen = sizeof(addr);
        connection_socket = accept(accept_socket, (struct sockaddr*)&addr,
                &addrlen);
        if (connection_socket < 0) {
            if (errno == EINTR) {       // 割り込みによる中断の場合はやり直し
                continue;
            }
            perror("accept");
            break;
        }
        print_nameinfo("Connected from", (struct sockaddr*)&addr, addrlen);

        // クライアントとの通信は別プロセスで処理
        pid = fork();
        if (pid < 0) {
            perror("fork");
            close(connection_socket);
            break;
        } else if (pid == 0) {
            // 子プロセスは接続待ち受けソケットを閉じて
            // メインルーチンを開始する
            // 処理が終わったらプロセス終了
            int status;
            close(accept_socket);
            status = run(connection_socket);
            exit(status);
        }
        close(connection_socket);       // 親プロセスは引き続き接続待ち受け
    }

    if (_interrupted) {
        printf("Interrupted.\n");
    }

    // まだ生きている子プロセスは全部終了させる
    // 自分がSIGTERMで死なないように一回だけ無視させる
    memset(&action, 0, sizeof(action));
    action.sa_handler = SIG_IGN;
    action.sa_flags = SA_RESETHAND;
    if (sigaction(SIGTERM, &action, NULL) != 0) {
        perror("sigaction");
    }
    if (killpg(0, SIGTERM) != 0) {
        perror("killpg");
    }

    return 0;
}

int main(int argc, char **argv)
{
    int sock;
    int status;

    if (argc < 2) {
        printf("usage: %s <service>\n", argv[0]);
        return -1;
    }

    sock = open_socket(argv[1]);
    if (sock < 0) {
        fprintf(stderr, "cannot open socket.\n");
        return -1;
    }

    status = start_server(sock);

    close(sock);

    printf("Done.\n");

    return status;
}

長っ!

しかもこれでもまだポーリングやシグナルの扱いに穴があるので、常時起動させておくサーバ用途なんかには使えません。

ちなみに

似たようなサーバをRubyで書いてみると…。

tcpserver.rb
#!/usr/bin/env ruby
# coding: utf-8

require 'socket'

READ_TIMEOUT = 10

if ARGV.empty?
  puts "Usage: #{$0} <service>"
  exit -1
end

server = TCPServer.open(ARGV.shift)
addr = server.addr
puts "Listening #{addr[3]} (#{addr[1]})"

begin
  loop do
    Thread.start(server.accept) do |io|
      peer = io.peeraddr
      puts "Connected from #{peer[3]} (#{peer[1]})"
      begin
        loop do
          reads, writes, excepts = IO.select([io], nil, nil, READ_TIMEOUT)
          raise 'Connection timeout.' unless reads
          raise 'Connection closed.' unless line = io.gets
          io.write line
        end
      rescue RuntimeError
        puts $!.message
      end
      io.close
    end
  end
rescue Interrupt
  puts 'Interrupted.'
end

server.close
puts 'Done.'

こんなに短く書けちゃいます。*1

または、

#!/usr/bin/env ruby
# coding: utf-8

require 'gserver'

READ_TIMEOUT = 10

if ARGV.empty?
  puts "Usage: #{$0} <service>"
  exit -1
end

class EchoServer < GServer
  def serve(io)
    loop do
      reads, writes, excepts = IO.select([io], nil, nil, READ_TIMEOUT)
      raise 'Connection timeout.' unless reads
      raise 'Connection closed.' unless line = io.gets
      io.write line
    end
  rescue RuntimeError
    puts $!.message
  end
end

server = EchoServer.new(ARGV.shift)
server.audit = true
server.start
begin
  sleep
rescue Interrupt
  puts 'Interrupted.'
end
server.stop
puts 'Done.'

こういう書き方もあります。

個人的には後者のほうが読みやすいかな?と思いますが、このレベルの違いなら好みで選んでもいいかもしれませんね。

*1:マルチプロセス/マルチスレッドの違いなどの細かな相違点はありますが。