我试图设置stateful = True来训练我的LSTM模型,并且它起作用了。
但是我必须将输入的形状调整为我为第一层设置的相同batch_size,这对于有状态RNN是必须的,否则我将收到错误消息:InvalidArgumentError:Invalid input_h shape。
我将batch_size设置为64,但是我只想输入一个开始句子来生成文本。如果我必须提供batch_size = 64的输入,则需要准备64个句子,这很荒谬。
如果我没有设置stateful = True,它会很好地工作,但是我需要提高性能。 在这种情况下,如何在不匹配我设置的batch_size的情况下使用有状态的LSTM模型?
我定义的模型
seq_length = 100
batch_size = 64
epochs = 3
vocab_size = len(vocab) # 65
embedding_dim = 256
rnn_units = 1024
def bi_lstm(vocab_size,embedding_dim,batch_size,rnn_units):
model = keras.models.Sequential([
keras.layers.Embedding(vocab_size,batch_input_shape = (batch_size,None)),keras.layers.Bidirectional(
keras.layers.LSTM(units = rnn_units,return_sequences = True,stateful = True,recurrent_initializer = "glorot_uniform"
)),keras.layers.Dense(vocab_size),])
return model
我做了一个简单的测试,它向我显示了错误。
for x,y in seq_dataset.take(1):
x = x[:-10,:] # change the batch size from 64 to 54,it worked well if I del this line
print(x.shape)
pred = model(x)
print(pred.shape)
InvalidArgumentError Traceback (most recent call last)
<ipython-input-98-99323ee3e09d> in <module>()
2 x = x[:-10,:]
3 print(x.shape)
----> 4 pred = model(x)
5 print(pred.shape)
14 frames
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/engine/base_layer.py in __call__(self,inputs,*args,**kwargs)
889 with base_layer_utils.autocast_context_manager(
890 self._compute_dtype):
--> 891 outputs = self.call(cast_inputs,**kwargs)
892 self._handle_activity_regularization(inputs,outputs)
893 self._set_mask_metadata(inputs,outputs,input_masks)
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/engine/sequential.py in call(self,training,mask)
254 if not self.built:
255 self._init_graph_network(self.inputs,self.outputs,name=self.name)
--> 256 return super(Sequential,self).call(inputs,training=training,mask=mask)
257
258 outputs = inputs # handle the corner case where self.layers is empty
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/engine/network.py in call(self,mask)
706 return self._run_internal_graph(
707 inputs,mask=mask,--> 708 convert_kwargs_to_constants=base_layer_utils.call_context().saving)
709
710 def compute_output_shape(self,input_shape):
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/engine/network.py in _run_internal_graph(self,mask,convert_kwargs_to_constants)
858
859 # Compute outputs.
--> 860 output_tensors = layer(computed_tensors,**kwargs)
861
862 # Update tensor_dict.
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/layers/wrappers.py in __call__(self,initial_state,constants,**kwargs)
526
527 if initial_state is None and constants is None:
--> 528 return super(Bidirectional,self).__call__(inputs,**kwargs)
529
530 # Applies the same workaround as in `RNN.__call__`
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/engine/base_layer.py in __call__(self,input_masks)
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/layers/wrappers.py in call(self,constants)
640
641 y = self.forward_layer(forward_inputs,--> 642 initial_state=forward_state,**kwargs)
643 y_rev = self.backward_layer(backward_inputs,644 initial_state=backward_state,**kwargs)
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/layers/recurrent.py in __call__(self,**kwargs)
621
622 if initial_state is None and constants is None:
--> 623 return super(RNN,**kwargs)
624
625 # If any of `initial_state` or `constants` are specified and are Keras
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/engine/base_layer.py in __call__(self,input_masks)
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/layers/recurrent_v2.py in call(self,initial_state)
959 if can_use_gpu:
960 last_output,new_h,new_c,runtime = cudnn_lstm(
--> 961 **cudnn_lstm_kwargs)
962 else:
963 last_output,runtime = standard_lstm(
/tensorflow-2.0.0/python3.6/tensorflow_core/python/keras/layers/recurrent_v2.py in cudnn_lstm(inputs,init_h,init_c,kernel,recurrent_kernel,bias,time_major,go_backwards)
1172 outputs,h,c,_ = gen_cudnn_rnn_ops.cudnn_rnn(
1173 inputs,input_h=init_h,input_c=init_c,params=params,is_training=True,-> 1174 rnn_mode='lstm')
1175
1176 last_output = outputs[-1]
/tensorflow-2.0.0/python3.6/tensorflow_core/python/ops/gen_cudnn_rnn_ops.py in cudnn_rnn(input,input_h,input_c,params,rnn_mode,input_mode,direction,dropout,seed,seed2,is_training,name)
107 input_mode=input_mode,direction=direction,dropout=dropout,108 seed=seed,seed2=seed2,is_training=is_training,name=name,--> 109 ctx=_ctx)
110 except _core._SymbolicException:
111 pass # Add nodes to the TensorFlow graph.
/tensorflow-2.0.0/python3.6/tensorflow_core/python/ops/gen_cudnn_rnn_ops.py in cudnn_rnn_eager_fallback(input,name,ctx)
196 "is_training",is_training)
197 _result = _execute.execute(b"CudnnRNN",4,inputs=_inputs_flat,--> 198 attrs=_attrs,ctx=_ctx,name=name)
199 _execute.record_gradient(
200 "CudnnRNN",_inputs_flat,_attrs,_result,name)
/tensorflow-2.0.0/python3.6/tensorflow_core/python/eager/execute.py in quick_execute(op_name,num_outputs,attrs,ctx,name)
65 else:
66 message = e.message
---> 67 six.raise_from(core._status_to_exception(e.code,message),None)
68 except TypeError as e:
69 keras_symbolic_tensors = [
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value,from_value)
InvalidArgumentError: Invalid input_h shape: [1,64,1024] [1,54,1024] [Op:CudnnRNN]