Pipeline implementation for BatchGenerator

* Add end_epoch_callback in Trainer
* Fix Layer.ACTIVATION in cas of nn.Module
This commit is contained in:
Corentin 2021-03-18 21:30:59 +09:00
commit 86787f6517
3 changed files with 27 additions and 10 deletions

View file

@ -33,7 +33,12 @@ class Layer(nn.Module):
self.info = LayerInfo()
# Preload default
self.activation = Layer.ACTIVATION if activation == 0 else activation
if activation == 0:
activation = Layer.ACTIVATION
if isinstance(activation, type):
self.activation = activation()
else:
self.activation = activation
self.batch_norm = Layer.BATCH_NORM if batch_norm is None else batch_norm
def forward(self, input_data: torch.Tensor) -> torch.Tensor:

View file

@ -99,6 +99,9 @@ class Trainer:
self.benchmark_step = 0
self.benchmark_time = time.time()
def end_epoch_callback(self):
pass
def train_step_callback(
self,
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
@ -177,6 +180,7 @@ class Trainer:
self.benchmark_step += 1
self.save_summaries()
self.batch_generator_train.next_batch()
self.end_epoch_callback()
except KeyboardInterrupt:
if self.verbose:
print()

View file

@ -11,6 +11,7 @@ class BatchGenerator:
def __init__(self, data: Iterable, label: Iterable, batch_size: int,
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
pipeline: Optional[Callable] = None,
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
flip_data=False, save: Optional[str] = None):
self.batch_size = batch_size
@ -18,6 +19,7 @@ class BatchGenerator:
self.prefetch = prefetch and not preload
self.num_workers = num_workers
self.flip_data = flip_data
self.pipeline = pipeline
if not preload:
self.data_processor = data_processor
@ -261,31 +263,37 @@ class BatchGenerator:
data = []
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
data.append(self.data_processor(self.data[entry]))
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
data = np.asarray(data)
else:
data = self.data[
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
if self.flip_data:
flip = np.random.uniform()
if flip < 0.25:
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1]
data = data[:, :, ::-1]
elif flip < 0.5:
self.cache_data[self.current_cache][:len(data)] = data[:, :, :, ::-1]
data = data[:, :, :, ::-1]
elif flip < 0.75:
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1, ::-1]
else:
self.cache_data[self.current_cache][:len(data)] = data
else:
self.cache_data[self.current_cache][:len(data)] = data
data = data[:, :, ::-1, ::-1]
# Loading label
if self.label_processor is not None:
label = []
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
label.append(self.label_processor(self.label[entry]))
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
label = np.asarray(label)
else:
label = self.label[
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
# Process through pipeline
if self.pipeline is not None:
for data_index, data_entry in enumerate(data):
piped_data, piped_label = self.pipeline(data_entry, label[data_index])
self.cache_data[self.current_cache][data_index] = piped_data
self.cache_label[self.current_cache][data_index] = piped_label
else:
self.cache_data[self.current_cache][:len(data)] = data
self.cache_label[self.current_cache][:len(label)] = label
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]: