Kerasにおける"Flatten"の役割は何ですか?
KerasのFlatten
関数の役割を理解しようとしています。下記は私のコードで、単純な2層ネットワークです。形状 (3, 2) の2次元データを取り込み、形状 (1, 4) の1次元データを出力する:
model = Sequential()
model.add(Dense(16, input_shape=(3, 2)))
model.add(Activation('relu'))
model.add(Flatten())
model.add(Dense(4))
model.compile(loss='mean_squared_error', optimizer='SGD')
x = np.array([[[1, 2], [3, 4], [5, 6]]])
y = model.predict(x)
print y.shape
これはy
の形状が(1, 4)であることを出力する。しかし、Flatten
行を削除すると、y
の形状は(1, 3, 4)と出力される。
これは理解できない。私のニューラルネットワークの理解では、model.add(Dense(16, input_shape=(3, 2)))
関数は16個のノードを持つ隠れ全結合層を作成している。これらのノードはそれぞれ3x2の入力要素に接続されています。したがって、この第1層の出力の16個のノードは、すでに"flat"である。したがって、第1層の出力形状は(1, 16)となる。そして第2層はこれを入力とし、(1, 4)のデータを出力する。
では、第1層の出力がすでに(1, 16)の形状でquot;flat"なのであれば、なぜさらに平坦化する必要があるのでしょうか?
81
3
ここの
Dense
のドキュメントを読めばわかるだろう:は3つの入力と16の出力を持つ
Dense
ネットワークになり、5つのステップのそれぞれに独立して適用されます。D(x)が3次元のベクトルを16次元のベクトルに変換する場合、そのレイヤーの出力は一連のベクトルになります:D(x[0,:], D(x[1,:],..., D(x[4,:]]
, shape(5, 16)
。指定した振る舞いをするためには、まず入力を15次元のベクトルにFlatten
してからDense
を適用します:EDIT: 理解するのに苦労している人もいるようなので、ここに説明画像がある:
[ここに画像の説明を入力してください]1。
short read:
[ここに画像の説明を入力]1.
これが、Matrixを1つの配列に変換するFlattenの動作です。