由于 SELECT ValueError,继承 Airflow OracleHook 以便设置 cx_Oracle outputtypehandler

Inheriting Airflow OracleHook in order to set cx_Oracle outputtypehandler due to SELECT ValueError

我正在查询很旧的完整 table。它包含一些损坏的数据,我无权更改或创建视图。每当我 SELECT 数据使用 Airflow OracleHook get_records I get the error message "ValueError: year -4712 is out of range". I would like to handle this by returning None for this particular record as shown in this cx_Oracle solution: Problem empty date cause ValueError: year -9999 is out of range. This requires setting the cx_Oracle outputtypehandler 属性时。当我 运行 下面的代码时,OutputHandler 或 DateTimeConverter 函数都没有被调用,并且代码失败并出现与我使用基本 OracleHook class 时相同的 ValueError。任何帮助将不胜感激!

from airflow.hooks.oracle_hook import OracleHook
import cx_Oracle

from datetime import datetime
import os
os.environ['NLS_DATE_FORMAT'] = 'YYYY-MM-DD HH24:MI:SS'

class OracleHookTypeHandler(OracleHook):
    def __init__(self, oracle_conn_id):
        OracleHook.__init__(self, oracle_conn_id)
        self.cursor = OracleHook.get_cursor
        self.cursor.outputtypehandler = self.OutputHandler

    # Dealing with invalid years in the database
    def DateTimeConverter(self, value):
        print('DateTimeConverter was called')
        if value.startswith('4712'):
            return None
        return datetime.strptime(value, '%Y-%m-%d %H:%M:%S')

    def OutputHandler(self, cursor, name, defaulttype, length, precision, scale):
        print('OutputHandler was called')
        if defaulttype == cx_Oracle.DATETIME:
            return cursor.var(cx_Oracle.STRING, arraysize=cursor.arraysize,
                              outconverter=self.DateTimeConverter)

def extract(extract_connection)
    # Return the extracted records
    extract_records_query = 'SELECT col1, col2, col3 FROM table'
    o_extract_hook = OracleHookTypeHandler(oracle_conn_id=extract_connection)
    print('Extract started')
    extract_records = o_extract_hook.get_records(sql=extract_records_query)
    return extract_records

更新: 我使用下面@joebeeson 的回答解决了这个问题。工作代码:

from airflow.hooks.oracle_hook import OracleHook
import cx_Oracle
import sys

from datetime import datetime
from contextlib import closing
import os
os.environ['NLS_DATE_FORMAT'] = 'YYYY-MM-DD HH24:MI:SS'

class OracleHookTypeHandler(OracleHook):
    def __init__(self, oracle_conn_id):
        OracleHook.__init__(self, oracle_conn_id)

    # Override get_records from inherited class dbapihook
    def get_records(self, sql, parameters=None):
        """
        Executes the sql and returns a set of records.

        :param sql: the sql statement to be executed (str) or a list of
            sql statements to execute
        :type sql: str or list
        :param parameters: The parameters to render the SQL query with.
        :type parameters: mapping or iterable
        """
        if sys.version_info[0] < 3:
            sql = sql.encode('utf-8')

        with closing(self.get_conn()) as conn:
            with closing(conn.cursor()) as cur:
    
                cur.outputtypehandler = self.OutputHandler

                if parameters is not None:
                    cur.execute(sql, parameters)
                else:
                    cur.execute(sql)
                return cur.fetchall()

    # Dealing with invalid years in the database
    def DateTimeConverter(self, value):
        print('DateTimeConverter was called')
        if value.startswith('4712'):
            return None
        return datetime.strptime(value, '%Y-%m-%d %H:%M:%S')

    def OutputHandler(self, cursor, name, defaulttype, length, precision, scale):
        print('OutputHandler was called')
        if defaulttype == cx_Oracle.DATETIME:
            return cursor.var(cx_Oracle.STRING, arraysize=cursor.arraysize,
                              outconverter=self.DateTimeConverter)

def extract(extract_connection)
    # Return the extracted records
    extract_records_query = 'SELECT col1, col2, col3 FROM table'
    o_extract_hook = OracleHookTypeHandler(oracle_conn_id=extract_connection)
    print('Extract started')
    extract_records = o_extract_hook.get_records(sql=extract_records_query)
    return extract_records

您想覆盖 airflow.hooks.dbapi_hook.DbapiHook class 中的 get_records 方法;它不会调用 OracleHook.get_cursor,因此您的作业将无法工作:

def get_records(self, sql, parameters=None):
    """
    Executes the sql and returns a set of records.

    :param sql: the sql statement to be executed (str) or a list of
        sql statements to execute
    :type sql: str or list
    :param parameters: The parameters to render the SQL query with.
    :type parameters: mapping or iterable
    """
    if sys.version_info[0] < 3:
        sql = sql.encode('utf-8')

    with closing(self.get_conn()) as conn:
        with closing(conn.cursor()) as cur:
 
            # You have access to the `Cursor` (named "cur") object.

            if parameters is not None:
                cur.execute(sql, parameters)
            else:
                cur.execute(sql)
            return cur.fetchall()

虽然将您需要的代码部分提升到需要这些修改的流程文件中可能更清晰。