2015-09-02 10 views
1

我正在爲python編寫一個函數來進行科學計算。這個函數的一個參數表示一個實值輸入參數。如果一個複雜的值作爲這個參數傳遞,函數的結果將會不正確,因爲我沒有實現複雜值輸入的情況下需要特別小心,但是函數會返回一個不正確的值而沒有錯誤或異常,因爲每個根據語法,函數中的行是有效的。如何確保在python中輸入參數不是複數值,但實值

例如,請考慮這樣的功能:

import numpy as np 
def foo(vara): 
    """ 
    This function evaluates the Foo formula for the real variable vara. 
    This function does not work for the complex variable vara because I am 
    too lazy to take care of the branch cut of the complex square-root function. 
    """ 
    if vara<0: 
     vv = -0.57386286*vara 
    else: 
     vv = 3.49604327*vara 
    return np.sqrt(vv) 

功能foo將返回即使參數vara是複雜的,因爲numpy.sqrt功能也是複雜的參數定義的複雜值,但返回的值假設函數foo僅在真正的論證中實現,那麼這將是不正確的。

如何檢查函數中的參數是否是實值,以便我可以使該函數拋出異常或以其他方式錯誤退出?

請不就是我想要保持功能,同時爲本地float型蟒蛇以及對float類型元素的數組numpy的工作。我只想禁止使用complex變量或complex元素的numpy數組。

(我想乘1.0j參數,並且檢查結果的實部爲零,但這並不顯得乾淨利落。)如果你僅想禁止複雜的數據類型,這將做

+0

只是要清楚,你是否想禁止複雜的論點,也沒有虛部? '2 + 0j'會被禁止嗎? – DSM

+0

我認爲最好禁止複雜的論點與他們的想象部分零。我想確保該函數的用戶知道該函數僅用於實際變量。如果用戶意識到這一點並且確定被傳遞的變量的虛部爲零,那麼他/她可以通過'.real'來取實部。 – norio

回答

1

絕招:

import types 

scalar_complex_types = [types.ComplexType, np.complex64, np.complex128] 

def is_complex_sequence(vara): 
    return (hasattr(vara, '__iter__') 
      and any(isinstance(v, t) for v in vara for t in complex_types) 

def is_complex_scalar(vara): 
    return any(isinstance(vara, t) for t in complex_types) 

然後在你的功能,你可以..

if is_complex_scalar(vara) or is_complex_sequence(vara): 
    raise ValueError('Argument must not be a complex number') 
+0

第一個會讓複數dtype數組通過,第二個會讓複數的0虛數部分的dtype數組通過,OP澄清的不應該通過。 – DSM

+0

從給出的函數我假設vara是一個標量,而不是一個數組,但你可以添加另一種類型到'complex_types'列表來解釋數組變量。 –

+0

對不起我的問題的文本中的矛盾。我說過我想允許numpy數組,但是在例子中我有'如果vara <0:'。這個例子不好。 – norio

0

(我回答我自己的問題,我不知道這是在b est的方式,但我想留下一個代碼,我試圖記錄。)

基於polpak的答案,我寫了下面的代碼。我想這會滿足我提出的條件。該函數是迂腐的,因爲它拒絕任何其他類型的輸入參數,而不是float scaling或float ndarray。 (也許它甚至不接受各種浮點ndarray。)特別是,它拒絕整數縮放器和整數ndarray以及複雜的縮放器和複雜的ndarray。

#!/usr/bin/python 

import numpy as np 
import types 

def foo(vara): 
    """vara must be a real-valued scaler or ndarray.""" 

    real_types = [types.FloatType, np.float16, np.float32, np.float64, np.float128] 
    print '----------' 
    print 'vara:', vara 
    if isinstance(vara, np.ndarray): 
     if not any(vara.dtype==t for t in real_types): 
      print 'NG.' 
      print ' type(vara)=', type(vara) 
      print ' vara.dtype=', vara.dtype 
      # raise an error here 
     else: 
      print 'OK.' 
      print ' type(vara)=', type(vara) 
      print ' vara.dtype=', vara.dtype 
    else: 
     if not any(isinstance(vara, t) for t in real_types): 
      print 'NG.' 
      print ' type(vara)=', type(vara) 
      # raise an error here 
     else: 
      print 'OK.' 
      print ' type(vara)=', type(vara) 


varalist=[3.0, 
      np.array([0.5, 0.2]), 
      np.array([3, 4, 1]), 
      np.array([3.4+1.2j, 0.8+0.7j]), 
      np.array([3.4+0.0j, 0.8+0.0j]), 
      np.array([1.3, 4.2, 5.9], dtype=complex), 
      np.array([1.3, 4.2, 5.9], dtype=complex).real ] 

for vara in varalist: 
    foo(vara) 

該代碼的輸出如下。

$ ./main003.py 
---------- 
vara: 3.0 
OK. 
    type(vara)= <type 'float'> 
---------- 
vara: [ 0.5 0.2] 
OK. 
    type(vara)= <type 'numpy.ndarray'> 
    vara.dtype= float64 
---------- 
vara: [3 4 1] 
NG. 
    type(vara)= <type 'numpy.ndarray'> 
    vara.dtype= int64 
---------- 
vara: [ 3.4+1.2j 0.8+0.7j] 
NG. 
    type(vara)= <type 'numpy.ndarray'> 
    vara.dtype= complex128 
---------- 
vara: [ 3.4+0.j 0.8+0.j] 
NG. 
    type(vara)= <type 'numpy.ndarray'> 
    vara.dtype= complex128 
---------- 
vara: [ 1.3+0.j 4.2+0.j 5.9+0.j] 
NG. 
    type(vara)= <type 'numpy.ndarray'> 
    vara.dtype= complex128 
---------- 
vara: [ 1.3 4.2 5.9] 
OK. 
    type(vara)= <type 'numpy.ndarray'> 
    vara.dtype= float64 
相關問題