Blob Blame History Raw
From f46ac074d066884480098c14397f5f3e34475e11 Mon Sep 17 00:00:00 2001
From: Jiri Popelka <jpopelka@redhat.com>
Date: Fri, 22 Jan 2016 16:30:13 +0100
Subject: [PATCH] fix from launchpad #1510950

---
 base/password.py | 17 ++++++++++++++++-
 base/utils.py    | 53 ++++++++++++++++++++++++++++++++++-------------------
 2 files changed, 50 insertions(+), 20 deletions(-)

diff --git a/base/password.py b/base/password.py
index 3ca16ae..6caefdf 100644
--- a/base/password.py
+++ b/base/password.py
@@ -104,6 +104,7 @@ def get_distro_name():
 class Password(object):
     def __init__(self, Mode = INTERACTIVE_MODE):
         self.__password =""
+        self.__password_prompt_str=""
         self.__passwordValidated = False
         self.__mode = Mode
         self.__readAuthType()  #self.__authType   
@@ -201,7 +202,17 @@ class Password(object):
                     
                     cb = child.before
                     if cb:
-
+                        if('true' in cmd and self.__password_prompt_str == ""): #sudo true or su -c "true"
+                            cb = cb.replace("[", "\[")
+                            cb = cb.replace("]", "\]")
+                            self.__password_prompt_str = cb
+                            try:
+                                p = re.compile(cb, re.I)
+                            except TypeError:
+                                self.__expectList.append(cb)
+                            else:
+                                self.__expectList.append(p)
+                            
                         start = time.time()
                         output.write(cb)
 
@@ -355,3 +366,7 @@ class Password(object):
         self.__validatePassword( pswd_msg)
         return self.__password
 
+    def getPasswordPromptString(self):
+        return self.__password_prompt_str
+
+
diff --git a/base/utils.py b/base/utils.py
index f1ec1e1..5d108f0 100644
--- a/base/utils.py
+++ b/base/utils.py
@@ -108,22 +108,21 @@ MAJ_VER = sys.version_info[0]
 MIN_VER = sys.version_info[1]
 
 
-
 EXPECT_WORD_LIST = [
     pexpect.EOF, # 0
     pexpect.TIMEOUT, # 1
-    "Continue?", # 2 (for zypper)
-    "passwor[dt]:", # en/de/it/ru
-    "kennwort", # de?
-    "password for", # en
-    "mot de passe", # fr
-    "contraseña", # es
-    "palavra passe", # pt
-    "口令", # zh
-    "wachtwoord", # nl
-    "heslo", # czech
-    "密码",
-    "Lösenord", #sv
+    u"Continue?", # 2 (for zypper)
+    u"passwor[dt]:", # en/de/it/ru
+    u"kennwort", # de?
+    u"password for", # en
+    u"mot de passe", # fr
+    u"contraseña", # es
+    u"palavra passe", # pt
+    u"口令", # zh
+    u"wachtwoord", # nl
+    u"heslo", # czech
+    u"密码",
+    u"Lösenord", #sv
 ]
 
 
@@ -1260,6 +1259,15 @@ def run(cmd, passwordObj = None, pswd_msg='', log_output=True, spinner=True, tim
     import io
     output = io.StringIO()
 
+    pwd_prompt_str = ""
+    if passwordObj and ('su' in cmd or 'sudo' in cmd):
+        pwd_prompt_str = passwordObj.getPasswordPromptString()
+        log.debug("cmd = %s pwd_prompt_str = [%s]"%(cmd, pwd_prompt_str))
+        if(pwd_prompt_str == ""):
+            passwordObj.getPassword(pswd_msg, 0)
+            pwd_prompt_str = passwordObj.getPasswordPromptString()
+            log.debug("pwd_prompt_str2 = [%s]"%(pwd_prompt_str))
+
     try:
         child = pexpect.spawnu(cmd, timeout=timeout)
     except pexpect.ExceptionPexpect as e:
@@ -1277,15 +1285,22 @@ def run(cmd, passwordObj = None, pswd_msg='', log_output=True, spinner=True, tim
                 continue
 
             if child.before:
+                if(pwd_prompt_str and pwd_prompt_str not in EXPECT_LIST):
+                    log.debug("Adding %s to EXPECT LIST"%pwd_prompt_str)
+                    try:
+                        p = re.compile(pwd_prompt_str, re.I)
+                    except TypeError:
+                        EXPECT_LIST.append(pwd_prompt_str)
+                    else:
+                        EXPECT_LIST.append(p)
+                        EXPECT_LIST.append(pwd_prompt_str)
+
                 try:
                     output.write(child.before)
+                    if log_output:
+                        log.debug(child.before)
                 except Exception:
                     pass
-                if log_output:
-                    try:
-                        log.debug(child.before)
-                    except Exception:
-                        pass
 
             if i == 0: # EOF
                 break
@@ -2018,7 +2033,7 @@ def download_from_network(weburl, outputFile = None, useURLLIB=False):
 
         if useURLLIB:
 		
-            sys.stdout.write("Download in progress..........")
+            #sys.stdout.write("Download in progress..........")
             try:
                 response = urllib2_request.urlopen(weburl)    
                 file_fd = open(outputFile, 'wb')
-- 
2.5.0