Pipeline implementation for BatchGenerator
* Add end_epoch_callback in Trainer * Fix Layer.ACTIVATION in cas of nn.Module
This commit is contained in:
parent
51776d6999
commit
86787f6517
3 changed files with 27 additions and 10 deletions
|
|
@ -33,7 +33,12 @@ class Layer(nn.Module):
|
|||
self.info = LayerInfo()
|
||||
|
||||
# Preload default
|
||||
self.activation = Layer.ACTIVATION if activation == 0 else activation
|
||||
if activation == 0:
|
||||
activation = Layer.ACTIVATION
|
||||
if isinstance(activation, type):
|
||||
self.activation = activation()
|
||||
else:
|
||||
self.activation = activation
|
||||
self.batch_norm = Layer.BATCH_NORM if batch_norm is None else batch_norm
|
||||
|
||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||
|
|
|
|||
|
|
@ -99,6 +99,9 @@ class Trainer:
|
|||
self.benchmark_step = 0
|
||||
self.benchmark_time = time.time()
|
||||
|
||||
def end_epoch_callback(self):
|
||||
pass
|
||||
|
||||
def train_step_callback(
|
||||
self,
|
||||
batch_inputs: torch.Tensor, processed_inputs: torch.Tensor,
|
||||
|
|
@ -177,6 +180,7 @@ class Trainer:
|
|||
self.benchmark_step += 1
|
||||
self.save_summaries()
|
||||
self.batch_generator_train.next_batch()
|
||||
self.end_epoch_callback()
|
||||
except KeyboardInterrupt:
|
||||
if self.verbose:
|
||||
print()
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ class BatchGenerator:
|
|||
|
||||
def __init__(self, data: Iterable, label: Iterable, batch_size: int,
|
||||
data_processor: Optional[Callable] = None, label_processor: Optional[Callable] = None,
|
||||
pipeline: Optional[Callable] = None,
|
||||
prefetch=True, preload=False, num_workers=1, shuffle=True, initial_shuffle=False,
|
||||
flip_data=False, save: Optional[str] = None):
|
||||
self.batch_size = batch_size
|
||||
|
|
@ -18,6 +19,7 @@ class BatchGenerator:
|
|||
self.prefetch = prefetch and not preload
|
||||
self.num_workers = num_workers
|
||||
self.flip_data = flip_data
|
||||
self.pipeline = pipeline
|
||||
|
||||
if not preload:
|
||||
self.data_processor = data_processor
|
||||
|
|
@ -261,31 +263,37 @@ class BatchGenerator:
|
|||
data = []
|
||||
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
data.append(self.data_processor(self.data[entry]))
|
||||
self.cache_data[self.current_cache][:len(data)] = np.asarray(data)
|
||||
data = np.asarray(data)
|
||||
else:
|
||||
data = self.data[
|
||||
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
||||
if self.flip_data:
|
||||
flip = np.random.uniform()
|
||||
if flip < 0.25:
|
||||
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1]
|
||||
data = data[:, :, ::-1]
|
||||
elif flip < 0.5:
|
||||
self.cache_data[self.current_cache][:len(data)] = data[:, :, :, ::-1]
|
||||
data = data[:, :, :, ::-1]
|
||||
elif flip < 0.75:
|
||||
self.cache_data[self.current_cache][:len(data)] = data[:, :, ::-1, ::-1]
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = data
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = data
|
||||
data = data[:, :, ::-1, ::-1]
|
||||
|
||||
# Loading label
|
||||
if self.label_processor is not None:
|
||||
label = []
|
||||
for entry in self.index_list[self.step * self.batch_size:(self.step + 1) * self.batch_size]:
|
||||
label.append(self.label_processor(self.label[entry]))
|
||||
self.cache_label[self.current_cache][:len(label)] = np.asarray(label)
|
||||
label = np.asarray(label)
|
||||
else:
|
||||
label = self.label[
|
||||
self.index_list[self.step * self.batch_size: (self.step + 1) * self.batch_size]]
|
||||
|
||||
# Process through pipeline
|
||||
if self.pipeline is not None:
|
||||
for data_index, data_entry in enumerate(data):
|
||||
piped_data, piped_label = self.pipeline(data_entry, label[data_index])
|
||||
self.cache_data[self.current_cache][data_index] = piped_data
|
||||
self.cache_label[self.current_cache][data_index] = piped_label
|
||||
else:
|
||||
self.cache_data[self.current_cache][:len(data)] = data
|
||||
self.cache_label[self.current_cache][:len(label)] = label
|
||||
|
||||
def next_batch(self) -> Tuple[np.ndarray, np.ndarray]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue