5月末までに、FrancoisScholléの著書「 Deep Learning in Python 」の翻訳版(KerasおよびTensorflowライブラリを使用した例)があります。 お見逃しなく!
しかし、もちろん、私たちは差し迫った未来に目を向け、さらに革新的なPyTorchライブラリを詳しく調べ始めます。 今日、ピーター・ゴールズボロの記事を翻訳するように注意を促します。ピーター・ゴールズボロは、このライブラリの慣れ親しんだツアーに
過去2年間、私はTensorFlowに真剣に取り組んでいます。このライブラリに関する記事を書いたり 、バックエンドの拡張に関する講義を行ったり、ディープラーニングに関連する自分の研究でそれを使用したりしました 。 この作業中に、TensorFlowの長所と短所をよく理解しました。また、競争の余地を残す特定のアーキテクチャソリューションについても知りました。 この荷物で、私は最近Facebook(FAIR)の人工知能研究部門のPyTorchチームに参加しました。これはおそらく現在TensorFlowの最強のライバルです。 現在、PyTorchは研究コミュニティで非常に人気があります。 理由-次の段落で説明します。
この記事では、PyTorchライブラリーの概要を明示し、その作成目的を説明し、そのAPIを紹介します。
全体像と哲学
まず、基本的な観点からPyTorchが何であるか、PyTorchを操作する際に適用する必要があるプログラミングモデル、および現代のディープラーニングツールのエコシステムにどのように適合するかを検討します。
本質的に、PyTorchはNumPyのようなGPU加速テンソル計算を提供するPythonライブラリです。 さらに、PyTorchは、ニューラルネットワークに関連するアプリケーションを解決するための豊富なAPIを提供します。
PyTorchは他の機械学習フレームワークとは異なり、TensorFlow、 Caffe2、またはMXNetのように、事前に即座にかつ明確に定義された静的な計算グラフを使用しません。 対照的に、PyTorchで計算されたグラフは動的であり、その場で決定されます。 したがって、PyTorchモデルでレイヤーを呼び出すたびに、新しい計算グラフが動的に決定されます。 このグラフは暗黙的に作成されます。つまり、ライブラリ自体がプログラムを通過するデータの流れを記録し、関数呼び出し(ノード)を(エッジを介して)計算グラフにリンクします。
動的グラフと静的グラフの比較
静的グラフと動的グラフの違いを詳しく見てみましょう。 一般に、ほとんどのプログラミング環境では、数値を表す2つの変数xとyを加算すると、それらの合計値が得られます(加算の結果)。 たとえば、Pythonの場合:
In [1]: x = 4 In [2]: y = 2 In [3]: x + y Out[3]: 6
しかし、TensorFlowではそうではありません。 TensorFlowでは、xとyはそのような数値ではなく、これらの値を表すが明示的に含まれていないグラフノードの記述子です。 さらに(さらに重要です)、
x
と
y
追加すると、これらの数値の合計ではなく、計算されたグラフの記述子を取得します。これにより、実行後にのみ目的の値が得られます。
In [1]: import tensorflow as tf In [2]: x = tf.constant(4) In [3]: y = tf.constant(2) In [4]: x + y Out[4]: <tf.Tensor 'add:0' shape=() dtype=int32>
基本的に、TensorFlowコードを記述するとき、これは実際にはプログラミングではなく、 メタプログラミングです。 私たちはプログラム(私たちのコード)を書き、それは別のプログラム(計算されたTensorFlowグラフ)を作成します。 当然、最初のプログラミングモデルは2番目のプログラミングモデルよりもはるかに単純です。 それらの表現よりむしろ実際の現象の文脈で話して、推論するのははるかに便利です。
PyTorchの最も重要な利点は、その実行モデルが2番目のパラダイムよりも最初のパラダイムにはるかに近いことです。 コアでは、PyTorchはテンソルコンピューティング(NumPyなど)をサポートする最も一般的なPythonですが、GPUアクセラレーションテンソル操作を備えており、最も重要なこととして、組み込みの自動微分 (AD)を備えています。 最新の機械学習アルゴリズムのほとんどは線形代数(行列とベクトル)のデータ型に大きく依存しており、勾配情報を使用して推定値を調整するため、これら2つのPyTorch柱は任意の大規模な機械学習問題に対処するのに十分です。
上記の単純なケースの分析に戻ると、PyTorchを使用したプログラミングが「自然な」Pythonのように感じることを確認できます。
In [1]: import torch In [2]: x = torch.ones(1) * 4 In [3]: y = torch.ones(1) * 2 In [4]: x + y Out[4]: 6 [torch.FloatTensor of size 1]
PyTorchは、特定の1つの側面でPythonの基本的なプログラミングロジックとわずかに異なります。ライブラリは実行中のプログラムの実行を記録します。 つまり、PyTorchは、そのデータ型に対して、そして舞台裏で、あなたが実行する操作を静かに「追跡」しています。 -決済グラフを収集します。 このような計算グラフは、微分を計算するために結果の値を生成する一連の操作に沿って反対方向に進む必要があるため、自動微分に必要です(逆自動微分の場合)。 TensorFlowまたはMXNetのバージョンとこの計算グラフ(または、むしろこの計算グラフを組み立てる方法)の重大な違いは、各コードフラグメントを解釈するときに、新しいグラフが「貪欲に」オンザフライで収集されることです。
それどころか、Tensorflowでは、計算されたグラフは一度だけ構築され、メタプログラム(コード)がこれを担当します。 さらに、値の導関数を要求するたびにPyTorchはグラフを反対方向に動的にトラバースしますが、TensorFlowはグラフに追加ノードを挿入し、これらの導関数を(暗黙的に)計算し、他のすべてのノードとまったく同じように解釈します。 ここでは、動的グラフと静的グラフの違いが特に顕著です。
どの計算グラフを使用するか(静的または動的)を選択すると、これらの環境のいずれかでのプログラミングプロセスが大幅に簡素化されます。 制御の流れは、この選択によって特に影響を受ける側面です。 静的グラフを使用する環境では、制御フローはグラフレベルで特殊なノードの形式で表す必要があります。 たとえば、
tf.cond()
では、分岐を確実にするために、3つのサブグラフを入力として取る
tf.cond()
操作があります。条件付きサブグラフと、2つの条件開発ブランチ(
if
および
else
2つのサブグラフ
else
。 同様に、Ternsorflowグラフのループはtf.while
tf.while()
操作として表す必要があります。これは、
condition
として
body
サブグラフを入力として受け入れます。 動的なグラフがある状況では、これはすべて単純化されます。 グラフはPythonコードからそのまま表示されるため、解釈ごとに、他のプログラムと同様に、
if
条件と
while
ループを使用してフロー制御を言語にネイティブに実装できます。 したがって、不器用で混乱を招くTensorflowコード:
import tensorflow as tf x = tf.constant(2, shape=[2, 2]) w = tf.while_loop( lambda x: tf.reduce_sum(x) < 100, lambda x: tf.nn.relu(tf.square(x)), [x])
自然で理解可能なPyTorchコードに変わります:
import torch.nn from torch.autograd import Variable x = Variable(torch.ones([2, 2]) * 2) while x.sum() < 100: x = torch.nn.ReLU()(x**2)
当然、プログラミングの容易さの観点から、動的グラフの使用はそれをはるかに超えています。 (
tf.Print()
ノードを使用するのではなく)
print
ステートメントを使用して、またはデバッガーで中間値をチェックできることは、すでに大きなプラスです。 もちろん、ダイナミズムはプログラマビリティを最適化し、パフォーマンスを低下させる可能性があります。つまり、そのようなグラフを最適化することはより困難です。 したがって、PyTorchとTensorFlowの違いとトレードオフは、Pythonなどの動的な解釈言語と、CまたはC ++などの静的なコンパイル言語との間とほぼ同じです。 前者は作業がより簡単で高速であり、2番目と3番目から非常に最適化可能なエンティティを組み立てる方が便利です。 これは、柔軟性とパフォーマンスのトレードオフです。
PyTorch APIノート
PyTorch APIについて、特にTensorFlowやMXNetなどの他のライブラリと比較したニューラルネットワークの計算に関して、一般的な発言をしたいと思います-このAPIは多くのモジュール(いわゆる「バッテリー付属」)にハングアップします。 同僚の一人が指摘したように、Tensorflow APIは実際には「アセンブリ」レベルを超えていませんでした。というのは、このAPIは計算グラフ(加算、乗算、点ごとの関数など)の作成に必要な最も単純なアセンブリ命令のみを提供するという意味でです。 d。)。 しかし、プログラマーが操作中に何千回も再現しなければならない最も一般的なプログラムフラグメント用の「標準ライブラリ」が欠けています。 したがって、Tensorflowの上に高レベルのAPIを構築するには、コミュニティの助けに頼る必要があります。
実際、コミュニティはそのような高レベルのAPIを作成しました。 真実、残念ながら、1つではなく、1ダース-ライバル関係で。 したがって、悪い日には、専門分野に関する5つの記事を読むことができます。5つすべてで、TensorFlowのさまざまな「フロントエンド」を見つけることができます。 原則として、これらのAPIにはほとんど共通点がないため、実際にはTensorFlowだけでなく5つの異なるフレームワークを学習する必要があります。 これらのAPIの中で最も人気のあるものを以下に示します。
PyTorchは、ディープラーニングの分野での日々の研究に必要な最も一般的な要素をすでに備えています。 原則として、torch.nnパッケージにKerasのような「ネイティブ」APIがあり、ニューラルネットワークの高レベルモジュールの結合を提供します。
共通生態系におけるPyTorchの位置
したがって、PyTorchがMXNet、TensorFlow、Theanoなどの静的グラフフレームワークとどのように異なるかを説明すると、実際、PyTorchはニューラルネットワークを計算するアプローチにおいてユニークではない、と言わざるを得ません。 PyTorchより前には、 ChainerやDyNetなど、同様の動的APIを提供する他のライブラリがすでに存在していました。 しかし、今日、PyTorchはこれらの選択肢よりも人気があります。
さらに、Facebookで使用されるフレームワークはPyTorchだけではありません。 現在、本番環境の主なワークロードはCaffe2に当てはまります。これは、 Caffeに基づいて構築された静的なグラフフレームワークです。 PyTorchの研究者が生産最適化の分野で静的グラフの利点を提供する柔軟性を友人に提供するために、FacebookはONNXを開発しています。
最後に、小さな歴史的余談:PyTorchの前に、 Torchがありました-Lua 言語で書かれた科学計算用の非常に古い(2000年代初期のサンプル)ライブラリです。 Torchは、Cで記述されたコードベースをラップするため、高速かつ効率的になります。 原則として、PyTorchはまったく同じベースのCコードをラップします(ただし、抽象化の中間レベルが追加されます )。PythonでユーザーAPIを公開します。 次に、このAPIについてPythonで話しましょう。
PyTorchを使用する
次に、PyTorchライブラリの基本概念と主要コンポーネントについて説明し、その基本データ型、自動差別化メカニズム、ニューラルネットワークに関連する特定の機能、およびデータのロードと処理のためのユーティリティを調べます。
テンソル
PyTorchで最も基本的なデータ型は
tensor
です。 テンソルデータ型の値と機能は、NumPyの
ndarray
と非常に似ています。 さらに、PyTorchはNumPyとの合理的な相互運用性を目指しているため、
tensor
APIも
ndarray
APIに似て
ndarray
ます(ただし、これと同一ではありません)。 PyTorchテンソルは
torch.Tensor
コンストラクターを使用して作成できます。この
torch.Tensor
ターは、入力としてテンソル次元を取り、 初期化されていないメモリ領域を占めるテンソルを返します。
import torch x = torch.Tensor(4, 4)
実際には、多くの場合、何らかの方法で初期化されたテンソルを返す次のPyTorch関数のいずれかを使用する必要があります。
-
torch.rand
:値はランダムな一様分布から初期化され、 -
torch.randn
:値はランダムな正規分布から初期化され、 -
torch.eye(n)
:n×nn×n
形式の単位行列 -
torch.from_numpy(ndarray)
:NumPyのndarray
基づくPyTorchテンソル -
torch.linspace(start, end, steps)
:start
とend
間で均等に分布するsteps
値を持つ1次元テンソル、 -
torch.ones
:1つの単位テンソル、 -
torch.zeros_like(other)
:other
と同じ形状で、同じゼロを持つテンソル、 -
torch.arange(start, end, step)
:値が範囲外に埋められた1次元テンソル。
NumPyの
ndarray
と同様に、PyTorchテンソルは、状況の変化だけでなく、他のテンソルと組み合わせるための非常に豊富なAPIを提供します。 NumPyの場合と同様に、単項演算と二項演算は通常、
torch
モジュールの関数
torch.add(x, y)
など
torch.add(x, y)
を使用するか、テンソルオブジェクトのメソッド
torch.add(x, y)
など
torch.add(x, y)
直接使用して実行できます。 最も一般的な場所には、たとえば
x + y
ようなオーバーロード演算子があります。 さらに、多くの関数には、新しいテンソルを作成せず、受信者インスタンスを変更する状況的代替手段があります。 これらの関数は標準のバリアントと同じように呼び出されますが、名前にアンダースコアが含まれます
x.add_(y)
例:
x.add_(y)
。
選択された操作:
torch.add(x, y)
:要素ごとの加算
torch.mm(x, y)
:行列の乗算(
matmul
または
dot
はない)、
torch.mul(x, y)
:要素ごとの乗算
torch.exp(x)
:
torch.exp(x)
ごとの指数
torch.pow(x, power)
:段階的なべき乗
torch.sqrt(x)
:要素ごとの二乗
torch.sqrt_(x)
:状況に
torch.sqrt_(x)
ごとの二乗
torch.sigmoid(x)
:要素単位のシグモイド
torch.cumprod(x)
:すべての値の積
torch.sum(x)
:すべての値の合計
torch.std(x)
:すべての値の標準偏差
torch.mean(x)
:すべての値の平均
テンソルは、翻訳、複雑な(気まぐれな)インデックス(
x[x > 5]
)、要素ごとの関係演算子(
x > y
)など、NumPyのndarrayでよく知られているセマンティクスを大幅にサポートしています。 PyTorchテンソルは、
torch.Tensor.numpy()
関数を使用して
ndarray
NumPyに直接変換することもできます。 最後に、ndarray NumPyと比較したPyTorchテンソルの主な利点はGPUアクセラレーションであるため、CUDA対応のGPUにテンソルメモリをコピーする
torch.Tensor.cuda()
関数もあります。
オートグラッド
最新の機械学習技術の中心にあるのは、勾配の計算です。 これは特に、重みを更新するために逆伝播アルゴリズムが使用されるニューラルネットワークに当てはまります。 これが、Pytorchがフレームワーク内で定義された関数と変数の勾配計算を強力にネイティブにサポートしている理由です。 勾配が任意の計算のために自動的に計算されるこの手法は、自動(場合によってはアルゴリズム )微分と呼ばれます。
グラフの計算に静的モデルを使用するフレームワークでは、グラフを分析して追加の計算ノードを追加することにより、自動微分が実現されます。ここで、別の値に対する1つの値の勾配が段階的に計算され、これらの追加の勾配ノードをエッジで接続するチェーンルールが部分的に接続されます。
ただし、PyTorchには静的に計算されたグラフはないため、他の計算が定義された後に勾配ノードを追加する余裕はありません。 代わりに、PyTorchは、プログラムが到着するときにプログラムを通過する値のフローを記録または追跡する必要があります。つまり、計算グラフを動的に構築します。 そのようなグラフが作成されるとすぐに、PyTorchはこの計算フローをバイパスし、入力に基づいて出力値の勾配を計算するために必要な情報を取得します。
PyTorchの
Tensor
は、自動分化に参加するための完全なメカニズムがまだありません。 テンソルを書くには、
torch.autograd.Variable
でラップする必要があります。
Variable
クラスは
Tensor
とほぼ同じAPIを提供しますが、自動差別化のために
torch.autograd.Function
正確に対話する機能で補完します。 より正確には、
Variable
は
Tensor
での操作の履歴を記録します。
torch.autograd.Variable
使用
torch.autograd.Variable
非常に簡単です。
Tensor
を渡して、この変数に勾配を書き込む必要があるかどうかを
torch
伝えるだけです。
x = torch.autograd.Variable(torch.ones(4, 4), requires_grad=True)
requires_grad
関数は、データの入力時やラベルの操作時など、通常はそのような情報が区別されないため、
False
requires_grad
する場合があります。 ただし、自動識別に適したものにするためには、
Variables
でなければなりません。 注意:require_gradのデフォルトは
False
であるため、トレーニングされたパラメーターの場合は
True
に設定する必要があります。
勾配を計算して自動微分を実行するには、
backward()
関数が
Variable
適用されます。 したがって、計算されたグラフの葉(このグラフに影響を与えたすべての入力値)に対するこのテンソルの勾配が計算されます。 次に、これらのグラデーションが
Variable
クラスの
grad
クラスのメンバーに収集されます。
In [1]: import torch In [2]: from torch.autograd import Variable In [3]: x = Variable(torch.ones(1, 5)) In [4]: w = Variable(torch.randn(5, 1), requires_grad=True) In [5]: b = Variable(torch.randn(1), requires_grad=True) In [6]: y = x.mm(w) + b # mm = matrix multiply In [7]: y.backward() # perform automatic differentiation In [8]: w.grad Out[8]: Variable containing: 1 1 1 1 1 [torch.FloatTensor of size (5,1)] In [9]: b.grad Out[9]: Variable containing: 1 [torch.FloatTensor of size (1,)] In [10]: x.grad None
入力値を除くすべての
Variable
は演算の結果であるため、
grad_fn
各変数に関連付けられています。これは、逆ステップを計算するための
torch.autograd.Function
関数です。 入力値の場合、
None
です。
In [11]: y.grad_fn Out[11]: <AddBackward1 at 0x1077cef60> In [12]: x.grad_fn None torch.nn
torch.nn
モジュールは、
torch.nn
ユーザーにニューラルネットワーク固有の機能を提供します。 その最も重要なメンバーの1つは
torch.nn.Module
。これは、ニューラルネットワークのレイヤーで最もよく使用される、再利用可能な操作の単位であり、関連する(トレーニングされた)パラメーターです。 モジュールには他のモジュールを含めることができ、バックディストリビューションの
backward()
関数を暗黙的に受け取ります。 モジュールの例は
torch.nn.Linear()
。これは、線形(密/完全に接続された)レイヤー(つまり、アフィン変換
Wx+bWx+b
)を
Wx+bWx+b
ます。
In [1]: import torch In [2]: from torch import nn In [3]: from torch.autograd import Variable In [4]: x = Variable(torch.ones(5, 5)) In [5]: x Out[5]: Variable containing: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 [torch.FloatTensor of size (5,5)] In [6]: linear = nn.Linear(5, 1) In [7]: linear(x) Out[7]: Variable containing: 0.3324 0.3324 0.3324 0.3324 0.3324 [torch.FloatTensor of size (5,1)]
トレーニング中、多くの場合、モジュールの
backward()
関数を呼び出して、変数の勾配を計算する必要があります。
Variables
の
grad
メンバー
Variables
いる
backward()
呼び出すと、すべての
Variable
grad
メンバーをゼロにリセットする
nn.Module.zero_grad()
メソッドもあります。 トレーニングループは通常、
zero_grad()
を呼び出すか、
backward()
を呼び出す直前に呼び出して、次の最適化ステップの勾配をリセットします。
ニューラルネットワーク用の独自のモデルを作成する場合、PyTorchと統合する一般的な機能をカプセル化するために、モジュールの独自のサブクラスを作成する必要があります。 これは非常に簡単に行われます
torch.nn.Module
からクラスを継承し、
forward
メソッドを与えます。 たとえば、ここに私のモデルの1つのために書いたモジュールがあります(入力情報にガウスノイズが追加されます)。
class AddNoise(torch.nn.Module): def __init__(self, mean=0.0, stddev=0.1): super(AddNoise, self).__init__() self.mean = mean self.stddev = stddev def forward(self, input): noise = input.clone().normal_(self.mean, self.stddev) return input + noise
モジュールを完全に機能するモデルに接続または結合するには、
torch.nn.Sequential()
コンテナーを使用できます。このコンテナーには、一連のモジュールが渡されます-そして、それは、独立したモジュールとして機能し始め、呼び出されるたびに、渡されたモジュールを計算します 例:
In [1]: import torch In [2]: from torch import nn In [3]: from torch.autograd import Variable In [4]: model = nn.Sequential( ...: nn.Conv2d(1, 20, 5), ...: nn.ReLU(), ...: nn.Conv2d(20, 64, 5), ...: nn.ReLU()) ...: In [5]: image = Variable(torch.rand(1, 1, 32, 32)) In [6]: model(image) Out[6]: Variable containing: (0 ,0 ,.,.) = 0.0026 0.0685 0.0000 ... 0.0000 0.1864 0.0413 0.0000 0.0979 0.0119 ... 0.1637 0.0618 0.0000 0.0000 0.0000 0.0000 ... 0.1289 0.1293 0.0000 ... ⋱ ... 0.1006 0.1270 0.0723 ... 0.0000 0.1026 0.0000 0.0000 0.0000 0.0574 ... 0.1491 0.0000 0.0191 0.0150 0.0321 0.0000 ... 0.0204 0.0146 0.1724
損失
torch.nn
は、機械学習アプリケーションにとって当然重要な多くの損失関数も提供します。 そのような機能の例:
-
torch.nn.MSELoss
:torch.nn.MSELoss
平均平方根損失関数 -
torch.nn.BCELoss
:バイナリ相互エントロピー損失関数、 -
torch.nn.KLDivLoss
:torch.nn.KLDivLoss
-Leibler情報発散の損失関数
PyTorchのコンテキストでは、損失関数はしばしば基準と呼ばれます 。 基本的に、基準は作成後すぐにパラメーター化できる非常に単純なモジュールであり、それ以降は通常の関数として使用できます。
In [1]: import torch In [2]: import torch.nn In [3]: from torch.autograd import Variable In [4]: x = Variable(torch.randn(10, 3)) In [5]: y = Variable(torch.ones(10).type(torch.LongTensor)) In [6]: weights = Variable(torch.Tensor([0.2, 0.2, 0.6])) In [7]: loss_function = torch.nn.CrossEntropyLoss(weight=weights) In [8]: loss_value = loss_function(x, y) Out [8]: Variable containing: 1.2380 [torch.FloatTensor of size (1,)]
オプティマイザー
ニューラルネットワーク(
nn.Module
)および損失関数の「プライマリエレメント」の後、確率的勾配降下(オプション)を開始するオプティマイザーのみを考慮する必要があります。 PyTorch
torch.optim
, , :
-
torch.optim.SGD
: , -
torch.optim.Adam
: , -
torch.optim.RMSprop
: , Coursera, -
torch.optim.LBFGS
: ---
-,
parameters()
nn.Module
, , . , . 例:
In [1]: import torch In [2]: import torch.optim In [3]: from torch.autograd import Variable In [4]: x = Variable(torch.randn(5, 5)) In [5]: y = Variable(torch.randn(5, 5), requires_grad=True) In [6]: z = x.mm(y).mean() # Perform an operation In [7]: opt = torch.optim.Adam([y], lr=2e-4, betas=(0.5, 0.999)) In [8]: z.backward() # Calculate gradients In [9]: y.data Out[9]: -0.4109 -0.0521 0.1481 1.9327 1.5276 -1.2396 0.0819 -1.3986 -0.0576 1.9694 0.6252 0.7571 -2.2882 -0.1773 1.4825 0.2634 -2.1945 -2.0998 0.7056 1.6744 1.5266 1.7088 0.7706 -0.7874 -0.0161 [torch.FloatTensor of size 5x5] In [10]: opt.step() # y Adam In [11]: y.data Out[11]: -0.4107 -0.0519 0.1483 1.9329 1.5278 -1.2398 0.0817 -1.3988 -0.0578 1.9692 0.6250 0.7569 -2.2884 -0.1775 1.4823 0.2636 -2.1943 -2.0996 0.7058 1.6746 1.5264 1.7086 0.7704 -0.7876 -0.0163 [torch.FloatTensor of size 5x5]
PyTorch , .
torch.utils.data module
. :
-
Dataset
, , -
DataLoader
, , , .
torch.utils.data.Dataset
__len__
, , ,
__getitem__
. , , :
import math class RangeDataset(torch.utils.data.Dataset): def __init__(self, start, end, step=1): self.start = start self.end = end self.step = step def __len__(self, length): return math.ceil((self.end - self.start) / self.step) def __getitem__(self, index): value = self.start + index * self.step assert value < self.end return value
__init__
- .
__len__
,
__getitem__
,
__getitem__
, , .
, , ,
for i in range
__getitem__
. , , ,
for sample in dataset
. ,
DataLoader
.
DataLoader
, . , , .
DataLoader
num_workers
. :
DataLoader
,
batch_size
. 簡単な例:
dataset = RangeDataset(0, 10) data_loader = torch.utils.data.DataLoader( dataset, batch_size=4, shuffle=True, num_workers=2, drop_last=True) for i, batch in enumerate(data_loader): print(i, batch)
batch_size
4, .
shuffle=True
, , .
drop_last=True
, , ,
batch_size
, . ,
num_workers
«», , . ,
DataLoader
, , , , .
, :
DataLoader
, , ,
__getitem__
, ,
DataLoader
. ,
__getitem__
,
DataLoader
, , . , ,
__getitem__
dict(example=example, label=label)
, ,
DataLoader
,
dict(example=[example1, example2, ...], label=[label1, label2, ...])
, , , . ,
collate_fn
DataLoader
.
:
torchvision
, ,
torchvision.datasets.CIFAR10
.
torchaudio
torchtext
.
おわりに
, PyTorch, API, , PyTorch. PyTorch, , PyTorch. , PyTorch LSGAN, TensorFlow , . , .