[T-ITS 2021] Spatio-Temporal Knowledge Transfer for Urban Crowd Flow Prediction via Deep Attentive Adaptation Networks
0. Overview
- Title : Spatio-Temporal Knowledge Transfer for Urban Crowd Flow Prediction via Deep Attentive Adaptation Networks
- Authors : Senzhang Wang, Hao Miao, Jiyue Li, Jiannong Cao
- Year : 2021
- Publish : TITS (IEEE Transactions on Intelligent Transportation Systems)
1. Introduction
1) Why do we need it?
- Deep learning์ด ๋ค์ํ spatio-temporal(์๊ณต๊ฐ) prediction task์ ์ฌ์ฉ๋๊ณ ์์
- ST-ResNet(2017, Cit. 1606) : forecast crowds inflow & outflow in each region of a city
- STDN(2018, Cit. 521) : road network based traffic prediction
- predict passenger pickup/demand demands (Attention+ConvLSTM)
- DeepTransport : predict the traffic data within a transport network (CNN+RNN)
- ์ต๊ทผ์๋ transfer learning์ ์ฌ์ฉํด ์๊ธฐ ๋ฌธ์ ๋ฅผ ํ์ด๋ณด๊ณ ์ ํ์
- RegionTrans(2019, Cit. 88) : source, target city์ ๋น์ทํ ์ง์ญ์ ๋งค์นญ โ ์ด ์์ ํ๋ ค๋ฉด other service data๊ฐ ๋ ํ์ (data ๊ด์ = region level)
- MetaST(2019, Cit. 166) : ์ฌ๋ฌ ๋์์ ์ฅ๊ธฐ์ ์ถ์ธ๋ฅผ ๋ฝ์๋ด์ target city์ ์จ๋ณด์ โ ์ด๊ฑธ automatically ํด์ฃผ๋ ํตํฉ ๋ชจ๋ธ์ ์์
- ์ฐ๋ฆฌ๋ data ๊ด์ = distribution ์์ ํ๊ณ , unified framework๋ฅผ ๋ง๋ค์ด๋ณด๊ฒ ๋ค.
2) Related works & Core things
- Urban Crowd Flow Prediction : ๋์/๊ตํต ๋ถ์ผ์ ํฐ ์ฃผ์ . ์ ํต์ ์ผ๋ก๋ ARIMA ๊ฐ์ ํต๊ณ based methods๋ฅผ ์ฃผ๋ก ์ฌ์ฉํ์ผ๋, ์ต๊ทผ์๋ DL methods๊ฐ ๋ง์ด ์ฐ์ด๋ ํธ
- DNN, ST-ResNet, SeqST-GAN, ConvLSTM, MT-ASTN, DCRNN, RegionTrans, MetaST ๋ฑ
- Transfer Learning : ML์ scarce labeled data problem์ ํด๊ฒฐํ๊ธฐ ์ํด ์ ์๋ ๋ฐฉ๋ฒ๋ก
- TCA, TLDA, JAN, JMMD ๋ฑ
- DAN(2015, Cit. 4413) : CNN์ domain adaptation task์ ๋ง๊ฒ ์ผ๋ฐํ, ์ปดํจํฐ ๋น์ ๋ถ์ผ์์ ํฐ ์ฑ๊ณต
- Neural Net์ด general feature ์ ์ก์๋ด๊ณ ์ฑ๋ฅ ์ข๋ค๋ง, labeled data ๋ณ๋ก ์๋ target domain์ ๋ฐ๋ก CNN ์ฐ๋ ๋ฌธ์ ๊ฐ ๋ง์
- ์ค์ ๋ก Yosinski et al.(2014, Cit. 8740) ๋ณด๋ Conv 1-3๊น์ง OK, Conv 4-5๋ถํฐ ์ด์ํด์ง๋๋, FC 6-8์์ ์์ ํ ๋ฉ๋กฑ
- DAN ์ ์๋ค์ Conv 1-3์ ๊ทธ๋๋ก ๋๊ณ (freeze), Conv 4-5 ๋จ๊ณ์ fine-tuning ์ ์ฉ, FC 6-8์ CNN parameter optimizing์ multi-kernel MMD๋ฅผ regularizer๋ก ๋ฃ๋ ์์ผ๋ก ๊ฐ์
- Sejdinovic et al.(2013, Cit. 610) : two samples์ distribution์ด ๊ฐ์์ง ํ๊ฐํ ๋งํ ํต๊ณ๋์ผ๋ก MMD(Maximum Mean Discrepancies)๋ฅผ ์ ์ํ ๋ฐ ์์
- ์์ฝํ๋ฉด CNN parameter๋ฅผ ์ฐพ๋, FC-layers ๋จ์์ ๋ง๋ค์ด์ง๋ source์ target์ hidden representation์ด ๋น์ทํด์ง๋๋ก ์ถ๊ฐ ์ ํ์ ์ค์ ํ ๊ฒ
- ConvLSTM(2015, Cit. 6876) : ๊ธฐ์กด Fully Connected LSTM์ 1์ฐจ์ time-series โ ๊ณต๊ฐ์ ๋ณด(row, column)์ ๋ฃ์ด์ 3์ฐจ์ ๋ฐ์ดํฐ๋ฅผ ๋ค๋ฃจ๋๋ก ํ์ฅ
- ํ์ฝฉ ๊ธฐ์์ฒญ์์ radar echo images๋ก ๊ฐ์ ์๋ณด๋ฅผ ํ๋ ค๋, ๊ธฐ์กด LSTM์ผ๋ก ๊ณต๊ฐ์ฑ์ ๋ด์๋ผ ์ ์์ด์ ์ง ์ฑ๋ฅ์ด ์ ์ข๋๋ผ โ image๋ฅผ LSTM์ ๋ฃ๊ธฐ ์ CNN์ผ๋ก ์ด๋ฒ๊ตฌ์ดํ๋ ๋ฐฉ์์ ์ ์
3) Formulationss
- Spatio-Temporal Data : 2์ฐจ์ ๊ณต๊ฐ ์์์ ๊ธฐ๋ก๋๋, ์๊ฐ์ ๋ฐ๋ผ ๋ณํ๋ feature๋ฅผ ๋งํ๋ค. ๋ฐ๋ผ์ ๋จ์ผ feature๋ผ๋ฉด ๊ธฐ๋ณธ์ ์ผ๋ก 3์ฐจ์ ๋ฐ์ดํฐ.
- ๋ณธ ๋
ผ๋ฌธ์์๋ ์๋ก ๋ค๋ฅธ ์ง์ญ์์ ๋ง๋ค์ด์ง ๋ฐ์ดํฐ๋ฅผ ๋ค๋ฃจ๋ฉฐ, ์ด๋ค์ ๊ฐ์ ์์ grid cell๋ก ๋๋ ์์
ํ๋ค.
- ์์ธ, ๋์ , ๋ด์, โฆ ๋์์ ํฌ๊ธฐ/ํํ๋ ์ ๊ฐ๊ฐ์ด์ง๋ง cell ์๊ฐ ๊ฐ๋๋ก ๊ฒฉ์๋ฅผ ๋ง๋ค์ด์ค๋ค.
๋ฐ์ดํฐ๊ฐ coverํ๋ ๊ณต๊ฐ์ m*n๊ฐ์ grid cell๋ก ๋๋๋ค. each cell region์ด t์์ ์ ๊ฐ๋ ์ ๋ณด(๊ตํต๋, ๊ฐ์ ๋ฑ)๊ฐ ์์ ํ ๋ฐ, ์ด๋ค์ด ์ด๋ค ๊ฐ์ ๊ฐ๋์ง ํํํ ๊ฒ spatio-temporal image (matrix)๋ผ ํ๋ค.
- ๊ฒฉ์ ํํ matrix๋ฅผ image๋ผ ํ ๋, ๋งค ์์ ๋ง๋ค ๊ธฐ๋ก๋ image๋ค์ time-series๋ฅผ ๋ชจ์ผ๋ฉด 3์ฐจ์ tensor๊ฐ ๋๋ค.
- ์์ธ์ ๋ฐ๋ฆ์ด ํตํ๋(a feature)์ ์ด๋ ์๊ฐ์ฏค ๊ด์ฐฐํ๋ค๋ฉด, ํด๋น ๋ฐ์ดํฐ๋ ์๋์ ๊ฐ์ spatio-temporal tensor๋ก ๋ฌ์ฌํ ์ ์๊ฒ ๋ค.
image๋ ์๊ฐ์ ๋ฐ๋ผ ๋ณํ๋ฉฐ, t์์ ๊ธฐ์ค์ผ๋ก ๊ณผ๊ฑฐ k๊ฐ image๋ฅผ ์ถ์ ํ๋ฉด, ์์ ๊ฐ์ 3์ฐจ์ tensor๋ฅผ ์ป์ ์ ์๋ค. ์ด tensor๊ฐ ์์ผ๋ก ์ ๊ฐํ ๋ ผ๋ฆฌ์ ๊ธฐ๋ณธ ๋จ์๋ก ์์ฃผ ์ฐ์ธ๋ค.
- tensor๋ค์ ์ต์๋จ(latest) image๋ฅผ ๊ธฐ์ค์ผ๋ก ์ถ๋ ค๋ธ ์ต๊ทผ k๊ฐ images์ธ ์
์ธ๋ฐ, ์ด ๊ฐ์ ๋ญ์น๋ฅผ 1-step after ๋ง๋ค ๊ณ์ ๋ฝ์๋ธ๋ค๋ฉด, ํด๋น tensors๋ก ์ด๋ค 4์ฐจ์ ๋ฆฌ์คํธ๋ฅผ ๋ง๋ค ์ ์๊ฒ ๋ค.
- List with parameters : Row(m) * Column(n) * Accumulation(k) * Time-stamp(t)
- ์ด ๋ฆฌ์คํธ๋ฅผ tensor set, ๊ธธ์ด๋ฅผ โLโ์ด๋ผ ํ์.
- ๋ฐ์ดํฐ๊ฐ ๋ง์(์ฅ๊ธฐ๊ฐ) domain์์๋ ์งํฉ์ด ๊ธธ์ญํ๊ฒ, ๋ฐ๋๋ก ๋ฐ์ดํฐ๊ฐ ๋ถ์กฑํ domain์์๋ ์งค๋งํ ์งํฉ์ด ๋์จ๋ค.
tensor๋ ์ ๋ณด๋ฅผ ์๋ฏธํ๋ฉฐ, domain์ ๋ฐ๋ผ ์ ๋ณด๋์ ๋ค๋ฅผ ํ ๋ค. ์์ปจ๋ ์ฌ๊ธฐ์ ์์ธ์ ํ์ ์น๊ฐ ๋ฐ์ดํฐ๋ ๋ํ(์ต์ข ์ ๋ฐ์ดํธ ๊ธฐ์ค) ์ ๋๋ก ๊ธธ์ง๋ง, ๋ฐ๋ฆ์ด ํตํ๋ ๋ฐ์ดํฐ๋ ๊ธฐ๊ปํด์ผ ๋ฐ๋์ ์ฏค ๋ผ์, ๋ค๋ฅธ domain์ธ ํ์ ์ ๋ณด๋ฅผ ์ด๋ป๊ฒ ์ ๊ฐ์ ธ์ฌ ์ ์์๊น ๊ณ ๋ฏผํ๊ฒ ๋๋ค. ๊ทธ๊ฒ ์ด ๋ ผ๋ฌธ์ ํต์ฌ ์ฃผ์ .
2. Main Architecture
- ๊ธฐ๋ณธ์ ์ธ ํน์ง์ stacked ConvLSTM ์ผ๋ก ์ก์๋ด๋ฉฐ, ๋ง๋ค์ด์ง hidden state์ DAN(generalized CNN), ๋ง์ง๋ง์ Global Attention ์ ์ฉ & ๊ธฐํ features ์ถ๊ฐํ๋ ๊ตฌ์ฑ์ด๋ค
๋ ผ๋ฌธ์ main figure. ํฌ๊ฒ 1) ConvLSTM, 2) CNN with MMD (DAN), 3) Global spatial attention ๊ตฌ๊ฐ์ผ๋ก ๋๋๋ค.
1) Representaion Learning (ConvLSTM)
- Input = Tensor set(4D) ์ด์ง๋ง, ์์ ์ ๋งค image(2D) ๋ง๋ค ์งํ โ ํ ์ฅ์ฉ CNN์ ๊ฑฐ์ณ ์๋ก์ด tensor set์ ๋ง๋ค์ด ๋ผ ์ ์์ โ ๋ค์ LSTM์ Input gate์ ํฌ์ + ์ด์ hidden state tensor set๊ณผ ๊ฒฐํฉ + โฆ (๋ง์ฐฌ๊ฐ์ง๋ก 2D ๋จ์๋ก ์งํ) โ ๋ฐ๋ณต
- ๋ชจ๋ stacked LSTM์ ํต๊ณผํด ๋ง๋ค์ด์ง ์ต์ข ๊ฒฐ๊ณผ๋ฌผ์ โHโ๋ผ ํ์
2) Knowledge Transfer (DAN)
- two different domainsโ distributions์ด ์ผ๋ง๋ ๋ค๋ฅธ์ง, distance๋ก ํ๊ฐํ ๊ฒ์ MMD๋ผ ํ๋ค.
- ๋๋ฉ์ธ ๋ณ๋ก hidden state์ CNN์ ์ ์ฉํ๋, CNN layer ๋ง๋ค mmd loss๋ฅผ ์ฐ์ถํด ํ๊ท ์ ๋ธ๋ค.
- Parameter set ฮ = argmin Loss Function of (GT vs ConvLSTM & CNN & mmd_loss & โฆ )
3) Global Spatial Attention
- local spatial correlations๋ CNN ๋จ๊ณ์์ ์กํ์ง๋ง, ๋ณด๋ค ๋์ ๋ฒ์์์ geographical dependencies๋ ์ ํฌ์ฐฉ๋์ง ์๋๋ค.
- ์ง๋ฆฌ์์ผ๋ก๋ ๋ฉ๋ฆฌ ๋จ์ด์ง ๋ ์ง์ญ์ด ์ ์ฌํ Point of Interest distribution์ ๊ฐ์ง๋ ๊ฒฝ์ฐ๊ฐ ๋ง๋ค
- ์ด๋ taxi-trip, crowd flow ๊ฐ์ ์๊ณต๊ฐ ์ ๋ณด๋ ๋ง์ฐฌ๊ฐ์ง
- source domain ๋ฐ์ดํฐ๋ฅผ ํ์ฉํ ๋, attention score๋ฅผ ๊ณฑํด์ ๊ฐ์ ธ์ค๋ฉด global relation์ ์ฒดํฌํ๋ ํจ๊ณผ๋ฅผ ๋ผ ์ ์์ง ์์๊น
์์นจ ํ๋์ ํ์ ์น๊ฐ(source)์, ๊ฐ์ ์๊ฐ ํ๋์ ๋ ธ์์ ์์ ๊ฑฐ ํตํ๋(target)๊ณผ ๋ฎ์์๋ค. domain์ ๋ค๋ฅด์ง๋ง, โ์ถํด๊ทผ/ํตํโ ์ด๋ผ๋ ์์๊ฐ ์ ๋ณ์ ๊น๋ ค์์์ attention mechanism์ ํตํด ํ์ ํ๋ ์ . ์ฑ์๋ ๋ ธ์๋ณด๋ค ํ๋์ ๊ฐ๊น์ด ์์ง๋ง, ์ฃผ๊ฑฐ/์ ๋ฌด/ํ๊ตฐ ๋ณด๋จ โ๋ฌธํ์์ โ ์ง์ญ์ด๋ผ ์์นจ์ ์์ ๊ฑฐ ํ๋ ์ฌ๋์ด ์ ๋ค๊ณ ํด์ํ ์ ์๊ฒ ๋ค.
- ๊ตฌ์ฒด์ ์ผ๋ก๋ source domain์ 2D image์ ํน์ ๋ถ๋ถ Region (i, j)๊ฐ, target domain์ ๋ชจ๋ m*n๊ฐ region๊ณผ ์ผ๋ง๋ ๋ฎ์์๋์ง ์ฒดํฌํ๋ค
- ๋ณธ ๋ ผ๋ฌธ์์ ๋ค๋ฃจ๋ image๋ ๋ชจ๋ ๊ฐ์ m*n ์ฌ์ด์ฆ grid cell๋ก ๋๋ ์ ธ ์์ผ๋ ํ๋ ฌ ๊ณ์ฐ์ด ์ฉ์ดํ๋ค.
- dot-product, softmax ์ทจํด์ attention matrix ๋ง๋๋ ๋ฑ ๋๋ฆฌ ์๋ ค์ง attention mechanism๊ณผ ํฌ๊ฒ ๋ค๋ฅธ ์ ์ ๋ณด์ด์ง ์์๋ค
3. Algorithm & Code
1) Algorithm
2) Real Code
https://github.com/MiaoHaoSunny/ST-DAAN
4. Evaluation
- ๊ณผ๊ฑฐ Taxi, Bike ๋ฐ์ดํฐ๋ก Crowd flow prediction ํ๋ task๋ก ST-DAAN ์ฑ๋ฅ์ ํ๊ฐํด๋ณด์
์ฌ๋ฌ ๋์์์ ์์ง๋ taxi, bike ๋ฐ์ดํฐ์ ์ผ๋ก, ๊ฐ๊ฐ GPS ๊ฒฝ๋ก, ์ถ๋ฐ/๋์ฐฉ์ง, ์๊ฐ, ID ๋ฑ ๋ค์ํ variables๋ก ๊ตฌ์ฑ๋ผ์๋ค. number of trips, time span์ ๋น๊ตํ๋ฉด DIDI๋ ๊ฐ์ ํ์ ๋ฐ์ดํฐ์ ์ธ TaxiNYC๋ณด๋ค data scarce ํ๋ค๊ณ ๋ณผ ์ ์๋ค.
- Intra-city(TaxiNYC โ BikeNYC), Cross-city(BikeChicago โ BikeNYC, DIDI โ TaxiBJ) transfer case๋ฅผ ๋ชจ๋ ๋ค๋ค๋ณด์๋ค
- Baseline model์ non-transfer learning, ์ต๊ทผ์ transfer leaning based์์ ๊ณ ๋ฃจ ๊ณจ๋๋ค
- non-transfer learning based : ARIMA, ConvLSTM, DCRNN, DeepST, ST-ResNet
- transfer learning based : (์ ๋ชจ๋ธ๋ค์ fine-tuning), RegionTrans, MetaST
1) Comparison With Baselines
- ARIMA < non-transfer < non-transfer with fine-tuning < transfer < ST-DAAN ์์ผ๋ก ์ฑ๋ฅ Good
-
ST-DAAN full version๊ณผ Attention & External features์ ๊ฐ๊ฐ ๋นผ๋ณธ variation์ ๋น๊ตํด๋ณด๋, ์ด๋ค ์ญ์ ์ฑ๋ฅ ํฅ์์ ๋์์ด ๋์
Intra-city, Cross-city ๋ฌด๊ดํ๊ฒ ST-DAAN์ด ์ข์ ์ฑ๋ฅ์ ๋ณด์. nonAtt, nonExt๋ ๊ฐ๊ฐ global spatial attention, inserting external feature์ ์์ค ๋ฒ์ ์ ST-DAAN
-
2) Effect of Data Amount
- ๋ฐ์ดํฐ๊ฐ ๋ง์ ์๋ก ์ข๊ธด ํ๋๋ผ. Source/Target ๋ ๋ค ๋ฐ์ดํฐ๊ฐ ๋ง์ผ๋ฉด ์ฑ๋ฅ ์ข์
๋์ฒด๋ก ๋ฐ์ดํฐ length ๊ธธ์๋ก ์์ธก ์ฑ๋ฅ์ด ์ข์์ง. ์ญ์ ๋ค๋ค์ต์
3) Parameter Sensitivity Analysis
- Scarce data ๋ค๋ฃจ๋ transfer learning, ์ ๊ฒฝ๋ง ๊น๊ฒ ์์ผ๋ฉด ์คํ๋ ค overfitting ๋ฌธ์ ๊ฐ ๋ฐ์
- Domain discrepancy์ ์ ๋นํ penalty ์ค์ผ ํจ. ์๊ฒ ์ฃผ๋ฉด common knowledge๊ฐ ์ ๋ฌ๋์ง ์๊ณ , ๋๋ฌด ํฌ๊ฒ ์ฃผ๋ฉด only domain-specific feature๋ง ์ ๋ฌ๋จ
ConvLSTM, CNN ๋จ๊ณ์์ number of layers ๋๋ฌด ๋ง์ผ๋ฉด ๋ฌธ์ , penalty hyper-parameter gamma๋ ์ ๋นํ ์ค์ ํ ํ์
5. Others
- TaxiBJ์ crowd flows๋ฅผ RegionTrans, ST-DAAN์ผ๋ก ์์ธกํด๋ณด์๋๋ฐ, ํ์ ๋ง์ด ์ก๋ Rush hour์์ ST-DAAN์ด RegionTrans ๋๋น ์ฐ์ โ ๋ณธ ๋ชจ๋ธ์ ์ดํดํ๋ ๋ฐ ๋์๋ ๋งํ ์ง๊ด์ ์์?
- ๊ธฐ์กด ๋ชจ๋ธ์ time invariant, ํน์ง์ ์ ๋๋ก ๊ตฌ๋ถํ์ง ๋ชปํ์ง๋ง, ST-DAAN์ ์ผ์ ๋ถ๋ถ GT์ ๋ค๊ฐ์๋ ๋ชจ์ต์ ๋ณด์๋ค๋ ์์ผ๋ก ์ดํดํจ
ํ์ ๋ง์ด ์ ์ก๋ ์ฌ์ผ ์๊ฐ์๋ RegionTrans, ST-DAAN ๋ ๋ค ๋น์ทํ์ง๋ง, Rush hour์์ ๊ฝค ๋น์ทํ๊ฒ capture