Re:ゼロから始めるML生活

どちらかといえばエミリア派です

BigQueryで始めるt検定

f:id:nogawanogawa:20211017144231p:plain

BQを使っているときに、かんたんな検定であればBQ内で完結したくなります。 これが実現できないものかと調べてみたところ、こちらの記事を拝見しました。

lab.mo-t.com

ちょっとこれを実際に試してみたくなったので、実際にやってみたいと思います。

やりたいこと

なぜBQで完結させたいか?

A/Bテスト時のログの集計をしていると、評価指標の大小は四則演算の延長で計算できますが、評価指標が有意に差がついているかどうかを確認したくなります。 普通にやろうとすると、

  1. これらのデータをローカルにダウンロード
  2. PythonやRなどを使用して検定のスクリプトを実行

みたいなことが必要になるわけですが、この作業はちょっと手間です。

また、BQでの集計結果をそのままダッシュボードに連携させて表示することもあるので、検定の処理もBQで処理できると集計結果と検定の結果を合わせてダッシュボード上で管理できるため管理コストが下がります。

そんなこんなで、BQ上で統計検定ができると非常に便利だというわけです。

理屈: UDFを使う

ただ、普通のSQL処理で統計検定を計算するのはちょっと難しそうです。 そこで、jstatを使用したUDFを作って対応します。

cloud.google.com

UDFはユーザー定義関数のことで、SQLにおける関数です。 BigQueryでは、この関数定義の中でJavaScriptを呼び出すことができます。 そして、都合が良いことに、JavaScriptには統計検定を行うことができるライブラリがあり、それを呼び出すことで実現できそうです。

github.com

(ほんと、世の中便利なものがあるもんだなあ)

要するに、UDFの中で統計検定のライブラリの呼び出しを定義してあげて、そのUDFをクエリの中で呼び出すことが、今回やりたいことになります。

やってみる

準備

今回はjstat.jsのパッケージを使用することにします。

jstat.github.io

このパッケージをダウンロードして、jstat.jsという名前で保存します。

次に、これをBQからアクセスできるGCSの適当なバケットに格納します。 そして、このファイルのパスのところを控えておきます。

f:id:nogawanogawa:20211018214035p:plain

これをBQからアクセスして使うことになります。後にあるSQLの"your_bucket"のところを、ご自分のパスに書き換えれば使えるはずです。

対応のあるt検定

さて、準備は終わったので実際に検定をやってみたいと思います。 とりあえず必要になりそうだったt検定をやってみます。

t検定には対応のあるt検定対応のないt検定があり、 まずは対応のあるt検定をやってみます。 詳しい違いはこちらがわかりやすかったので、ご参照ください。

d1.gmobb.jp

クエリ

こんな感じのtemporary functionを書いてみます。

CREATE TEMPORARY FUNCTION tscore_to_p(a FLOAT64, b FLOAT64, c FLOAT64)
 RETURNS FLOAT64
 LANGUAGE js AS
"""
  return jStat.ttest(a,b,c); //jStat.ttest( tscore, n, sides)
"""
OPTIONS (
    library=["gs://your_bucket/jstat.js"]
);

CREATE TEMPORARY FUNCTION tTest (data ARRAY<FLOAT64>, data2 ARRAY<FLOAT64>)  RETURNS FLOAT64
AS ((
    WITH dataset1 AS (
        SELECT
            d AS A,
            i
        FROM UNNEST(data) as d WITH OFFSET AS i
        ORDER BY i
    )
    ,dataset2 AS (
        SELECT
            d AS B,
            i
        FROM UNNEST(data2) as d WITH OFFSET AS i
        ORDER BY i
    )
    , dataset AS (
        SELECT
          dataset1.A AS A,
          dataset2.B AS B
        FROM dataset1
        LEFT JOIN dataset2 on dataset1.i = dataset2.i
    )
    , test AS (
        SELECT 
        COUNT(*) n
        , COUNT(*)-1 dof
        , AVG(difference) mean
        , STDDEV_SAMP(difference) SD 
        , STDDEV_SAMP(difference)/SQRT(COUNT(*)) SE
        , AVG(difference)/ (STDDEV_SAMP(difference)/SQRT(COUNT(*))) t 
        , tscore_to_p(((AVG(A) - AVG(B)) / (STDDEV_SAMP(difference)/SQRT(COUNT(*)))), COUNT(*), 2) p_value 
        FROM (SELECT *, (A-B) difference FROM dataset)
    )
    SELECT p_value FROM test
));

WITH test_data AS ( 
 SELECT * FROM 
 (SELECT 9.96 AS A, 3.96 AS B) UNION ALL
 (SELECT 3.76 AS A, 5.76 AS B) UNION ALL
 (SELECT 1.17 AS A, 7.17 AS B) UNION ALL
 (SELECT 8.66 AS A, 7.66 AS B) UNION ALL
 (SELECT 5.25 AS A, 9.25 AS B) UNION ALL
 (SELECT 7.61 AS A, 3.61 AS B) UNION ALL
 (SELECT 5.80 AS A, 4.80 AS B) UNION ALL
 (SELECT 1.84 AS A, 8.84 AS B) UNION ALL
 (SELECT 7.06 AS A, 6.06 AS B) UNION ALL
 (SELECT 9.40 AS A, 4.40 AS B) UNION ALL
 (SELECT 2.99 AS A, 1.99 AS B) UNION ALL
 (SELECT 9.30 AS A, 8.30 AS B) UNION ALL
 (SELECT 9.01 AS A, 9.01 AS B) UNION ALL
 (SELECT 4.24 AS A, 1.24 AS B) UNION ALL
 (SELECT 3.52 AS A, 5.52 AS B) UNION ALL
 (SELECT 9.60 AS A, 8.60 AS B) UNION ALL
 (SELECT 7.59 AS A, 5.59 AS B) UNION ALL
 (SELECT 6.99 AS A, 1.99 AS B) UNION ALL
 (SELECT 9.62 AS A, 7.62 AS B) UNION ALL
 (SELECT 2.18 AS A, 3.18 AS B)
 ) 
SELECT tTest(ARRAY_AGG(A), ARRAY_AGG(B)) AS pvalue FROM test_data

結果として

pvalue=0.4871933190426815

という結果が得られています。

確認

同じ検定をpythonでやってみて、値が同じことを確認します。

p値が0.4871になっているので同じ結果が得られていそうなので、動作確認できました。 

対応のないt検定

次に、対応のないt検定をやってみます。

クエリ

CREATE TEMPORARY FUNCTION studentt_cdf(t FLOAT64, dof FLOAT64)
 RETURNS FLOAT64
 LANGUAGE js AS
"""
  return jStat.studentt.cdf(-Math.abs(t), dof) *2; 
"""
OPTIONS (
    library=["gs://your_bucket/jstat.js"]
);

CREATE TEMPORARY FUNCTION tTest (data ARRAY<FLOAT64>, data2 ARRAY<FLOAT64>)  
AS ((
    WITH dataset1 AS (
        SELECT
            d AS A
        FROM UNNEST(data) as d
    )
    ,dataset2 AS (
        SELECT
            d AS B
        FROM UNNEST(data2) as d
    )
    , mean AS (
        SELECT 
            AVG(A) AS ma, 
            AVG(B) AS mb
        FROM dataset1, dataset2
    )
    , lena AS (
        SELECT 
            COUNT(A) AS len_a
        FROM dataset1
    )
    , lenb AS (
        SELECT 
            COUNT(B) AS len_b
        FROM dataset2
    )
    , Ama AS (
        SELECT 
            A,
            ma,
            A - ma AS A_ma,
        FROM dataset1, mean
    )
    , bmb AS (
        SELECT 
            B,
            mb,
            B - mb AS B_mb,
        FROM dataset2, mean
    )
    , pow_Ama AS (
        SELECT
            SUM(A_ma * A_ma) AS A_ma_2
        FROM Ama
    )
    , pow_Bmb AS (
        SELECT
            SUM(B_mb * B_mb) AS B_mb_2
        FROM bmb
    )
    , S2 AS (
        SELECT 
            (A_ma_2 + B_mb_2) / (len_a + len_b - 2) AS s_2
        FROM pow_Ama, pow_Bmb, lena, lenb
    )
    , t AS (
        SELECT 
            len_a,
            len_b,
            (ma - mb) / SQRT((s_2/len_a) + (s_2/len_b)) AS t_value
        FROM mean, S2, lena, lenb
    )
    SELECT 
        studentt_cdf(t_value, len_a + len_b-2) AS p_value
    FROM t
));

WITH test_data AS ( 
    SELECT 
        [0.0,5.0,29.0,3.0,4.0] AS A,
        [9.0,4.0,5.0,6.0,4.0,2.0,3.0,1.0,2.0,4.0] AS B
) 
SELECT tTest(A, B) AS p_value FROM test_data

これを計算すると、

p_value=0.2804860758986702

となります。

確認

同じ検定をpythonでやってみて、値が同じことを確認します。

p値が0.28048になっているので同じ結果が得られていそうなので、こちらも動作確認できました。 

参考文献

こちらの記事を参考にさせていただきました。

lab.mo-t.com

stackoverflow.com

github.com

d1.gmobb.jp

感想

疲れた。以上。