Timee Product Team Blog

タイミー開発者ブログ

Vertex AI PipelinesとCloud Run jobsを使って機械学習バッチ予測とA/Bテストをシンプルに実現した話

こんにちは、タイミーでデータサイエンティストとして働いている小栗です。

今回は、機械学習バッチ予測およびA/BテストをVertex AI PipelinesとCloud Run jobsを使ってシンプルに実現した話をご紹介します。

経緯

タイミーのサービスのユーザーは2種類に大別されます。お仕事内容を掲載して働く人を募集する「事業者」と、お仕事に申し込んで働く「働き手」です。

今回、事業者を対象に機械学習を用いた予測を行い、予測結果を元にWebアプリケーション上で特定の処理を行う機能を開発することになりました。

要件としては以下を実現する必要がありました。

  1. 定期的なバッチ処理でのMLモデルの学習・予測
  2. MLモデルのA/Bテスト

最終的に、Vertex AI PipelinesとCloud Run jobsを活用したシンプルな構成でバッチ予測とA/Bテストを実現することにしました。

本記事では主に構成とA/B割り当ての仕組みをご紹介します。

構成

まず、全体構成とその構成要素についてご紹介します。

構成図

Webアプリケーション側の構成・実装についてもご紹介したいところですが、今回は機械学習に関係する部分に絞ってお話しします。

前提として、データサイエンス(以下DS)グループはGoogle CloudをベースとしたML基盤を構築しています。MLパイプライン等はCloud Composerに載せて統一的に管理しており、今回も例に漏れずワークフロー管理ツールとして採用しています。

MLパイプラインはVertex AI Pipelinesで実装しています。MLモデルのA/Bテストを実現するため、MLモデルごとにパイプラインを構築し、並行で稼働させています。同時に、それぞれのMLモデルの予測値と付随情報をBigQueryの予測結果テーブルに蓄積する責務もMLパイプラインに持たせています。

もちろんそれだけでは予測結果がテーブルに蓄積されるだけでA/Bテストは実現できないので、各事業者に対する予測結果のA/B割り当てをCloud Run jobsの責務とし、MLパイプライン実行の後段タスクとして実行しています。同時に、割り当て結果をBigQueryテーブルに出力する処理も実施します。

当初はA/B割り当てを含めたすべての責務をVertex AI Pipelinesに集約する案も議論の中で出たのですが、将来的に類似の取り組みにて実装や思想を使い回せそう等の理由から、取り回しのしやすいCloud Run jobsを採用しました。

Cloud Run jobsの利用はDSグループ内でも初めてではありましたが、グループ内のMLOpsエンジニアに相談・依頼してCloud Run jobs用CI/CDの導入などML基盤のアップデートを並行して進めていただくことで、スムーズに開発を進められました。

今回、モジュール間のデータやり取りのIFとしてBigQueryを採用していますが、読み込み・書き込みの操作に関しては、以前ご紹介した社内ライブラリを活用することでサクッと実装できました。

また、各処理には実行日時などの情報が必要なため、Cloud Composerのオペレータからパラメータを渡してキックする形にしています。

例えば、Cloud Run jobsは2023年のアップデートからジョブ構成のオーバーライドが可能になっており、それに併せてCloud Composer側でもCloudRunExecuteJobOperatorを介したオーバーライドが可能になったため、そちらを利用して必要なパラメータを実行時に渡しています。

さて、A/B割り当ての結果が出力されたのち、Webアプリケーション側はデータ連携用テーブルを参照して、事業者に対してバッチ処理を行います。残念ながら機能や施策の具体についてはご紹介できないのですが、機械学習の予測結果を元に事業者ごとに特定の処理を行う仕組みになっています。

A/B割り当ての仕組み

次に、Cloud Run jobsの中身で実施しているA/B割り当てについて、より具体的にご紹介します。

どう設定を管理するか

A/B割り当てに必要なパラメータはyamlファイルで指定する形にしています。例えば、実験期間や各MLモデルへの割り当て割合などです。

- experiment_name: str # 実験名。割り当てに用いるキーも兼ねる e.g. 'experiment_1'
  start_date: str # A/Bテストの開始日 e.g. '2024-08-01'
  end_date: str # A/Bテストの終了日 e.g. '2024-08-31'
  groups:
    - model_name: str # MLモデルの名前 e.g. 'model_1'
      weight: float # このMLモデルに割り当てる割合 e.g. 0.5
    - model_name: ...
- experiment_name: ...
  ...

この方法を用いる問題点として、”PyYAML”というライブラリを使えばyamlを読み込むこと自体は可能なのですが、開発者が想定していない形式でyamlが記述されるとエラーや予期せぬ挙動に繋がります。

当初の開発者以外がyamlファイルを更新することを見越して、ファイルの中身をバリデーションすることが望ましいと考えました。そこで、型・データのバリデーションが可能なライブラリである”Pydantic”を活用することにしました。

上記の形式のyamlファイルを安全にパースするために、以下のようなPydanticモデルクラスを定義しています。

# コードの一部を抜粋・簡略化して記載しています

import datetime

from pydantic import BaseModel, Field, field_validator

class ABTestGroup(BaseModel):
    model_name: str
    weight: float = Field(..., ge=0.0, le=1.0)

class ABTestExperiment(BaseModel):
    experiment_name: str
    start_date: datetime.date
    end_date: datetime.date
    groups: list[ABTestGroup]

    @field_validator('groups')
    @classmethod
    def validate_total_weight(cls, v: list[ABTestGroup]) -> list[ABTestGroup]:
        """
        各groupのweightの合計が1.0であることを確認する。
        """
        total_weight = sum(group.weight for group in v)
        if not math.isclose(total_weight, 1.0, rel_tol=1e-9):
            raise ValueError('Total weight of groups must be 1.0')
        return v

class ABTestConfig(BaseModel):
    experiments: list[ABTestExperiment]

例えば、weight(各グループに割り当てる事業者の割合)に対しては、型アノテーションとFieldを使って型と数値の範囲のバリデーションを実施しています。

加えて、各グループの割合の合計が1.0を超えることを避けるため、フィールドごとのカスタムバリデーションを定義可能なfield_validatorを使用し、独自のロジックでバリデーションを実装しました。

このようなPydanticを使ったバリデーション処理をML基盤のCIを通して呼び出すことにより、不適切なyamlファイルを事前に検知できるようにしています。

どう割り当てるか

A/B割り当てに関しては、事業者を適切にバケット(=グループ)に割り当てるために、事業者のIDとキーを使用しています。

# コードの一部を抜粋・簡略化して記載しています

import hashlib

def _compute_allocation_bucket(company_id: int, key: str, bucket_size: int) -> int:
    """
    company_idとkeyに基づき、company_idに対してバケットを割り当てる。
    keyはexperiment_nameなどを想定。
    """
    hash_key = key + str(company_id)
    hash_str = hashlib.sha256(hash_key.encode("utf-8")).hexdigest()
    return int(hash_str, 16) % bucket_size + 1

まず、実験名をkeyとして事業者のIDと結合し、ハッシュ化の元となる文字列を生成します。次に文字列をハッシュ化し、16進数の文字列を取得します。ハッシュ値を整数に戻した後、バケットサイズで割った余りをバケット番号とします。

その後、各MLモデルに対してweightに基づいてバケット範囲を割り当てることでMLモデルと事業者を紐付けます(コードは省略します)。

ややこしい点はありつつも上記のロジックにより、同じ事業者とキーの組み合わせに対して一貫して同じバケットを割り当てることができます。

実行のタイミングによって割り当てが変化する等の問題が生じず、A/Bテストの管理が容易になります。また、ハッシュ関数を使うことで入力値をほぼ均等に分散させることができ、実質的にランダムにグループ分けすることができます。*1

おわりに

機械学習バッチ予測とA/Bテストをシンプルに実現した話をご紹介しました。

今回、データサイエンティストとバックエンドエンジニアの共同開発については黎明期といった状況での開発でしたが(そこがまた楽しいのですが)、主にPdM含めた3人で協力しながら手探りで設計・実装を進めていきました。

その他にもMLOpsエンジニア、データエンジニアなど多くのポジションの方々に協力いただいており、部門を跨いでスムーズに協業できる組織体制が整ってきたことを開発を通して感じました。

今回ご紹介した構成にはやや課題が残っていたりするのですが、部門を横断しつつ解決を図っていけるのではと考えています。

We’re Hiring!

タイミーのデータエンジニアリング部・データアナリティクス部では、ともに働くメンバーを募集しています!!

現在募集中のポジションはこちらです!*2

「話を聞きたい」と思われた方は、是非一度カジュアル面談でお話ししましょう!

*1:ハッシュ関数を活用したA/B割り当てに関してはGunosyさんのブログ記事が分かりやすいです。

*2:募集中のエンジニア系のポジションはこちらです!